diff --git a/cmd/server/main.go b/cmd/server/main.go index 4136bfc..285d974 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -297,6 +297,7 @@ func main() { internalGroup.Use(middleware.InternalAuthMiddleware(cfg.Internal.StatsToken)) { internalGroup.POST("/stats/flush", internalHandler.FlushStats) + internalGroup.POST("/apikey-stats/flush", internalHandler.FlushAPIKeyStats) internalGroup.GET("/metrics", gin.WrapH(expvar.Handler())) } @@ -353,6 +354,7 @@ func main() { adminGroup.GET("/logs/webhook", handler.GetLogWebhookConfig) adminGroup.PUT("/logs/webhook", handler.UpdateLogWebhookConfig) adminGroup.GET("/stats", adminHandler.GetAdminStats) + adminGroup.GET("/apikey-stats/summary", adminHandler.GetAPIKeyStatsSummary) adminGroup.POST("/bindings", handler.CreateBinding) adminGroup.GET("/bindings", handler.ListBindings) adminGroup.GET("/bindings/:id", handler.GetBinding) diff --git a/internal/api/apikey_stats_handler.go b/internal/api/apikey_stats_handler.go new file mode 100644 index 0000000..7c62d79 --- /dev/null +++ b/internal/api/apikey_stats_handler.go @@ -0,0 +1,67 @@ +package api + +import ( + "net/http" + + "github.com/ez-api/ez-api/internal/model" + "github.com/gin-gonic/gin" +) + +type APIKeyStatsSummaryResponse struct { + TotalRequests int64 `json:"total_requests"` + SuccessRequests int64 `json:"success_requests"` + FailureRequests int64 `json:"failure_requests"` + SuccessRate float64 `json:"success_rate"` + FailureRate float64 `json:"failure_rate"` +} + +// GetAPIKeyStatsSummary godoc +// @Summary APIKey stats summary (admin) +// @Description Aggregate APIKey success/failure stats across all provider groups +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Success 200 {object} APIKeyStatsSummaryResponse +// @Failure 500 {object} gin.H +// @Router /admin/apikey-stats/summary [get] +func (h *AdminHandler) GetAPIKeyStatsSummary(c *gin.Context) { + if h == nil || h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database not configured"}) + return + } + + var totals struct { + TotalRequests int64 `json:"total_requests"` + SuccessRequests int64 `json:"success_requests"` + FailureRequests int64 `json:"failure_requests"` + } + + if err := h.db.Model(&model.APIKey{}). + Select("COALESCE(SUM(total_requests),0) as total_requests, COALESCE(SUM(success_requests),0) as success_requests"). + Scan(&totals).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate api key stats", "details": err.Error()}) + return + } + + total := totals.TotalRequests + success := totals.SuccessRequests + failure := total - success + if failure < 0 { + failure = 0 + } + + var successRate float64 + var failureRate float64 + if total > 0 { + successRate = float64(success) / float64(total) + failureRate = float64(failure) / float64(total) + } + + c.JSON(http.StatusOK, APIKeyStatsSummaryResponse{ + TotalRequests: total, + SuccessRequests: success, + FailureRequests: failure, + SuccessRate: successRate, + FailureRate: failureRate, + }) +} diff --git a/internal/api/apikey_stats_handler_test.go b/internal/api/apikey_stats_handler_test.go new file mode 100644 index 0000000..97d18e2 --- /dev/null +++ b/internal/api/apikey_stats_handler_test.go @@ -0,0 +1,68 @@ +package api + +import ( + "encoding/json" + "math" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ez-api/ez-api/internal/model" + "github.com/gin-gonic/gin" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestAdminHandler_GetAPIKeyStatsSummary(t *testing.T) { + gin.SetMode(gin.TestMode) + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.APIKey{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + if err := db.Create(&model.APIKey{ + GroupID: 1, + APIKey: "k1", + TotalRequests: 10, + SuccessRequests: 7, + FailureRequests: 3, + }).Error; err != nil { + t.Fatalf("create key1: %v", err) + } + if err := db.Create(&model.APIKey{ + GroupID: 1, + APIKey: "k2", + TotalRequests: 5, + SuccessRequests: 5, + FailureRequests: 0, + }).Error; err != nil { + t.Fatalf("create key2: %v", err) + } + + handler := &AdminHandler{db: db} + r := gin.New() + r.GET("/admin/apikey-stats/summary", handler.GetAPIKeyStatsSummary) + + req := httptest.NewRequest(http.MethodGet, "/admin/apikey-stats/summary", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got=%d body=%s", rec.Code, rec.Body.String()) + } + + var resp APIKeyStatsSummaryResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.TotalRequests != 15 || resp.SuccessRequests != 12 || resp.FailureRequests != 3 { + t.Fatalf("totals mismatch: %+v", resp) + } + if math.Abs(resp.SuccessRate-0.8) > 1e-6 || math.Abs(resp.FailureRate-0.2) > 1e-6 { + t.Fatalf("rates mismatch: success=%f failure=%f", resp.SuccessRate, resp.FailureRate) + } +} diff --git a/internal/api/internal_handler.go b/internal/api/internal_handler.go index 5d07603..59c2b43 100644 --- a/internal/api/internal_handler.go +++ b/internal/api/internal_handler.go @@ -29,6 +29,16 @@ type statsFlushEntry struct { LastAccessedAt int64 `json:"last_accessed_at"` } +type apiKeyStatsFlushRequest struct { + Keys []apiKeyStatsFlushEntry `json:"keys"` +} + +type apiKeyStatsFlushEntry struct { + APIKeyID uint `json:"api_key_id"` + Requests int64 `json:"requests"` + SuccessRequests int64 `json:"success_requests"` +} + // FlushStats godoc // @Summary Flush key stats // @Description Internal endpoint for flushing accumulated key usage stats from DP to CP database @@ -105,3 +115,139 @@ func (h *InternalHandler) FlushStats(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"updated": updated}) } + +// FlushAPIKeyStats godoc +// @Summary Flush API key stats +// @Description Internal endpoint for flushing accumulated APIKey stats from DP to CP database +// @Tags internal +// @Accept json +// @Produce json +// @Param request body apiKeyStatsFlushRequest true "Stats to flush" +// @Success 200 {object} gin.H +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /internal/apikey-stats/flush [post] +func (h *InternalHandler) FlushAPIKeyStats(c *gin.Context) { + if h == nil || h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database not configured"}) + return + } + + var req apiKeyStatsFlushRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + if len(req.Keys) == 0 { + c.JSON(http.StatusOK, gin.H{"updated": 0, "groups_updated": 0}) + return + } + + type apiKeyDelta struct { + requests int64 + success int64 + } + + deltas := make(map[uint]apiKeyDelta, len(req.Keys)) + for _, entry := range req.Keys { + if entry.APIKeyID == 0 { + continue + } + if entry.Requests < 0 || entry.SuccessRequests < 0 || entry.SuccessRequests > entry.Requests { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid stats payload"}) + return + } + if entry.Requests == 0 && entry.SuccessRequests == 0 { + continue + } + delta := deltas[entry.APIKeyID] + delta.requests += entry.Requests + delta.success += entry.SuccessRequests + deltas[entry.APIKeyID] = delta + } + + if len(deltas) == 0 { + c.JSON(http.StatusOK, gin.H{"updated": 0, "groups_updated": 0}) + return + } + + ids := make([]uint, 0, len(deltas)) + for id := range deltas { + ids = append(ids, id) + } + + var apiKeys []model.APIKey + if err := h.db.Model(&model.APIKey{}).Select("id, group_id").Where("id IN ?", ids).Find(&apiKeys).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load api keys", "details": err.Error()}) + return + } + groupByKey := make(map[uint]uint, len(apiKeys)) + for _, key := range apiKeys { + groupByKey[key.ID] = key.GroupID + } + + statsUpdates := func(requests, success int64) map[string]any { + return map[string]any{ + "total_requests": gorm.Expr("total_requests + ?", requests), + "success_requests": gorm.Expr("success_requests + ?", success), + "failure_requests": gorm.Expr("(total_requests + ?) - (success_requests + ?)", requests, success), + "success_rate": gorm.Expr( + "CASE WHEN (total_requests + ?) > 0 THEN (success_requests + ?) * 1.0 / (total_requests + ?) ELSE 0 END", + requests, success, requests, + ), + "failure_rate": gorm.Expr( + "CASE WHEN (total_requests + ?) > 0 THEN ((total_requests + ?) - (success_requests + ?)) * 1.0 / (total_requests + ?) ELSE 0 END", + requests, requests, success, requests, + ), + } + } + + updated := 0 + groupsUpdated := 0 + groupDeltas := make(map[uint]apiKeyDelta) + + err := h.db.Transaction(func(tx *gorm.DB) error { + for id, delta := range deltas { + groupID, ok := groupByKey[id] + if !ok { + continue + } + if delta.requests == 0 && delta.success == 0 { + continue + } + res := tx.Model(&model.APIKey{}).Where("id = ?", id).Updates(statsUpdates(delta.requests, delta.success)) + if res.Error != nil { + return res.Error + } + if res.RowsAffected > 0 { + updated++ + } + if groupID > 0 { + groupDelta := groupDeltas[groupID] + groupDelta.requests += delta.requests + groupDelta.success += delta.success + groupDeltas[groupID] = groupDelta + } + } + + for groupID, delta := range groupDeltas { + if delta.requests == 0 && delta.success == 0 { + continue + } + res := tx.Model(&model.ProviderGroup{}).Where("id = ?", groupID).Updates(statsUpdates(delta.requests, delta.success)) + if res.Error != nil { + return res.Error + } + if res.RowsAffected > 0 { + groupsUpdated++ + } + } + return nil + }) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to flush api key stats", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"updated": updated, "groups_updated": groupsUpdated}) +} diff --git a/internal/api/internal_handler_test.go b/internal/api/internal_handler_test.go index 1eed227..0b16dfc 100644 --- a/internal/api/internal_handler_test.go +++ b/internal/api/internal_handler_test.go @@ -2,6 +2,8 @@ package api import ( "bytes" + "fmt" + "math" "net/http" "net/http/httptest" "testing" @@ -102,3 +104,79 @@ func TestInternalHandler_FlushStatsUpdatesCounters(t *testing.T) { t.Fatalf("key2 last_accessed_at: got=%v", got2.LastAccessedAt) } } + +func TestInternalHandler_FlushAPIKeyStatsUpdatesCounters(t *testing.T) { + gin.SetMode(gin.TestMode) + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + group := model.ProviderGroup{Name: "g1", Type: "openai", BaseURL: "https://example.com"} + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) + } + key1 := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"} + key2 := model.APIKey{GroupID: group.ID, APIKey: "k2", Status: "active"} + if err := db.Create(&key1).Error; err != nil { + t.Fatalf("create key1: %v", err) + } + if err := db.Create(&key2).Error; err != nil { + t.Fatalf("create key2: %v", err) + } + + handler := NewInternalHandler(db) + r := gin.New() + r.POST("/internal/apikey-stats/flush", handler.FlushAPIKeyStats) + + body := []byte(`{ + "keys": [ + {"api_key_id": ` + fmt.Sprint(key1.ID) + `, "requests": 5, "success_requests": 3}, + {"api_key_id": ` + fmt.Sprint(key2.ID) + `, "requests": 4, "success_requests": 4} + ] + }`) + req := httptest.NewRequest(http.MethodPost, "/internal/apikey-stats/flush", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got=%d body=%s", rec.Code, rec.Body.String()) + } + + var got1 model.APIKey + if err := db.First(&got1, key1.ID).Error; err != nil { + t.Fatalf("load key1: %v", err) + } + if got1.TotalRequests != 5 || got1.SuccessRequests != 3 || got1.FailureRequests != 2 { + t.Fatalf("key1 counts: total=%d success=%d failure=%d", got1.TotalRequests, got1.SuccessRequests, got1.FailureRequests) + } + if math.Abs(got1.SuccessRate-0.6) > 1e-6 || math.Abs(got1.FailureRate-0.4) > 1e-6 { + t.Fatalf("key1 rates: success=%f failure=%f", got1.SuccessRate, got1.FailureRate) + } + + var got2 model.APIKey + if err := db.First(&got2, key2.ID).Error; err != nil { + t.Fatalf("load key2: %v", err) + } + if got2.TotalRequests != 4 || got2.SuccessRequests != 4 || got2.FailureRequests != 0 { + t.Fatalf("key2 counts: total=%d success=%d failure=%d", got2.TotalRequests, got2.SuccessRequests, got2.FailureRequests) + } + if math.Abs(got2.SuccessRate-1.0) > 1e-6 || math.Abs(got2.FailureRate-0.0) > 1e-6 { + t.Fatalf("key2 rates: success=%f failure=%f", got2.SuccessRate, got2.FailureRate) + } + + var gotGroup model.ProviderGroup + if err := db.First(&gotGroup, group.ID).Error; err != nil { + t.Fatalf("load group: %v", err) + } + if gotGroup.TotalRequests != 9 || gotGroup.SuccessRequests != 7 || gotGroup.FailureRequests != 2 { + t.Fatalf("group counts: total=%d success=%d failure=%d", gotGroup.TotalRequests, gotGroup.SuccessRequests, gotGroup.FailureRequests) + } + if math.Abs(gotGroup.SuccessRate-(7.0/9.0)) > 1e-6 || math.Abs(gotGroup.FailureRate-(2.0/9.0)) > 1e-6 { + t.Fatalf("group rates: success=%f failure=%f", gotGroup.SuccessRate, gotGroup.FailureRate) + } +} diff --git a/internal/model/provider_group.go b/internal/model/provider_group.go index 00add9e..e45269b 100644 --- a/internal/model/provider_group.go +++ b/internal/model/provider_group.go @@ -9,18 +9,18 @@ import ( // ProviderGroup represents a shared upstream definition. type ProviderGroup struct { gorm.Model - Name string `gorm:"size:255;uniqueIndex;not null" json:"name"` - Type string `gorm:"size:50;not null" json:"type"` // openai, anthropic, gemini - BaseURL string `gorm:"size:512;not null" json:"base_url"` - GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` - GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` - StaticHeaders string `gorm:"type:text" json:"static_headers,omitempty"` - HeadersProfile string `gorm:"size:64" json:"headers_profile,omitempty"` - Models string `json:"models"` // comma-separated list of supported models - Status string `gorm:"size:50;default:'active'" json:"status"` - TotalRequests int64 `gorm:"default:0" json:"total_requests"` - SuccessRequests int64 `gorm:"default:0" json:"success_requests"` - FailureRequests int64 `gorm:"default:0" json:"failure_requests"` + Name string `gorm:"size:255;uniqueIndex;not null" json:"name"` + Type string `gorm:"size:50;not null" json:"type"` // openai, anthropic, gemini + BaseURL string `gorm:"size:512;not null" json:"base_url"` + GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` + GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` + StaticHeaders string `gorm:"type:text" json:"static_headers,omitempty"` + HeadersProfile string `gorm:"size:64" json:"headers_profile,omitempty"` + Models string `json:"models"` // comma-separated list of supported models + Status string `gorm:"size:50;default:'active'" json:"status"` + TotalRequests int64 `gorm:"default:0" json:"total_requests"` + SuccessRequests int64 `gorm:"default:0" json:"success_requests"` + FailureRequests int64 `gorm:"default:0" json:"failure_requests"` SuccessRate float64 `gorm:"default:0" json:"success_rate"` FailureRate float64 `gorm:"default:0" json:"failure_rate"` } @@ -28,21 +28,21 @@ type ProviderGroup struct { // APIKey represents a credential within a provider group. type APIKey struct { gorm.Model - GroupID uint `gorm:"not null;index" json:"group_id"` - APIKey string `gorm:"not null" json:"api_key"` - AccessToken string `gorm:"type:text" json:"access_token,omitempty"` - RefreshToken string `gorm:"type:text" json:"-"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` - AccountID string `gorm:"size:255" json:"account_id,omitempty"` - ProjectID string `gorm:"size:255" json:"project_id,omitempty"` - Weight int `gorm:"default:1" json:"weight"` - Status string `gorm:"size:50;default:'active'" json:"status"` - AutoBan bool `gorm:"default:true" json:"auto_ban"` - BanReason string `gorm:"size:255" json:"ban_reason"` - BanUntil *time.Time `json:"ban_until"` - TotalRequests int64 `gorm:"default:0" json:"total_requests"` - SuccessRequests int64 `gorm:"default:0" json:"success_requests"` - FailureRequests int64 `gorm:"default:0" json:"failure_requests"` - SuccessRate float64 `gorm:"default:0" json:"success_rate"` - FailureRate float64 `gorm:"default:0" json:"failure_rate"` + GroupID uint `gorm:"not null;index" json:"group_id"` + APIKey string `gorm:"not null" json:"api_key"` + AccessToken string `gorm:"type:text" json:"access_token,omitempty"` + RefreshToken string `gorm:"type:text" json:"-"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + AccountID string `gorm:"size:255" json:"account_id,omitempty"` + ProjectID string `gorm:"size:255" json:"project_id,omitempty"` + Weight int `gorm:"default:1" json:"weight"` + Status string `gorm:"size:50;default:'active'" json:"status"` + AutoBan bool `gorm:"default:true" json:"auto_ban"` + BanReason string `gorm:"size:255" json:"ban_reason"` + BanUntil *time.Time `json:"ban_until"` + TotalRequests int64 `gorm:"default:0" json:"total_requests"` + SuccessRequests int64 `gorm:"default:0" json:"success_requests"` + FailureRequests int64 `gorm:"default:0" json:"failure_requests"` + SuccessRate float64 `gorm:"default:0" json:"success_rate"` + FailureRate float64 `gorm:"default:0" json:"failure_rate"` }