package cron import ( "context" "encoding/json" "fmt" "log/slog" "time" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/service" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) // AlertDetectorConfig holds configuration for alert detection type AlertDetectorConfig struct { ErrorSpikeThreshold float64 // Error rate threshold (0.1 = 10%) ErrorSpikeWindow time.Duration QuotaWarningThreshold float64 // Quota usage threshold (0.9 = 90%) ProviderFailThreshold int // Consecutive failures before alert DeduplicationCooldown time.Duration } // DefaultAlertDetectorConfig returns default configuration func DefaultAlertDetectorConfig() AlertDetectorConfig { return AlertDetectorConfig{ ErrorSpikeThreshold: 0.1, // 10% error rate ErrorSpikeWindow: 5 * time.Minute, QuotaWarningThreshold: 0.9, // 90% quota used ProviderFailThreshold: 10, DeduplicationCooldown: 5 * time.Minute, } } // AlertDetector detects anomalies and creates alerts type AlertDetector struct { db *gorm.DB logDB *gorm.DB rdb *redis.Client statsService *service.StatsService config AlertDetectorConfig logger *slog.Logger } // NewAlertDetector creates a new AlertDetector func NewAlertDetector(db, logDB *gorm.DB, rdb *redis.Client, statsService *service.StatsService, config AlertDetectorConfig, logger *slog.Logger) *AlertDetector { if logDB == nil { logDB = db } if logger == nil { logger = slog.Default() } return &AlertDetector{ db: db, logDB: logDB, rdb: rdb, statsService: statsService, config: config, logger: logger, } } // RunOnce executes a single detection cycle. Called by scheduler. func (d *AlertDetector) RunOnce(ctx context.Context) { if d == nil || d.db == nil { return } d.detectOnce(ctx) } // detectOnce runs all detection rules once func (d *AlertDetector) detectOnce(ctx context.Context) { if d == nil || d.db == nil { return } // Run each detection rule d.detectRateLimits(ctx) d.detectErrorSpikes(ctx) d.detectQuotaExceeded(ctx) d.detectProviderDown(ctx) d.detectTrafficSpikes(ctx) } // detectRateLimits checks for masters hitting rate limits func (d *AlertDetector) detectRateLimits(ctx context.Context) { if d.rdb == nil || d.statsService == nil { return } // Get all active masters var masters []model.Master if err := d.db.Where("status = ?", "active").Find(&masters).Error; err != nil { d.logger.Warn("failed to load masters for rate limit detection", "err", err) return } for _, master := range masters { snapshot, err := d.statsService.GetMasterRealtimeSnapshot(ctx, master.ID) if err != nil { continue } if snapshot.RateLimited { d.createAlertIfNew( model.AlertTypeRateLimit, model.AlertSeverityWarning, fmt.Sprintf("Master '%s' is rate limited", master.Name), fmt.Sprintf("Master '%s' (ID: %d) is currently being rate limited. Current QPS: %d", master.Name, master.ID, snapshot.QPS), master.ID, "master", master.Name, ) } } } // detectErrorSpikes checks for high error rates in recent logs func (d *AlertDetector) detectErrorSpikes(ctx context.Context) { if d.logDB == nil { return } since := time.Now().UTC().Add(-d.config.ErrorSpikeWindow) // Query error stats grouped by master/key type errorStat struct { KeyID uint Total int64 Errors int64 ErrRate float64 } var stats []errorStat err := d.logDB.Model(&model.LogRecord{}). Select("key_id, COUNT(*) as total, SUM(CASE WHEN status_code >= 400 THEN 1 ELSE 0 END) as errors"). Where("created_at >= ?", since). Where("key_id > 0"). Group("key_id"). Having("COUNT(*) >= 10"). // Minimum requests threshold Scan(&stats).Error if err != nil { d.logger.Warn("failed to query error stats", "err", err) return } for _, stat := range stats { if stat.Total == 0 { continue } errRate := float64(stat.Errors) / float64(stat.Total) if errRate >= d.config.ErrorSpikeThreshold { // Get key name var key model.Key if err := d.db.Select("id, master_id").First(&key, stat.KeyID).Error; err != nil { continue } d.createAlertIfNew( model.AlertTypeErrorSpike, d.errorRateSeverity(errRate), fmt.Sprintf("High error rate detected (%.1f%%)", errRate*100), fmt.Sprintf("Key ID %d has %.1f%% error rate (%d/%d requests) in the last %v", stat.KeyID, errRate*100, stat.Errors, stat.Total, d.config.ErrorSpikeWindow), stat.KeyID, "key", "", ) } } } // detectQuotaExceeded checks for keys approaching or exceeding quota func (d *AlertDetector) detectQuotaExceeded(ctx context.Context) { var keys []model.Key // Find keys with quota enabled and usage >= threshold err := d.db.Where("quota_limit > 0 AND quota_used >= quota_limit * ?", d.config.QuotaWarningThreshold). Find(&keys).Error if err != nil { d.logger.Warn("failed to query quota usage", "err", err) return } for _, key := range keys { usagePercent := float64(key.QuotaUsed) / float64(key.QuotaLimit) * 100 exceeded := key.QuotaUsed >= key.QuotaLimit severity := model.AlertSeverityWarning title := fmt.Sprintf("Quota at %.0f%%", usagePercent) if exceeded { severity = model.AlertSeverityCritical title = "Quota exceeded" } d.createAlertIfNew( model.AlertTypeQuotaExceeded, severity, title, fmt.Sprintf("Key ID %d (Master %d) has used %d/%d tokens (%.1f%%)", key.ID, key.MasterID, key.QuotaUsed, key.QuotaLimit, usagePercent), key.ID, "key", "", ) } } // detectProviderDown checks for API keys with consecutive failures func (d *AlertDetector) detectProviderDown(ctx context.Context) { // Find API keys with high failure rate var apiKeys []model.APIKey err := d.db.Where("status = ? AND total_requests > 0", "active"). Where("(total_requests - success_requests) >= ?", d.config.ProviderFailThreshold). Find(&apiKeys).Error if err != nil { d.logger.Warn("failed to query api key failures", "err", err) return } for _, apiKey := range apiKeys { failures := apiKey.TotalRequests - apiKey.SuccessRequests if failures < int64(d.config.ProviderFailThreshold) { continue } // Check recent failure rate failureRate := float64(failures) / float64(apiKey.TotalRequests) if failureRate < 0.5 { // At least 50% failure rate continue } // Get provider group name var group model.ProviderGroup groupName := "" if err := d.db.Select("name").First(&group, apiKey.GroupID).Error; err == nil { groupName = group.Name } d.createAlertIfNew( model.AlertTypeProviderDown, model.AlertSeverityCritical, fmt.Sprintf("Provider API key failing (%d failures)", failures), fmt.Sprintf("API Key ID %d in group '%s' has %d failures out of %d requests (%.1f%% failure rate)", apiKey.ID, groupName, failures, apiKey.TotalRequests, failureRate*100), apiKey.ID, "apikey", groupName, ) } } // createAlertIfNew creates an alert if no duplicate exists within cooldown period func (d *AlertDetector) createAlertIfNew( alertType model.AlertType, severity model.AlertSeverity, title, message string, relatedID uint, relatedType, relatedName string, ) { fingerprint := fmt.Sprintf("%s:%s:%d", alertType, relatedType, relatedID) cooldownTime := time.Now().UTC().Add(-d.config.DeduplicationCooldown) // Check for existing active alert with same fingerprint var count int64 d.db.Model(&model.Alert{}). Where("fingerprint = ? AND status = ? AND created_at >= ?", fingerprint, model.AlertStatusActive, cooldownTime). Count(&count) if count > 0 { return // Duplicate within cooldown } alert := model.Alert{ Type: alertType, Severity: severity, Status: model.AlertStatusActive, Title: title, Message: message, RelatedID: relatedID, RelatedType: relatedType, RelatedName: relatedName, Fingerprint: fingerprint, } if err := d.db.Create(&alert).Error; err != nil { d.logger.Warn("failed to create alert", "type", alertType, "err", err) return } d.logger.Info("alert created", "type", alertType, "severity", severity, "title", title) } // errorRateSeverity determines severity based on error rate func (d *AlertDetector) errorRateSeverity(rate float64) model.AlertSeverity { if rate >= 0.5 { return model.AlertSeverityCritical } if rate >= 0.25 { return model.AlertSeverityWarning } return model.AlertSeverityInfo } // loadThresholdConfig loads the threshold config from DB with fixed ID=1, or returns defaults func (d *AlertDetector) loadThresholdConfig() model.AlertThresholdConfig { cfg := model.DefaultAlertThresholdConfig() cfg.ID = 1 // Fixed ID to ensure single config row if err := d.db.Where("id = ?", 1).FirstOrCreate(&cfg).Error; err != nil { d.logger.Warn("failed to load threshold config, using defaults", "err", err) return model.DefaultAlertThresholdConfig() } return cfg } // trafficSpikeSeverity determines severity based on value vs threshold // warning when >= threshold, critical when >= 2x threshold func trafficSpikeSeverity(value, threshold int64) model.AlertSeverity { if value >= threshold*2 { return model.AlertSeverityCritical } return model.AlertSeverityWarning } // trafficSpikeMetadata creates JSON metadata for traffic spike alerts type trafficSpikeMetadata struct { Metric string `json:"metric"` Value int64 `json:"value"` Threshold int64 `json:"threshold"` Window string `json:"window"` } func (m trafficSpikeMetadata) JSON() string { b, _ := json.Marshal(m) return string(b) } // detectTrafficSpikes checks for traffic threshold breaches func (d *AlertDetector) detectTrafficSpikes(ctx context.Context) { cfg := d.loadThresholdConfig() // 1. Global QPS check (requires Redis) if d.rdb != nil && d.statsService != nil { d.detectGlobalQPSSpike(ctx, cfg) } // 2. Per-master RPM/TPM (1-minute window) - uses logDB, works without Redis d.detectMasterMinuteSpikes(ctx, cfg) // 3. Per-master RPD/TPD (24-hour window) - uses logDB, works without Redis d.detectMasterDaySpikes(ctx, cfg) } // detectGlobalQPSSpike checks global QPS against threshold func (d *AlertDetector) detectGlobalQPSSpike(ctx context.Context, cfg model.AlertThresholdConfig) { // Sum QPS from all active masters var masters []model.Master if err := d.db.Where("status = ?", "active").Find(&masters).Error; err != nil { d.logger.Warn("failed to load masters for global QPS check", "err", err) return } var totalQPS int64 for _, master := range masters { snapshot, err := d.statsService.GetMasterRealtimeSnapshot(ctx, master.ID) if err != nil { continue } totalQPS += snapshot.QPS } if totalQPS >= cfg.GlobalQPS { meta := trafficSpikeMetadata{ Metric: "global_qps", Value: totalQPS, Threshold: cfg.GlobalQPS, Window: "realtime", } d.createTrafficSpikeAlert( trafficSpikeSeverity(totalQPS, cfg.GlobalQPS), fmt.Sprintf("Global QPS threshold exceeded (%d >= %d)", totalQPS, cfg.GlobalQPS), fmt.Sprintf("System-wide QPS is %d, threshold is %d", totalQPS, cfg.GlobalQPS), 0, "system", "global", meta, ) } } // detectMasterMinuteSpikes checks per-master RPM/TPM in 1-minute window func (d *AlertDetector) detectMasterMinuteSpikes(ctx context.Context, cfg model.AlertThresholdConfig) { since := time.Now().UTC().Add(-1 * time.Minute) // Query aggregated stats per master for 1-minute window type masterStat struct { MasterID uint Requests int64 TokensIn int64 TokensOut int64 } var stats []masterStat err := d.logDB.Model(&model.LogRecord{}). Select("master_id, COUNT(*) as requests, COALESCE(SUM(tokens_in), 0) as tokens_in, COALESCE(SUM(tokens_out), 0) as tokens_out"). Where("created_at >= ?", since). Where("master_id > 0"). Group("master_id"). Scan(&stats).Error if err != nil { d.logger.Warn("failed to query master minute stats", "err", err) return } // Load master names for alerts masterNames := make(map[uint]string) var masters []model.Master if err := d.db.Select("id, name").Find(&masters).Error; err == nil { for _, m := range masters { masterNames[m.ID] = m.Name } } for _, stat := range stats { masterName := masterNames[stat.MasterID] if masterName == "" { masterName = fmt.Sprintf("Master#%d", stat.MasterID) } // RPM check (with minimum sample threshold) if stat.Requests >= cfg.MinRPMRequests1m && stat.Requests >= cfg.MasterRPM { meta := trafficSpikeMetadata{ Metric: "master_rpm", Value: stat.Requests, Threshold: cfg.MasterRPM, Window: "1m", } d.createTrafficSpikeAlert( trafficSpikeSeverity(stat.Requests, cfg.MasterRPM), fmt.Sprintf("Master '%s' RPM threshold exceeded (%d >= %d)", masterName, stat.Requests, cfg.MasterRPM), fmt.Sprintf("Master '%s' (ID: %d) has %d requests in the last minute, threshold is %d", masterName, stat.MasterID, stat.Requests, cfg.MasterRPM), stat.MasterID, "master", masterName, meta, ) } // TPM check (with minimum sample threshold) totalTokens := stat.TokensIn + stat.TokensOut if totalTokens >= cfg.MinTPMTokens1m && totalTokens >= cfg.MasterTPM { meta := trafficSpikeMetadata{ Metric: "master_tpm", Value: totalTokens, Threshold: cfg.MasterTPM, Window: "1m", } d.createTrafficSpikeAlert( trafficSpikeSeverity(totalTokens, cfg.MasterTPM), fmt.Sprintf("Master '%s' TPM threshold exceeded (%d >= %d)", masterName, totalTokens, cfg.MasterTPM), fmt.Sprintf("Master '%s' (ID: %d) used %d tokens in the last minute, threshold is %d", masterName, stat.MasterID, totalTokens, cfg.MasterTPM), stat.MasterID, "master", masterName, meta, ) } } } // detectMasterDaySpikes checks per-master RPD/TPD in 24-hour window func (d *AlertDetector) detectMasterDaySpikes(ctx context.Context, cfg model.AlertThresholdConfig) { since := time.Now().UTC().Add(-24 * time.Hour) // Query aggregated stats per master for 24-hour window type masterStat struct { MasterID uint Requests int64 TokensIn int64 TokensOut int64 } var stats []masterStat err := d.logDB.Model(&model.LogRecord{}). Select("master_id, COUNT(*) as requests, COALESCE(SUM(tokens_in), 0) as tokens_in, COALESCE(SUM(tokens_out), 0) as tokens_out"). Where("created_at >= ?", since). Where("master_id > 0"). Group("master_id"). Scan(&stats).Error if err != nil { d.logger.Warn("failed to query master day stats", "err", err) return } // Load master names for alerts masterNames := make(map[uint]string) var masters []model.Master if err := d.db.Select("id, name").Find(&masters).Error; err == nil { for _, m := range masters { masterNames[m.ID] = m.Name } } for _, stat := range stats { masterName := masterNames[stat.MasterID] if masterName == "" { masterName = fmt.Sprintf("Master#%d", stat.MasterID) } // RPD check if stat.Requests >= cfg.MasterRPD { meta := trafficSpikeMetadata{ Metric: "master_rpd", Value: stat.Requests, Threshold: cfg.MasterRPD, Window: "24h", } d.createTrafficSpikeAlert( trafficSpikeSeverity(stat.Requests, cfg.MasterRPD), fmt.Sprintf("Master '%s' RPD threshold exceeded (%d >= %d)", masterName, stat.Requests, cfg.MasterRPD), fmt.Sprintf("Master '%s' (ID: %d) has %d requests in the last 24 hours, threshold is %d", masterName, stat.MasterID, stat.Requests, cfg.MasterRPD), stat.MasterID, "master", masterName, meta, ) } // TPD check totalTokens := stat.TokensIn + stat.TokensOut if totalTokens >= cfg.MasterTPD { meta := trafficSpikeMetadata{ Metric: "master_tpd", Value: totalTokens, Threshold: cfg.MasterTPD, Window: "24h", } d.createTrafficSpikeAlert( trafficSpikeSeverity(totalTokens, cfg.MasterTPD), fmt.Sprintf("Master '%s' TPD threshold exceeded (%d >= %d)", masterName, totalTokens, cfg.MasterTPD), fmt.Sprintf("Master '%s' (ID: %d) used %d tokens in the last 24 hours, threshold is %d", masterName, stat.MasterID, totalTokens, cfg.MasterTPD), stat.MasterID, "master", masterName, meta, ) } } } // createTrafficSpikeAlert creates a traffic_spike alert with metadata func (d *AlertDetector) createTrafficSpikeAlert( severity model.AlertSeverity, title, message string, relatedID uint, relatedType, relatedName string, meta trafficSpikeMetadata, ) { fingerprint := fmt.Sprintf("%s:%s:%s:%d", model.AlertTypeTrafficSpike, meta.Metric, relatedType, relatedID) cooldownTime := time.Now().UTC().Add(-d.config.DeduplicationCooldown) // Check for existing active alert with same fingerprint var count int64 d.db.Model(&model.Alert{}). Where("fingerprint = ? AND status = ? AND created_at >= ?", fingerprint, model.AlertStatusActive, cooldownTime). Count(&count) if count > 0 { return // Duplicate within cooldown } alert := model.Alert{ Type: model.AlertTypeTrafficSpike, Severity: severity, Status: model.AlertStatusActive, Title: title, Message: message, RelatedID: relatedID, RelatedType: relatedType, RelatedName: relatedName, Fingerprint: fingerprint, Metadata: meta.JSON(), } if err := d.db.Create(&alert).Error; err != nil { d.logger.Warn("failed to create traffic spike alert", "metric", meta.Metric, "err", err) return } d.logger.Info("traffic spike alert created", "metric", meta.Metric, "severity", severity, "title", title) }