diff --git a/cmd/server/main.go b/cmd/server/main.go index e68f11e..1c52419 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -12,6 +12,7 @@ import ( _ "github.com/ez-api/ez-api/docs" "github.com/ez-api/ez-api/internal/api" "github.com/ez-api/ez-api/internal/config" + "github.com/ez-api/ez-api/internal/cron" "github.com/ez-api/ez-api/internal/middleware" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/service" @@ -102,6 +103,10 @@ func main() { logCtx, cancelLogs := context.WithCancel(context.Background()) defer cancelLogs() logWriter.Start(logCtx) + quotaResetter := cron.NewQuotaResetter(db, syncService, time.Duration(cfg.Quota.ResetIntervalSeconds)*time.Second) + quotaCtx, cancelQuota := context.WithCancel(context.Background()) + defer cancelQuota() + go quotaResetter.Start(quotaCtx) adminService, err := service.NewAdminService() if err != nil { @@ -200,6 +205,7 @@ func main() { adminGroup.PUT("/models/:id", handler.UpdateModel) adminGroup.GET("/logs", handler.ListLogs) adminGroup.GET("/logs/stats", handler.LogStats) + adminGroup.GET("/stats", adminHandler.GetAdminStats) adminGroup.POST("/bindings", handler.CreateBinding) adminGroup.GET("/bindings", handler.ListBindings) adminGroup.GET("/bindings/:id", handler.GetBinding) @@ -219,6 +225,7 @@ func main() { masterGroup.PUT("/tokens/:id", masterHandler.UpdateToken) masterGroup.DELETE("/tokens/:id", masterHandler.DeleteToken) masterGroup.GET("/logs", masterHandler.ListSelfLogs) + masterGroup.GET("/logs/stats", masterHandler.GetSelfLogStats) masterGroup.GET("/stats", masterHandler.GetSelfStats) } diff --git a/internal/api/log_handler.go b/internal/api/log_handler.go index 2f0b949..9b989e3 100644 --- a/internal/api/log_handler.go +++ b/internal/api/log_handler.go @@ -16,6 +16,9 @@ type LogView struct { Group string `json:"group"` KeyID uint `json:"key_id"` ModelName string `json:"model"` + ProviderID uint `json:"provider_id"` + ProviderType string `json:"provider_type"` + ProviderName string `json:"provider_name"` StatusCode int `json:"status_code"` LatencyMs int64 `json:"latency_ms"` TokensIn int64 `json:"tokens_in"` @@ -34,6 +37,9 @@ func toLogView(r model.LogRecord) LogView { Group: r.Group, KeyID: r.KeyID, ModelName: r.ModelName, + ProviderID: r.ProviderID, + ProviderType: r.ProviderType, + ProviderName: r.ProviderName, StatusCode: r.StatusCode, LatencyMs: r.LatencyMs, TokensIn: r.TokensIn, @@ -276,7 +282,7 @@ func (h *MasterHandler) ListSelfLogs(c *gin.Context) { c.JSON(http.StatusOK, ListLogsResponse{Total: total, Limit: limit, Offset: offset, Items: out}) } -// GetSelfStats godoc +// GetSelfLogStats godoc // @Summary Log stats (master) // @Description Aggregate request log stats for the authenticated master // @Tags master @@ -287,8 +293,8 @@ func (h *MasterHandler) ListSelfLogs(c *gin.Context) { // @Success 200 {object} LogStatsResponse // @Failure 401 {object} gin.H // @Failure 500 {object} gin.H -// @Router /v1/stats [get] -func (h *MasterHandler) GetSelfStats(c *gin.Context) { +// @Router /v1/logs/stats [get] +func (h *MasterHandler) GetSelfLogStats(c *gin.Context) { master, exists := c.Get("master") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "master key not found in context"}) diff --git a/internal/api/stats_handler.go b/internal/api/stats_handler.go new file mode 100644 index 0000000..088f919 --- /dev/null +++ b/internal/api/stats_handler.go @@ -0,0 +1,270 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/ez-api/ez-api/internal/model" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type KeyUsageStat struct { + KeyID uint `json:"key_id"` + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +type ModelUsageStat struct { + Model string `json:"model"` + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +type MasterUsageStatsResponse struct { + Period string `json:"period,omitempty"` + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` + ByKey []KeyUsageStat `json:"by_key"` + ByModel []ModelUsageStat `json:"by_model"` +} + +// GetSelfStats godoc +// @Summary Usage stats (master) +// @Description Aggregate request stats for the authenticated master +// @Tags master +// @Produce json +// @Security MasterAuth +// @Param period query string false "today|week|month|all" +// @Param since query int false "unix seconds" +// @Param until query int false "unix seconds" +// @Success 200 {object} MasterUsageStatsResponse +// @Failure 400 {object} gin.H +// @Failure 401 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /v1/stats [get] +func (h *MasterHandler) GetSelfStats(c *gin.Context) { + master, exists := c.Get("master") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "master key not found in context"}) + return + } + m := master.(*model.Master) + + rng, err := parseStatsRange(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + base := h.db.Model(&model.LogRecord{}). + Joins("JOIN keys ON keys.id = log_records.key_id"). + Where("keys.master_id = ?", m.ID) + base = applyStatsRange(base, rng) + + totalRequests, totalTokens, err := aggregateTotals(base) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate stats", "details": err.Error()}) + return + } + + var byKey []KeyUsageStat + if err := base.Session(&gorm.Session{}). + Select("log_records.key_id as key_id, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens"). + Group("log_records.key_id"). + Scan(&byKey).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by key", "details": err.Error()}) + return + } + + var byModel []ModelUsageStat + if err := base.Session(&gorm.Session{}). + Select("log_records.model_name as model, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens"). + Group("log_records.model_name"). + Scan(&byModel).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by model", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, MasterUsageStatsResponse{ + Period: rng.Period, + TotalRequests: totalRequests, + TotalTokens: totalTokens, + ByKey: byKey, + ByModel: byModel, + }) +} + +type MasterUsageAgg struct { + MasterID uint `json:"master_id"` + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +type ProviderUsageAgg struct { + ProviderID uint `json:"provider_id"` + ProviderType string `json:"provider_type"` + ProviderName string `json:"provider_name"` + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +type AdminUsageStatsResponse struct { + Period string `json:"period,omitempty"` + TotalMasters int64 `json:"total_masters"` + ActiveMasters int64 `json:"active_masters"` + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` + ByMaster []MasterUsageAgg `json:"by_master"` + ByProvider []ProviderUsageAgg `json:"by_provider"` +} + +// GetAdminStats godoc +// @Summary Usage stats (admin) +// @Description Aggregate request stats across all masters +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param period query string false "today|week|month|all" +// @Param since query int false "unix seconds" +// @Param until query int false "unix seconds" +// @Success 200 {object} AdminUsageStatsResponse +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/stats [get] +func (h *AdminHandler) GetAdminStats(c *gin.Context) { + rng, err := parseStatsRange(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + var totalMasters int64 + if err := h.db.Model(&model.Master{}).Count(&totalMasters).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to count masters", "details": err.Error()}) + return + } + var activeMasters int64 + if err := h.db.Model(&model.Master{}).Where("status = ?", "active").Count(&activeMasters).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to count active masters", "details": err.Error()}) + return + } + + base := h.db.Model(&model.LogRecord{}) + base = applyStatsRange(base, rng) + + totalRequests, totalTokens, err := aggregateTotals(base) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate stats", "details": err.Error()}) + return + } + + var byMaster []MasterUsageAgg + if err := base.Session(&gorm.Session{}). + Joins("JOIN keys ON keys.id = log_records.key_id"). + Select("keys.master_id as master_id, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens"). + Group("keys.master_id"). + Scan(&byMaster).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by master", "details": err.Error()}) + return + } + + var byProvider []ProviderUsageAgg + if err := base.Session(&gorm.Session{}). + Select("log_records.provider_id as provider_id, log_records.provider_type as provider_type, log_records.provider_name as provider_name, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens"). + Group("log_records.provider_id, log_records.provider_type, log_records.provider_name"). + Scan(&byProvider).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by provider", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, AdminUsageStatsResponse{ + Period: rng.Period, + TotalMasters: totalMasters, + ActiveMasters: activeMasters, + TotalRequests: totalRequests, + TotalTokens: totalTokens, + ByMaster: byMaster, + ByProvider: byProvider, + }) +} + +type statsRange struct { + Since *time.Time + Until *time.Time + Period string +} + +func parseStatsRange(c *gin.Context) (statsRange, error) { + period := strings.ToLower(strings.TrimSpace(c.Query("period"))) + if period != "" { + if period == "all" { + return statsRange{Period: period}, nil + } + start, now := periodWindow(period) + if start.IsZero() { + return statsRange{}, fmt.Errorf("invalid period") + } + return statsRange{Since: &start, Until: &now, Period: period}, nil + } + + var since *time.Time + if t, ok := parseUnixSeconds(c.Query("since")); ok { + since = &t + } + var until *time.Time + if t, ok := parseUnixSeconds(c.Query("until")); ok { + until = &t + } + return statsRange{Since: since, Until: until}, nil +} + +func periodWindow(period string) (time.Time, time.Time) { + now := time.Now().UTC() + startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + switch period { + case "today": + return startOfDay, now + case "week": + weekday := int(startOfDay.Weekday()) + if weekday == 0 { + weekday = 7 + } + start := startOfDay.AddDate(0, 0, -(weekday - 1)) + return start, now + case "month": + start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + return start, now + default: + return time.Time{}, time.Time{} + } +} + +func applyStatsRange(q *gorm.DB, rng statsRange) *gorm.DB { + if rng.Since != nil { + q = q.Where("log_records.created_at >= ?", *rng.Since) + } + if rng.Until != nil { + q = q.Where("log_records.created_at <= ?", *rng.Until) + } + return q +} + +func aggregateTotals(q *gorm.DB) (int64, int64, error) { + var totalRequests int64 + if err := q.Session(&gorm.Session{}).Count(&totalRequests).Error; err != nil { + return 0, 0, err + } + type totals struct { + Tokens int64 + } + var t totals + if err := q.Session(&gorm.Session{}). + Select("COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens"). + Scan(&t).Error; err != nil { + return 0, 0, err + } + return totalRequests, t.Tokens, nil +} diff --git a/internal/api/stats_handler_test.go b/internal/api/stats_handler_test.go new file mode 100644 index 0000000..02fa1eb --- /dev/null +++ b/internal/api/stats_handler_test.go @@ -0,0 +1,191 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/ez-api/internal/service" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestMasterStats_AggregatesByKeyAndModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.LogRecord{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + m := &model.Master{Name: "m1", Group: "g", Status: "active", Epoch: 1, MasterKeyDigest: "d1"} + if err := db.Create(m).Error; err != nil { + t.Fatalf("create master: %v", err) + } + k1 := &model.Key{MasterID: m.ID, TokenHash: "h1", Group: "g", Status: "active", IssuedAtEpoch: 1} + k2 := &model.Key{MasterID: m.ID, TokenHash: "h2", Group: "g", Status: "active", IssuedAtEpoch: 1} + if err := db.Create(k1).Error; err != nil { + t.Fatalf("create k1: %v", err) + } + if err := db.Create(k2).Error; err != nil { + t.Fatalf("create k2: %v", err) + } + + if err := db.Create(&model.LogRecord{ + Group: "rg", + KeyID: k1.ID, + ModelName: "ns.m1", + ProviderID: 10, + ProviderType: "openai", + ProviderName: "p1", + StatusCode: 200, + TokensIn: 5, + TokensOut: 7, + }).Error; err != nil { + t.Fatalf("create log1: %v", err) + } + if err := db.Create(&model.LogRecord{ + Group: "rg", + KeyID: k2.ID, + ModelName: "ns.m2", + ProviderID: 11, + ProviderType: "anthropic", + ProviderName: "p2", + StatusCode: 200, + TokensIn: 2, + TokensOut: 3, + }).Error; err != nil { + t.Fatalf("create log2: %v", err) + } + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + masterSvc := service.NewMasterService(db) + syncSvc := service.NewSyncService(rdb) + h := NewMasterHandler(db, masterSvc, syncSvc) + + withMaster := func(next gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("master", m) + next(c) + } + } + + r := gin.New() + r.GET("/v1/stats", withMaster(h.GetSelfStats)) + + req := httptest.NewRequest(http.MethodGet, "/v1/stats?period=all", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + var resp MasterUsageStatsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.TotalRequests != 2 || resp.TotalTokens != 17 { + t.Fatalf("unexpected totals: %+v", resp) + } + if len(resp.ByKey) != 2 || len(resp.ByModel) != 2 { + t.Fatalf("unexpected breakdown: %+v", resp) + } +} + +func TestAdminStats_AggregatesByProvider(t *testing.T) { + gin.SetMode(gin.TestMode) + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.LogRecord{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + m1 := &model.Master{Name: "m1", Group: "g", Status: "active", Epoch: 1, MasterKeyDigest: "d1"} + m2 := &model.Master{Name: "m2", Group: "g", Status: "suspended", Epoch: 1, MasterKeyDigest: "d2"} + if err := db.Create(m1).Error; err != nil { + t.Fatalf("create m1: %v", err) + } + if err := db.Create(m2).Error; err != nil { + t.Fatalf("create m2: %v", err) + } + k1 := &model.Key{MasterID: m1.ID, TokenHash: "h1", Group: "g", Status: "active", IssuedAtEpoch: 1} + k2 := &model.Key{MasterID: m2.ID, TokenHash: "h2", Group: "g", Status: "active", IssuedAtEpoch: 1} + if err := db.Create(k1).Error; err != nil { + t.Fatalf("create k1: %v", err) + } + if err := db.Create(k2).Error; err != nil { + t.Fatalf("create k2: %v", err) + } + + if err := db.Create(&model.LogRecord{ + Group: "rg", + KeyID: k1.ID, + ModelName: "ns.m1", + ProviderID: 10, + ProviderType: "openai", + ProviderName: "p1", + StatusCode: 200, + TokensIn: 4, + TokensOut: 6, + }).Error; err != nil { + t.Fatalf("create log1: %v", err) + } + if err := db.Create(&model.LogRecord{ + Group: "rg", + KeyID: k2.ID, + ModelName: "ns.m2", + ProviderID: 11, + ProviderType: "anthropic", + ProviderName: "p2", + StatusCode: 200, + TokensIn: 1, + TokensOut: 2, + }).Error; err != nil { + t.Fatalf("create log2: %v", err) + } + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + masterSvc := service.NewMasterService(db) + syncSvc := service.NewSyncService(rdb) + adminHandler := NewAdminHandler(db, masterSvc, syncSvc) + + r := gin.New() + r.GET("/admin/stats", adminHandler.GetAdminStats) + + req := httptest.NewRequest(http.MethodGet, "/admin/stats?period=all", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + var resp AdminUsageStatsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.TotalMasters != 2 || resp.ActiveMasters != 1 { + t.Fatalf("unexpected master counts: %+v", resp) + } + if resp.TotalRequests != 2 || resp.TotalTokens != 13 { + t.Fatalf("unexpected totals: %+v", resp) + } + if len(resp.ByProvider) != 2 { + t.Fatalf("expected provider breakdown, got %+v", resp.ByProvider) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 8362bb6..3dc4c9b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,7 @@ type Config struct { Log LogConfig Auth AuthConfig ModelRegistry ModelRegistryConfig + Quota QuotaConfig } type ServerConfig struct { @@ -52,6 +53,10 @@ type ModelRegistryConfig struct { TimeoutSeconds int } +type QuotaConfig struct { + ResetIntervalSeconds int +} + func Load() (*Config, error) { v := viper.New() @@ -71,6 +76,7 @@ func Load() (*Config, error) { v.SetDefault("model_registry.models_dev_ref", "dev") v.SetDefault("model_registry.cache_dir", "./data/model-registry") v.SetDefault("model_registry.timeout_seconds", 30) + v.SetDefault("quota.reset_interval_seconds", 300) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() @@ -91,6 +97,7 @@ func Load() (*Config, error) { _ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF") _ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR") _ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS") + _ = v.BindEnv("quota.reset_interval_seconds", "EZ_QUOTA_RESET_INTERVAL_SECONDS") if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" { v.SetConfigFile(configFile) @@ -136,6 +143,9 @@ func Load() (*Config, error) { CacheDir: v.GetString("model_registry.cache_dir"), TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"), }, + Quota: QuotaConfig{ + ResetIntervalSeconds: v.GetInt("quota.reset_interval_seconds"), + }, } return cfg, nil diff --git a/internal/cron/quota_reset.go b/internal/cron/quota_reset.go new file mode 100644 index 0000000..a6e2a68 --- /dev/null +++ b/internal/cron/quota_reset.go @@ -0,0 +1,92 @@ +package cron + +import ( + "context" + "log/slog" + "strings" + "time" + + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/ez-api/internal/service" + "gorm.io/gorm" +) + +type QuotaResetter struct { + db *gorm.DB + sync *service.SyncService + interval time.Duration +} + +func NewQuotaResetter(db *gorm.DB, sync *service.SyncService, interval time.Duration) *QuotaResetter { + if interval <= 0 { + interval = 5 * time.Minute + } + return &QuotaResetter{db: db, sync: sync, interval: interval} +} + +func (q *QuotaResetter) Start(ctx context.Context) { + if q == nil || q.db == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + ticker := time.NewTicker(q.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := q.resetOnce(ctx); err != nil { + slog.Default().Warn("quota reset failed", "err", err) + } + } + } +} + +func (q *QuotaResetter) resetOnce(ctx context.Context) error { + if q == nil || q.db == nil { + return nil + } + now := time.Now().UTC() + var keys []model.Key + if err := q.db.Where("quota_reset_type IN ? AND (quota_reset_at IS NULL OR quota_reset_at <= ?)", []string{"daily", "monthly"}, now).Find(&keys).Error; err != nil { + return err + } + for i := range keys { + resetType := strings.ToLower(strings.TrimSpace(keys[i].QuotaResetType)) + nextAt, ok := nextQuotaReset(now, resetType) + if !ok { + continue + } + if err := q.db.Model(&keys[i]).Updates(map[string]any{ + "quota_used": 0, + "quota_reset_at": nextAt, + }).Error; err != nil { + slog.Default().Warn("quota reset update failed", "key_id", keys[i].ID, "err", err) + continue + } + keys[i].QuotaUsed = 0 + keys[i].QuotaResetAt = &nextAt + if q.sync != nil { + _ = q.sync.SyncKey(&keys[i]) + } + } + return nil +} + +func nextQuotaReset(now time.Time, resetType string) (time.Time, bool) { + now = now.UTC() + switch resetType { + case "daily": + next := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC).AddDate(0, 0, 1) + return next, true + case "monthly": + next := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, 0) + return next, true + default: + return time.Time{}, false + } +} diff --git a/internal/model/log.go b/internal/model/log.go index cc1aef5..552cc5c 100644 --- a/internal/model/log.go +++ b/internal/model/log.go @@ -8,6 +8,9 @@ type LogRecord struct { Group string `json:"group"` KeyID uint `json:"key_id"` ModelName string `json:"model"` + ProviderID uint `json:"provider_id"` + ProviderType string `json:"provider_type"` + ProviderName string `json:"provider_name"` StatusCode int `json:"status_code"` LatencyMs int64 `json:"latency_ms"` TokensIn int64 `json:"tokens_in"` diff --git a/internal/service/stats.go b/internal/service/stats.go new file mode 100644 index 0000000..83c3ae6 --- /dev/null +++ b/internal/service/stats.go @@ -0,0 +1,79 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +type StatsService struct { + rdb *redis.Client +} + +type RealtimeStats struct { + Requests int64 + Tokens int64 + LastAccessedAt *time.Time +} + +func NewStatsService(rdb *redis.Client) *StatsService { + return &StatsService{rdb: rdb} +} + +func (s *StatsService) GetKeyRealtimeStats(ctx context.Context, tokenHash string) (RealtimeStats, error) { + if s == nil || s.rdb == nil { + return RealtimeStats{}, fmt.Errorf("redis client is required") + } + tokenHash = strings.TrimSpace(tokenHash) + if tokenHash == "" { + return RealtimeStats{}, fmt.Errorf("token hash required") + } + if ctx == nil { + ctx = context.Background() + } + reqs, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:requests", tokenHash)).Int64() + if err != nil && err != redis.Nil { + return RealtimeStats{}, fmt.Errorf("read key requests: %w", err) + } + tokens, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:tokens", tokenHash)).Int64() + if err != nil && err != redis.Nil { + return RealtimeStats{}, fmt.Errorf("read key tokens: %w", err) + } + lastRaw, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:last_access", tokenHash)).Result() + if err != nil && err != redis.Nil { + return RealtimeStats{}, fmt.Errorf("read key last access: %w", err) + } + var lastAt *time.Time + if lastRaw != "" { + if sec, err := strconv.ParseInt(lastRaw, 10, 64); err == nil && sec > 0 { + t := time.Unix(sec, 0).UTC() + lastAt = &t + } + } + return RealtimeStats{Requests: reqs, Tokens: tokens, LastAccessedAt: lastAt}, nil +} + +func (s *StatsService) GetMasterRealtimeStats(ctx context.Context, masterID uint) (RealtimeStats, error) { + if s == nil || s.rdb == nil { + return RealtimeStats{}, fmt.Errorf("redis client is required") + } + if masterID == 0 { + return RealtimeStats{}, fmt.Errorf("master id required") + } + if ctx == nil { + ctx = context.Background() + } + reqs, err := s.rdb.Get(ctx, fmt.Sprintf("master:stats:%d:requests", masterID)).Int64() + if err != nil && err != redis.Nil { + return RealtimeStats{}, fmt.Errorf("read master requests: %w", err) + } + tokens, err := s.rdb.Get(ctx, fmt.Sprintf("master:stats:%d:tokens", masterID)).Int64() + if err != nil && err != redis.Nil { + return RealtimeStats{}, fmt.Errorf("read master tokens: %w", err) + } + return RealtimeStats{Requests: reqs, Tokens: tokens}, nil +} diff --git a/internal/service/stats_test.go b/internal/service/stats_test.go new file mode 100644 index 0000000..0b19e9f --- /dev/null +++ b/internal/service/stats_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func TestStatsService_KeyRealtimeStats(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + svc := NewStatsService(rdb) + + mr.Set("key:stats:hash:requests", "3") + mr.Set("key:stats:hash:tokens", "42") + mr.Set("key:stats:hash:last_access", "1700000000") + + stats, err := svc.GetKeyRealtimeStats(context.Background(), "hash") + if err != nil { + t.Fatalf("GetKeyRealtimeStats: %v", err) + } + if stats.Requests != 3 || stats.Tokens != 42 { + t.Fatalf("unexpected stats: %+v", stats) + } + if stats.LastAccessedAt == nil || !stats.LastAccessedAt.Equal(time.Unix(1700000000, 0).UTC()) { + t.Fatalf("unexpected last access: %+v", stats.LastAccessedAt) + } +} + +func TestStatsService_MasterRealtimeStats(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + svc := NewStatsService(rdb) + + mr.Set("master:stats:99:requests", "7") + mr.Set("master:stats:99:tokens", "100") + + stats, err := svc.GetMasterRealtimeStats(context.Background(), 99) + if err != nil { + t.Fatalf("GetMasterRealtimeStats: %v", err) + } + if stats.Requests != 7 || stats.Tokens != 100 { + t.Fatalf("unexpected stats: %+v", stats) + } +}