package api import ( "bytes" "fmt" "math" "net/http" "net/http/httptest" "testing" "time" "github.com/ez-api/ez-api/internal/model" "github.com/gin-gonic/gin" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func TestInternalHandler_FlushStatsUpdatesCounters(t *testing.T) { gin.SetMode(gin.TestMode) db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.Key{}); err != nil { t.Fatalf("migrate: %v", err) } key1 := model.Key{ MasterID: 1, IssuedAtEpoch: 1, TokenHash: "hash-1", RequestCount: 5, UsedTokens: 20, QuotaLimit: 100, QuotaUsed: 10, } key2 := model.Key{ MasterID: 1, IssuedAtEpoch: 1, TokenHash: "hash-2", RequestCount: 0, UsedTokens: 0, QuotaLimit: -1, QuotaUsed: 7, } if err := db.Create(&key1).Error; err != nil { t.Fatalf("create key1: %v", err) } if err := db.Create(&key2).Error; err != nil { t.Fatalf("create key2: %v", err) } handler := NewInternalHandler(db) r := gin.New() r.POST("/internal/stats/flush", handler.FlushStats) body := []byte(`{ "keys": [ {"token_hash": "hash-1", "requests": 3, "tokens": 15, "last_accessed_at": 1700000000}, {"token_hash": "hash-2", "requests": 1, "tokens": 5, "last_accessed_at": 1700000010} ] }`) req := httptest.NewRequest(http.MethodPost, "/internal/stats/flush", bytes.NewReader(body)) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: got=%d body=%s", rec.Code, rec.Body.String()) } var got1 model.Key if err := db.First(&got1, "token_hash = ?", "hash-1").Error; err != nil { t.Fatalf("load key1: %v", err) } if got1.RequestCount != 8 { t.Fatalf("key1 request_count: got=%d", got1.RequestCount) } if got1.UsedTokens != 35 { t.Fatalf("key1 used_tokens: got=%d", got1.UsedTokens) } if got1.QuotaUsed != 25 { t.Fatalf("key1 quota_used: got=%d", got1.QuotaUsed) } if got1.LastAccessedAt == nil || got1.LastAccessedAt.Unix() != 1700000000 { t.Fatalf("key1 last_accessed_at: got=%v", got1.LastAccessedAt) } var got2 model.Key if err := db.First(&got2, "token_hash = ?", "hash-2").Error; err != nil { t.Fatalf("load key2: %v", err) } if got2.RequestCount != 1 { t.Fatalf("key2 request_count: got=%d", got2.RequestCount) } if got2.UsedTokens != 5 { t.Fatalf("key2 used_tokens: got=%d", got2.UsedTokens) } if got2.QuotaUsed != 7 { t.Fatalf("key2 quota_used: got=%d", got2.QuotaUsed) } if got2.LastAccessedAt == nil || got2.LastAccessedAt.UTC().Unix() != time.Unix(1700000010, 0).UTC().Unix() { t.Fatalf("key2 last_accessed_at: got=%v", got2.LastAccessedAt) } } func TestInternalHandler_FlushAPIKeyStatsUpdatesCounters(t *testing.T) { gin.SetMode(gin.TestMode) db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}); err != nil { t.Fatalf("migrate: %v", err) } group := model.ProviderGroup{Name: "g1", Type: "openai", BaseURL: "https://example.com"} if err := db.Create(&group).Error; err != nil { t.Fatalf("create group: %v", err) } key1 := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"} key2 := model.APIKey{GroupID: group.ID, APIKey: "k2", Status: "active"} if err := db.Create(&key1).Error; err != nil { t.Fatalf("create key1: %v", err) } if err := db.Create(&key2).Error; err != nil { t.Fatalf("create key2: %v", err) } handler := NewInternalHandler(db) r := gin.New() r.POST("/internal/apikey-stats/flush", handler.FlushAPIKeyStats) body := []byte(`{ "keys": [ {"api_key_id": ` + fmt.Sprint(key1.ID) + `, "requests": 5, "success_requests": 3}, {"api_key_id": ` + fmt.Sprint(key2.ID) + `, "requests": 4, "success_requests": 4} ] }`) req := httptest.NewRequest(http.MethodPost, "/internal/apikey-stats/flush", bytes.NewReader(body)) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: got=%d body=%s", rec.Code, rec.Body.String()) } var got1 model.APIKey if err := db.First(&got1, key1.ID).Error; err != nil { t.Fatalf("load key1: %v", err) } if got1.TotalRequests != 5 || got1.SuccessRequests != 3 || got1.FailureRequests != 2 { t.Fatalf("key1 counts: total=%d success=%d failure=%d", got1.TotalRequests, got1.SuccessRequests, got1.FailureRequests) } if math.Abs(got1.SuccessRate-0.6) > 1e-6 || math.Abs(got1.FailureRate-0.4) > 1e-6 { t.Fatalf("key1 rates: success=%f failure=%f", got1.SuccessRate, got1.FailureRate) } var got2 model.APIKey if err := db.First(&got2, key2.ID).Error; err != nil { t.Fatalf("load key2: %v", err) } if got2.TotalRequests != 4 || got2.SuccessRequests != 4 || got2.FailureRequests != 0 { t.Fatalf("key2 counts: total=%d success=%d failure=%d", got2.TotalRequests, got2.SuccessRequests, got2.FailureRequests) } if math.Abs(got2.SuccessRate-1.0) > 1e-6 || math.Abs(got2.FailureRate-0.0) > 1e-6 { t.Fatalf("key2 rates: success=%f failure=%f", got2.SuccessRate, got2.FailureRate) } var gotGroup model.ProviderGroup if err := db.First(&gotGroup, group.ID).Error; err != nil { t.Fatalf("load group: %v", err) } if gotGroup.TotalRequests != 9 || gotGroup.SuccessRequests != 7 || gotGroup.FailureRequests != 2 { t.Fatalf("group counts: total=%d success=%d failure=%d", gotGroup.TotalRequests, gotGroup.SuccessRequests, gotGroup.FailureRequests) } if math.Abs(gotGroup.SuccessRate-(7.0/9.0)) > 1e-6 || math.Abs(gotGroup.FailureRate-(2.0/9.0)) > 1e-6 { t.Fatalf("group rates: success=%f failure=%f", gotGroup.SuccessRate, gotGroup.FailureRate) } }