diff --git a/internal/api/alert_handler_test.go b/internal/api/alert_handler_test.go new file mode 100644 index 0000000..3c4cc86 --- /dev/null +++ b/internal/api/alert_handler_test.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ez-api/ez-api/internal/model" + "github.com/gin-gonic/gin" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func setupAlertTestDB(t *testing.T) *gorm.DB { + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Alert{}, &model.AlertThresholdConfig{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func setupAlertRouter(db *gorm.DB) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + handler := NewAlertHandler(db) + + r.GET("/admin/alerts/thresholds", handler.GetAlertThresholds) + r.PUT("/admin/alerts/thresholds", handler.UpdateAlertThresholds) + r.GET("/admin/alerts", handler.ListAlerts) + r.POST("/admin/alerts", handler.CreateAlert) + r.GET("/admin/alerts/stats", handler.GetAlertStats) + + return r +} + +func TestGetAlertThresholdsDefault(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + req, _ := http.NewRequest("GET", "/admin/alerts/thresholds", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp AlertThresholdView + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + // Should return defaults + if resp.GlobalQPS != 100 { + t.Errorf("expected GlobalQPS=100, got %d", resp.GlobalQPS) + } + if resp.MasterRPM != 20 { + t.Errorf("expected MasterRPM=20, got %d", resp.MasterRPM) + } + if resp.MasterRPD != 1000 { + t.Errorf("expected MasterRPD=1000, got %d", resp.MasterRPD) + } + if resp.MasterTPM != 10_000_000 { + t.Errorf("expected MasterTPM=10000000, got %d", resp.MasterTPM) + } + if resp.MasterTPD != 100_000_000 { + t.Errorf("expected MasterTPD=100000000, got %d", resp.MasterTPD) + } +} + +func TestUpdateAlertThresholds(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + // Update some thresholds + body := `{"global_qps": 500, "master_rpm": 100}` + req, _ := http.NewRequest("PUT", "/admin/alerts/thresholds", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp AlertThresholdView + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if resp.GlobalQPS != 500 { + t.Errorf("expected GlobalQPS=500, got %d", resp.GlobalQPS) + } + if resp.MasterRPM != 100 { + t.Errorf("expected MasterRPM=100, got %d", resp.MasterRPM) + } + // Other fields should remain default + if resp.MasterRPD != 1000 { + t.Errorf("expected MasterRPD=1000 (unchanged), got %d", resp.MasterRPD) + } +} + +func TestUpdateAlertThresholdsValidation(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + tests := []struct { + name string + body string + expected int + }{ + {"negative global_qps", `{"global_qps": -1}`, http.StatusBadRequest}, + {"zero global_qps", `{"global_qps": 0}`, http.StatusBadRequest}, + {"negative master_rpm", `{"master_rpm": -5}`, http.StatusBadRequest}, + {"valid update", `{"global_qps": 200}`, http.StatusOK}, + {"negative min_rpm_requests_1m", `{"min_rpm_requests_1m": -1}`, http.StatusBadRequest}, + {"zero min_rpm_requests_1m allowed", `{"min_rpm_requests_1m": 0}`, http.StatusOK}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest("PUT", "/admin/alerts/thresholds", bytes.NewBufferString(tc.body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tc.expected { + t.Errorf("expected status %d, got %d: %s", tc.expected, w.Code, w.Body.String()) + } + }) + } +} + +func TestCreateAlertWithTrafficSpikeType(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + body := `{ + "type": "traffic_spike", + "severity": "warning", + "title": "RPM Exceeded", + "message": "Master exceeded RPM threshold", + "related_id": 1, + "related_type": "master", + "metadata": "{\"metric\":\"master_rpm\",\"value\":150,\"threshold\":20,\"window\":\"1m\"}" + }` + + req, _ := http.NewRequest("POST", "/admin/alerts", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp AlertView + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if resp.Type != "traffic_spike" { + t.Errorf("expected type=traffic_spike, got %s", resp.Type) + } + if resp.Severity != "warning" { + t.Errorf("expected severity=warning, got %s", resp.Severity) + } + if resp.Status != "active" { + t.Errorf("expected status=active, got %s", resp.Status) + } +} + +func TestListAlertsWithTypeFilter(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + // Create some alerts + alerts := []model.Alert{ + {Type: model.AlertTypeRateLimit, Severity: model.AlertSeverityWarning, Status: model.AlertStatusActive, Title: "Rate Limit 1"}, + {Type: model.AlertTypeTrafficSpike, Severity: model.AlertSeverityWarning, Status: model.AlertStatusActive, Title: "Traffic Spike 1"}, + {Type: model.AlertTypeTrafficSpike, Severity: model.AlertSeverityCritical, Status: model.AlertStatusActive, Title: "Traffic Spike 2"}, + } + for _, a := range alerts { + if err := db.Create(&a).Error; err != nil { + t.Fatalf("create alert: %v", err) + } + } + + // Filter by type=traffic_spike + req, _ := http.NewRequest("GET", "/admin/alerts?type=traffic_spike", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp ListAlertsResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if resp.Total != 2 { + t.Errorf("expected 2 traffic_spike alerts, got %d", resp.Total) + } + for _, a := range resp.Items { + if a.Type != "traffic_spike" { + t.Errorf("expected type=traffic_spike, got %s", a.Type) + } + } +} + +func TestAlertStatsIncludesAllAlerts(t *testing.T) { + db := setupAlertTestDB(t) + r := setupAlertRouter(db) + + // Create alerts of different types and statuses + alerts := []model.Alert{ + {Type: model.AlertTypeRateLimit, Severity: model.AlertSeverityWarning, Status: model.AlertStatusActive, Title: "A1"}, + {Type: model.AlertTypeTrafficSpike, Severity: model.AlertSeverityCritical, Status: model.AlertStatusActive, Title: "A2"}, + {Type: model.AlertTypeErrorSpike, Severity: model.AlertSeverityInfo, Status: model.AlertStatusResolved, Title: "A3"}, + } + for _, a := range alerts { + if err := db.Create(&a).Error; err != nil { + t.Fatalf("create alert: %v", err) + } + } + + req, _ := http.NewRequest("GET", "/admin/alerts/stats", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp AlertStats + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if resp.Total != 3 { + t.Errorf("expected total=3, got %d", resp.Total) + } + if resp.Active != 2 { + t.Errorf("expected active=2, got %d", resp.Active) + } + if resp.Resolved != 1 { + t.Errorf("expected resolved=1, got %d", resp.Resolved) + } + if resp.Critical != 1 { + t.Errorf("expected critical=1, got %d", resp.Critical) + } + if resp.Warning != 1 { + t.Errorf("expected warning=1, got %d", resp.Warning) + } +} diff --git a/internal/cron/alert_detector.go b/internal/cron/alert_detector.go index 7b4695e..b926f83 100644 --- a/internal/cron/alert_detector.go +++ b/internal/cron/alert_detector.go @@ -128,7 +128,7 @@ func (d *AlertDetector) detectRateLimits(ctx context.Context) { model.AlertTypeRateLimit, model.AlertSeverityWarning, fmt.Sprintf("Master '%s' is rate limited", master.Name), - fmt.Sprintf("Master '%s' (ID: %d) is currently being rate limited. Current QPS: %.2f", master.Name, master.ID, snapshot.QPS), + fmt.Sprintf("Master '%s' (ID: %d) is currently being rate limited. Current QPS: %d", master.Name, master.ID, snapshot.QPS), master.ID, "master", master.Name, diff --git a/internal/cron/alert_detector_test.go b/internal/cron/alert_detector_test.go new file mode 100644 index 0000000..c657ef2 --- /dev/null +++ b/internal/cron/alert_detector_test.go @@ -0,0 +1,328 @@ +package cron + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ez-api/ez-api/internal/model" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func setupTestDB(t *testing.T) *gorm.DB { + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Alert{}, &model.AlertThresholdConfig{}, &model.Master{}, &model.Key{}, &model.APIKey{}, &model.ProviderGroup{}, &model.LogRecord{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func TestDefaultAlertThresholdConfig(t *testing.T) { + cfg := model.DefaultAlertThresholdConfig() + + if cfg.GlobalQPS != 100 { + t.Errorf("expected GlobalQPS=100, got %d", cfg.GlobalQPS) + } + if cfg.MasterRPM != 20 { + t.Errorf("expected MasterRPM=20, got %d", cfg.MasterRPM) + } + if cfg.MasterRPD != 1000 { + t.Errorf("expected MasterRPD=1000, got %d", cfg.MasterRPD) + } + if cfg.MasterTPM != 10_000_000 { + t.Errorf("expected MasterTPM=10000000, got %d", cfg.MasterTPM) + } + if cfg.MasterTPD != 100_000_000 { + t.Errorf("expected MasterTPD=100000000, got %d", cfg.MasterTPD) + } + if cfg.MinRPMRequests1m != 10 { + t.Errorf("expected MinRPMRequests1m=10, got %d", cfg.MinRPMRequests1m) + } + if cfg.MinTPMTokens1m != 1_000_000 { + t.Errorf("expected MinTPMTokens1m=1000000, got %d", cfg.MinTPMTokens1m) + } +} + +func TestAlertDetectorLoadThresholdConfigDefault(t *testing.T) { + db := setupTestDB(t) + + detector := &AlertDetector{db: db} + cfg := detector.loadThresholdConfig() + + // Should return defaults when no config in DB + if cfg.GlobalQPS != 100 { + t.Errorf("expected default GlobalQPS=100, got %d", cfg.GlobalQPS) + } + if cfg.MasterRPM != 20 { + t.Errorf("expected default MasterRPM=20, got %d", cfg.MasterRPM) + } +} + +func TestAlertDetectorLoadThresholdConfigFromDB(t *testing.T) { + db := setupTestDB(t) + + // Insert custom config + customCfg := model.AlertThresholdConfig{ + GlobalQPS: 500, + MasterRPM: 100, + MasterRPD: 5000, + MasterTPM: 50_000_000, + MasterTPD: 500_000_000, + MinRPMRequests1m: 50, + MinTPMTokens1m: 5_000_000, + } + if err := db.Create(&customCfg).Error; err != nil { + t.Fatalf("create config: %v", err) + } + + detector := &AlertDetector{db: db} + cfg := detector.loadThresholdConfig() + + if cfg.GlobalQPS != 500 { + t.Errorf("expected GlobalQPS=500, got %d", cfg.GlobalQPS) + } + if cfg.MasterRPM != 100 { + t.Errorf("expected MasterRPM=100, got %d", cfg.MasterRPM) + } + if cfg.MasterRPD != 5000 { + t.Errorf("expected MasterRPD=5000, got %d", cfg.MasterRPD) + } +} + +func TestTrafficSpikeSeverity(t *testing.T) { + tests := []struct { + value int64 + threshold int64 + expected model.AlertSeverity + }{ + {50, 100, model.AlertSeverityWarning}, // below threshold, but this func is only called when >= threshold + {100, 100, model.AlertSeverityWarning}, // exactly at threshold + {150, 100, model.AlertSeverityWarning}, // 1.5x threshold + {199, 100, model.AlertSeverityWarning}, // just below 2x + {200, 100, model.AlertSeverityCritical}, // exactly 2x threshold + {300, 100, model.AlertSeverityCritical}, // 3x threshold + } + + for _, tc := range tests { + result := trafficSpikeSeverity(tc.value, tc.threshold) + if result != tc.expected { + t.Errorf("trafficSpikeSeverity(%d, %d) = %s, expected %s", tc.value, tc.threshold, result, tc.expected) + } + } +} + +func TestTrafficSpikeMetadataJSON(t *testing.T) { + meta := trafficSpikeMetadata{ + Metric: "master_rpm", + Value: 150, + Threshold: 20, + Window: "1m", + } + + json := meta.JSON() + if json == "" { + t.Error("expected non-empty JSON") + } + if len(json) < 10 { + t.Errorf("JSON too short: %s", json) + } +} + +func TestAlertDetectorDeduplication(t *testing.T) { + db := setupTestDB(t) + + config := DefaultAlertDetectorConfig() + config.DeduplicationCooldown = 5 * time.Minute + + detector := NewAlertDetector(db, db, nil, nil, config, nil) + + // Create first alert + detector.createAlertIfNew( + model.AlertTypeRateLimit, + model.AlertSeverityWarning, + "Test Alert", + "Test Message", + 1, + "master", + "test-master", + ) + + var count int64 + db.Model(&model.Alert{}).Count(&count) + if count != 1 { + t.Fatalf("expected 1 alert, got %d", count) + } + + // Try to create duplicate (should be skipped) + detector.createAlertIfNew( + model.AlertTypeRateLimit, + model.AlertSeverityWarning, + "Test Alert Duplicate", + "Test Message Duplicate", + 1, + "master", + "test-master", + ) + + db.Model(&model.Alert{}).Count(&count) + if count != 1 { + t.Fatalf("expected still 1 alert after duplicate, got %d", count) + } + + // Different fingerprint should create new alert + detector.createAlertIfNew( + model.AlertTypeRateLimit, + model.AlertSeverityWarning, + "Different Alert", + "Different Message", + 2, // Different related_id + "master", + "test-master-2", + ) + + db.Model(&model.Alert{}).Count(&count) + if count != 2 { + t.Fatalf("expected 2 alerts with different fingerprint, got %d", count) + } +} + +func TestAlertDetectorTrafficSpikeDeduplication(t *testing.T) { + db := setupTestDB(t) + + config := DefaultAlertDetectorConfig() + config.DeduplicationCooldown = 5 * time.Minute + + detector := NewAlertDetector(db, db, nil, nil, config, nil) + + meta := trafficSpikeMetadata{ + Metric: "master_rpm", + Value: 150, + Threshold: 20, + Window: "1m", + } + + // Create first traffic spike alert + detector.createTrafficSpikeAlert( + model.AlertSeverityWarning, + "RPM Exceeded", + "Master exceeded RPM", + 1, + "master", + "test-master", + meta, + ) + + var count int64 + db.Model(&model.Alert{}).Count(&count) + if count != 1 { + t.Fatalf("expected 1 alert, got %d", count) + } + + // Try to create duplicate (same metric, same master) + detector.createTrafficSpikeAlert( + model.AlertSeverityWarning, + "RPM Exceeded Again", + "Master exceeded RPM again", + 1, + "master", + "test-master", + meta, + ) + + db.Model(&model.Alert{}).Count(&count) + if count != 1 { + t.Fatalf("expected still 1 alert after duplicate, got %d", count) + } + + // Different metric should create new alert + meta2 := trafficSpikeMetadata{ + Metric: "master_tpm", // Different metric + Value: 15000000, + Threshold: 10000000, + Window: "1m", + } + detector.createTrafficSpikeAlert( + model.AlertSeverityWarning, + "TPM Exceeded", + "Master exceeded TPM", + 1, + "master", + "test-master", + meta2, + ) + + db.Model(&model.Alert{}).Count(&count) + if count != 2 { + t.Fatalf("expected 2 alerts with different metric, got %d", count) + } +} + +func TestAlertDetectorErrorRateSeverity(t *testing.T) { + detector := &AlertDetector{} + + tests := []struct { + rate float64 + expected model.AlertSeverity + }{ + {0.05, model.AlertSeverityInfo}, // 5% + {0.10, model.AlertSeverityInfo}, // 10% + {0.24, model.AlertSeverityInfo}, // 24% + {0.25, model.AlertSeverityWarning}, // 25% + {0.40, model.AlertSeverityWarning}, // 40% + {0.49, model.AlertSeverityWarning}, // 49% + {0.50, model.AlertSeverityCritical}, // 50% + {0.75, model.AlertSeverityCritical}, // 75% + {1.00, model.AlertSeverityCritical}, // 100% + } + + for _, tc := range tests { + result := detector.errorRateSeverity(tc.rate) + if result != tc.expected { + t.Errorf("errorRateSeverity(%.2f) = %s, expected %s", tc.rate, result, tc.expected) + } + } +} + +func TestAlertDetectorDetectOnceNilSafe(t *testing.T) { + // Test nil detector + var nilDetector *AlertDetector + nilDetector.detectOnce(context.Background()) + + // Test detector with nil db + detector := &AlertDetector{} + detector.detectOnce(context.Background()) + + // Should not panic +} + +func TestAlertDetectorStartDisabled(t *testing.T) { + db := setupTestDB(t) + + config := DefaultAlertDetectorConfig() + config.Enabled = false + + detector := NewAlertDetector(db, db, nil, nil, config, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Should return immediately without blocking + done := make(chan struct{}) + go func() { + detector.Start(ctx) + close(done) + }() + + select { + case <-done: + // Expected: Start returned immediately because Enabled=false + case <-time.After(200 * time.Millisecond): + t.Error("Start did not return immediately when disabled") + } +} diff --git a/internal/cron/token_refresh.go b/internal/cron/token_refresh.go index 1e75487..c2742f4 100644 --- a/internal/cron/token_refresh.go +++ b/internal/cron/token_refresh.go @@ -360,7 +360,7 @@ func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url if code == "invalid_grant" || code == "invalid_client" { retryable = false } - return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf(strings.TrimSpace(payload.ErrorDescription))} + return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf("%s", strings.TrimSpace(payload.ErrorDescription))} } if strings.TrimSpace(payload.AccessToken) == "" { return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")}