mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(stats): add usage stats and quota reset
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
|||||||
_ "github.com/ez-api/ez-api/docs"
|
_ "github.com/ez-api/ez-api/docs"
|
||||||
"github.com/ez-api/ez-api/internal/api"
|
"github.com/ez-api/ez-api/internal/api"
|
||||||
"github.com/ez-api/ez-api/internal/config"
|
"github.com/ez-api/ez-api/internal/config"
|
||||||
|
"github.com/ez-api/ez-api/internal/cron"
|
||||||
"github.com/ez-api/ez-api/internal/middleware"
|
"github.com/ez-api/ez-api/internal/middleware"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/ez-api/internal/service"
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
@@ -102,6 +103,10 @@ func main() {
|
|||||||
logCtx, cancelLogs := context.WithCancel(context.Background())
|
logCtx, cancelLogs := context.WithCancel(context.Background())
|
||||||
defer cancelLogs()
|
defer cancelLogs()
|
||||||
logWriter.Start(logCtx)
|
logWriter.Start(logCtx)
|
||||||
|
quotaResetter := cron.NewQuotaResetter(db, syncService, time.Duration(cfg.Quota.ResetIntervalSeconds)*time.Second)
|
||||||
|
quotaCtx, cancelQuota := context.WithCancel(context.Background())
|
||||||
|
defer cancelQuota()
|
||||||
|
go quotaResetter.Start(quotaCtx)
|
||||||
|
|
||||||
adminService, err := service.NewAdminService()
|
adminService, err := service.NewAdminService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -200,6 +205,7 @@ func main() {
|
|||||||
adminGroup.PUT("/models/:id", handler.UpdateModel)
|
adminGroup.PUT("/models/:id", handler.UpdateModel)
|
||||||
adminGroup.GET("/logs", handler.ListLogs)
|
adminGroup.GET("/logs", handler.ListLogs)
|
||||||
adminGroup.GET("/logs/stats", handler.LogStats)
|
adminGroup.GET("/logs/stats", handler.LogStats)
|
||||||
|
adminGroup.GET("/stats", adminHandler.GetAdminStats)
|
||||||
adminGroup.POST("/bindings", handler.CreateBinding)
|
adminGroup.POST("/bindings", handler.CreateBinding)
|
||||||
adminGroup.GET("/bindings", handler.ListBindings)
|
adminGroup.GET("/bindings", handler.ListBindings)
|
||||||
adminGroup.GET("/bindings/:id", handler.GetBinding)
|
adminGroup.GET("/bindings/:id", handler.GetBinding)
|
||||||
@@ -219,6 +225,7 @@ func main() {
|
|||||||
masterGroup.PUT("/tokens/:id", masterHandler.UpdateToken)
|
masterGroup.PUT("/tokens/:id", masterHandler.UpdateToken)
|
||||||
masterGroup.DELETE("/tokens/:id", masterHandler.DeleteToken)
|
masterGroup.DELETE("/tokens/:id", masterHandler.DeleteToken)
|
||||||
masterGroup.GET("/logs", masterHandler.ListSelfLogs)
|
masterGroup.GET("/logs", masterHandler.ListSelfLogs)
|
||||||
|
masterGroup.GET("/logs/stats", masterHandler.GetSelfLogStats)
|
||||||
masterGroup.GET("/stats", masterHandler.GetSelfStats)
|
masterGroup.GET("/stats", masterHandler.GetSelfStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ type LogView struct {
|
|||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
KeyID uint `json:"key_id"`
|
KeyID uint `json:"key_id"`
|
||||||
ModelName string `json:"model"`
|
ModelName string `json:"model"`
|
||||||
|
ProviderID uint `json:"provider_id"`
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderName string `json:"provider_name"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
LatencyMs int64 `json:"latency_ms"`
|
LatencyMs int64 `json:"latency_ms"`
|
||||||
TokensIn int64 `json:"tokens_in"`
|
TokensIn int64 `json:"tokens_in"`
|
||||||
@@ -34,6 +37,9 @@ func toLogView(r model.LogRecord) LogView {
|
|||||||
Group: r.Group,
|
Group: r.Group,
|
||||||
KeyID: r.KeyID,
|
KeyID: r.KeyID,
|
||||||
ModelName: r.ModelName,
|
ModelName: r.ModelName,
|
||||||
|
ProviderID: r.ProviderID,
|
||||||
|
ProviderType: r.ProviderType,
|
||||||
|
ProviderName: r.ProviderName,
|
||||||
StatusCode: r.StatusCode,
|
StatusCode: r.StatusCode,
|
||||||
LatencyMs: r.LatencyMs,
|
LatencyMs: r.LatencyMs,
|
||||||
TokensIn: r.TokensIn,
|
TokensIn: r.TokensIn,
|
||||||
@@ -276,7 +282,7 @@ func (h *MasterHandler) ListSelfLogs(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, ListLogsResponse{Total: total, Limit: limit, Offset: offset, Items: out})
|
c.JSON(http.StatusOK, ListLogsResponse{Total: total, Limit: limit, Offset: offset, Items: out})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSelfStats godoc
|
// GetSelfLogStats godoc
|
||||||
// @Summary Log stats (master)
|
// @Summary Log stats (master)
|
||||||
// @Description Aggregate request log stats for the authenticated master
|
// @Description Aggregate request log stats for the authenticated master
|
||||||
// @Tags master
|
// @Tags master
|
||||||
@@ -287,8 +293,8 @@ func (h *MasterHandler) ListSelfLogs(c *gin.Context) {
|
|||||||
// @Success 200 {object} LogStatsResponse
|
// @Success 200 {object} LogStatsResponse
|
||||||
// @Failure 401 {object} gin.H
|
// @Failure 401 {object} gin.H
|
||||||
// @Failure 500 {object} gin.H
|
// @Failure 500 {object} gin.H
|
||||||
// @Router /v1/stats [get]
|
// @Router /v1/logs/stats [get]
|
||||||
func (h *MasterHandler) GetSelfStats(c *gin.Context) {
|
func (h *MasterHandler) GetSelfLogStats(c *gin.Context) {
|
||||||
master, exists := c.Get("master")
|
master, exists := c.Get("master")
|
||||||
if !exists {
|
if !exists {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "master key not found in context"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "master key not found in context"})
|
||||||
|
|||||||
270
internal/api/stats_handler.go
Normal file
270
internal/api/stats_handler.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyUsageStat struct {
|
||||||
|
KeyID uint `json:"key_id"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelUsageStat struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MasterUsageStatsResponse struct {
|
||||||
|
Period string `json:"period,omitempty"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
ByKey []KeyUsageStat `json:"by_key"`
|
||||||
|
ByModel []ModelUsageStat `json:"by_model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSelfStats godoc
|
||||||
|
// @Summary Usage stats (master)
|
||||||
|
// @Description Aggregate request stats for the authenticated master
|
||||||
|
// @Tags master
|
||||||
|
// @Produce json
|
||||||
|
// @Security MasterAuth
|
||||||
|
// @Param period query string false "today|week|month|all"
|
||||||
|
// @Param since query int false "unix seconds"
|
||||||
|
// @Param until query int false "unix seconds"
|
||||||
|
// @Success 200 {object} MasterUsageStatsResponse
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 401 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /v1/stats [get]
|
||||||
|
func (h *MasterHandler) GetSelfStats(c *gin.Context) {
|
||||||
|
master, exists := c.Get("master")
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "master key not found in context"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m := master.(*model.Master)
|
||||||
|
|
||||||
|
rng, err := parseStatsRange(c)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
base := h.db.Model(&model.LogRecord{}).
|
||||||
|
Joins("JOIN keys ON keys.id = log_records.key_id").
|
||||||
|
Where("keys.master_id = ?", m.ID)
|
||||||
|
base = applyStatsRange(base, rng)
|
||||||
|
|
||||||
|
totalRequests, totalTokens, err := aggregateTotals(base)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate stats", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var byKey []KeyUsageStat
|
||||||
|
if err := base.Session(&gorm.Session{}).
|
||||||
|
Select("log_records.key_id as key_id, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens").
|
||||||
|
Group("log_records.key_id").
|
||||||
|
Scan(&byKey).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by key", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var byModel []ModelUsageStat
|
||||||
|
if err := base.Session(&gorm.Session{}).
|
||||||
|
Select("log_records.model_name as model, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens").
|
||||||
|
Group("log_records.model_name").
|
||||||
|
Scan(&byModel).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by model", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, MasterUsageStatsResponse{
|
||||||
|
Period: rng.Period,
|
||||||
|
TotalRequests: totalRequests,
|
||||||
|
TotalTokens: totalTokens,
|
||||||
|
ByKey: byKey,
|
||||||
|
ByModel: byModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type MasterUsageAgg struct {
|
||||||
|
MasterID uint `json:"master_id"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderUsageAgg struct {
|
||||||
|
ProviderID uint `json:"provider_id"`
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderName string `json:"provider_name"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdminUsageStatsResponse struct {
|
||||||
|
Period string `json:"period,omitempty"`
|
||||||
|
TotalMasters int64 `json:"total_masters"`
|
||||||
|
ActiveMasters int64 `json:"active_masters"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
ByMaster []MasterUsageAgg `json:"by_master"`
|
||||||
|
ByProvider []ProviderUsageAgg `json:"by_provider"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdminStats godoc
|
||||||
|
// @Summary Usage stats (admin)
|
||||||
|
// @Description Aggregate request stats across all masters
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param period query string false "today|week|month|all"
|
||||||
|
// @Param since query int false "unix seconds"
|
||||||
|
// @Param until query int false "unix seconds"
|
||||||
|
// @Success 200 {object} AdminUsageStatsResponse
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/stats [get]
|
||||||
|
func (h *AdminHandler) GetAdminStats(c *gin.Context) {
|
||||||
|
rng, err := parseStatsRange(c)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalMasters int64
|
||||||
|
if err := h.db.Model(&model.Master{}).Count(&totalMasters).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to count masters", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var activeMasters int64
|
||||||
|
if err := h.db.Model(&model.Master{}).Where("status = ?", "active").Count(&activeMasters).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to count active masters", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
base := h.db.Model(&model.LogRecord{})
|
||||||
|
base = applyStatsRange(base, rng)
|
||||||
|
|
||||||
|
totalRequests, totalTokens, err := aggregateTotals(base)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate stats", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var byMaster []MasterUsageAgg
|
||||||
|
if err := base.Session(&gorm.Session{}).
|
||||||
|
Joins("JOIN keys ON keys.id = log_records.key_id").
|
||||||
|
Select("keys.master_id as master_id, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens").
|
||||||
|
Group("keys.master_id").
|
||||||
|
Scan(&byMaster).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by master", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var byProvider []ProviderUsageAgg
|
||||||
|
if err := base.Session(&gorm.Session{}).
|
||||||
|
Select("log_records.provider_id as provider_id, log_records.provider_type as provider_type, log_records.provider_name as provider_name, COUNT(*) as requests, COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens").
|
||||||
|
Group("log_records.provider_id, log_records.provider_type, log_records.provider_name").
|
||||||
|
Scan(&byProvider).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to group by provider", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, AdminUsageStatsResponse{
|
||||||
|
Period: rng.Period,
|
||||||
|
TotalMasters: totalMasters,
|
||||||
|
ActiveMasters: activeMasters,
|
||||||
|
TotalRequests: totalRequests,
|
||||||
|
TotalTokens: totalTokens,
|
||||||
|
ByMaster: byMaster,
|
||||||
|
ByProvider: byProvider,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type statsRange struct {
|
||||||
|
Since *time.Time
|
||||||
|
Until *time.Time
|
||||||
|
Period string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatsRange(c *gin.Context) (statsRange, error) {
|
||||||
|
period := strings.ToLower(strings.TrimSpace(c.Query("period")))
|
||||||
|
if period != "" {
|
||||||
|
if period == "all" {
|
||||||
|
return statsRange{Period: period}, nil
|
||||||
|
}
|
||||||
|
start, now := periodWindow(period)
|
||||||
|
if start.IsZero() {
|
||||||
|
return statsRange{}, fmt.Errorf("invalid period")
|
||||||
|
}
|
||||||
|
return statsRange{Since: &start, Until: &now, Period: period}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var since *time.Time
|
||||||
|
if t, ok := parseUnixSeconds(c.Query("since")); ok {
|
||||||
|
since = &t
|
||||||
|
}
|
||||||
|
var until *time.Time
|
||||||
|
if t, ok := parseUnixSeconds(c.Query("until")); ok {
|
||||||
|
until = &t
|
||||||
|
}
|
||||||
|
return statsRange{Since: since, Until: until}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func periodWindow(period string) (time.Time, time.Time) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
switch period {
|
||||||
|
case "today":
|
||||||
|
return startOfDay, now
|
||||||
|
case "week":
|
||||||
|
weekday := int(startOfDay.Weekday())
|
||||||
|
if weekday == 0 {
|
||||||
|
weekday = 7
|
||||||
|
}
|
||||||
|
start := startOfDay.AddDate(0, 0, -(weekday - 1))
|
||||||
|
return start, now
|
||||||
|
case "month":
|
||||||
|
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
return start, now
|
||||||
|
default:
|
||||||
|
return time.Time{}, time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyStatsRange(q *gorm.DB, rng statsRange) *gorm.DB {
|
||||||
|
if rng.Since != nil {
|
||||||
|
q = q.Where("log_records.created_at >= ?", *rng.Since)
|
||||||
|
}
|
||||||
|
if rng.Until != nil {
|
||||||
|
q = q.Where("log_records.created_at <= ?", *rng.Until)
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func aggregateTotals(q *gorm.DB) (int64, int64, error) {
|
||||||
|
var totalRequests int64
|
||||||
|
if err := q.Session(&gorm.Session{}).Count(&totalRequests).Error; err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
type totals struct {
|
||||||
|
Tokens int64
|
||||||
|
}
|
||||||
|
var t totals
|
||||||
|
if err := q.Session(&gorm.Session{}).
|
||||||
|
Select("COALESCE(SUM(log_records.tokens_in + log_records.tokens_out),0) as tokens").
|
||||||
|
Scan(&t).Error; err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
return totalRequests, t.Tokens, nil
|
||||||
|
}
|
||||||
191
internal/api/stats_handler_test.go
Normal file
191
internal/api/stats_handler_test.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMasterStats_AggregatesByKeyAndModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||||
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.LogRecord{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &model.Master{Name: "m1", Group: "g", Status: "active", Epoch: 1, MasterKeyDigest: "d1"}
|
||||||
|
if err := db.Create(m).Error; err != nil {
|
||||||
|
t.Fatalf("create master: %v", err)
|
||||||
|
}
|
||||||
|
k1 := &model.Key{MasterID: m.ID, TokenHash: "h1", Group: "g", Status: "active", IssuedAtEpoch: 1}
|
||||||
|
k2 := &model.Key{MasterID: m.ID, TokenHash: "h2", Group: "g", Status: "active", IssuedAtEpoch: 1}
|
||||||
|
if err := db.Create(k1).Error; err != nil {
|
||||||
|
t.Fatalf("create k1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(k2).Error; err != nil {
|
||||||
|
t.Fatalf("create k2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&model.LogRecord{
|
||||||
|
Group: "rg",
|
||||||
|
KeyID: k1.ID,
|
||||||
|
ModelName: "ns.m1",
|
||||||
|
ProviderID: 10,
|
||||||
|
ProviderType: "openai",
|
||||||
|
ProviderName: "p1",
|
||||||
|
StatusCode: 200,
|
||||||
|
TokensIn: 5,
|
||||||
|
TokensOut: 7,
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("create log1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.LogRecord{
|
||||||
|
Group: "rg",
|
||||||
|
KeyID: k2.ID,
|
||||||
|
ModelName: "ns.m2",
|
||||||
|
ProviderID: 11,
|
||||||
|
ProviderType: "anthropic",
|
||||||
|
ProviderName: "p2",
|
||||||
|
StatusCode: 200,
|
||||||
|
TokensIn: 2,
|
||||||
|
TokensOut: 3,
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("create log2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
masterSvc := service.NewMasterService(db)
|
||||||
|
syncSvc := service.NewSyncService(rdb)
|
||||||
|
h := NewMasterHandler(db, masterSvc, syncSvc)
|
||||||
|
|
||||||
|
withMaster := func(next gin.HandlerFunc) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Set("master", m)
|
||||||
|
next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.GET("/v1/stats", withMaster(h.GetSelfStats))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/stats?period=all", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
var resp MasterUsageStatsResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if resp.TotalRequests != 2 || resp.TotalTokens != 17 {
|
||||||
|
t.Fatalf("unexpected totals: %+v", resp)
|
||||||
|
}
|
||||||
|
if len(resp.ByKey) != 2 || len(resp.ByModel) != 2 {
|
||||||
|
t.Fatalf("unexpected breakdown: %+v", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminStats_AggregatesByProvider(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||||
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.LogRecord{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m1 := &model.Master{Name: "m1", Group: "g", Status: "active", Epoch: 1, MasterKeyDigest: "d1"}
|
||||||
|
m2 := &model.Master{Name: "m2", Group: "g", Status: "suspended", Epoch: 1, MasterKeyDigest: "d2"}
|
||||||
|
if err := db.Create(m1).Error; err != nil {
|
||||||
|
t.Fatalf("create m1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(m2).Error; err != nil {
|
||||||
|
t.Fatalf("create m2: %v", err)
|
||||||
|
}
|
||||||
|
k1 := &model.Key{MasterID: m1.ID, TokenHash: "h1", Group: "g", Status: "active", IssuedAtEpoch: 1}
|
||||||
|
k2 := &model.Key{MasterID: m2.ID, TokenHash: "h2", Group: "g", Status: "active", IssuedAtEpoch: 1}
|
||||||
|
if err := db.Create(k1).Error; err != nil {
|
||||||
|
t.Fatalf("create k1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(k2).Error; err != nil {
|
||||||
|
t.Fatalf("create k2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&model.LogRecord{
|
||||||
|
Group: "rg",
|
||||||
|
KeyID: k1.ID,
|
||||||
|
ModelName: "ns.m1",
|
||||||
|
ProviderID: 10,
|
||||||
|
ProviderType: "openai",
|
||||||
|
ProviderName: "p1",
|
||||||
|
StatusCode: 200,
|
||||||
|
TokensIn: 4,
|
||||||
|
TokensOut: 6,
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("create log1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.LogRecord{
|
||||||
|
Group: "rg",
|
||||||
|
KeyID: k2.ID,
|
||||||
|
ModelName: "ns.m2",
|
||||||
|
ProviderID: 11,
|
||||||
|
ProviderType: "anthropic",
|
||||||
|
ProviderName: "p2",
|
||||||
|
StatusCode: 200,
|
||||||
|
TokensIn: 1,
|
||||||
|
TokensOut: 2,
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("create log2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
masterSvc := service.NewMasterService(db)
|
||||||
|
syncSvc := service.NewSyncService(rdb)
|
||||||
|
adminHandler := NewAdminHandler(db, masterSvc, syncSvc)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.GET("/admin/stats", adminHandler.GetAdminStats)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/stats?period=all", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
var resp AdminUsageStatsResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if resp.TotalMasters != 2 || resp.ActiveMasters != 1 {
|
||||||
|
t.Fatalf("unexpected master counts: %+v", resp)
|
||||||
|
}
|
||||||
|
if resp.TotalRequests != 2 || resp.TotalTokens != 13 {
|
||||||
|
t.Fatalf("unexpected totals: %+v", resp)
|
||||||
|
}
|
||||||
|
if len(resp.ByProvider) != 2 {
|
||||||
|
t.Fatalf("expected provider breakdown, got %+v", resp.ByProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ type Config struct {
|
|||||||
Log LogConfig
|
Log LogConfig
|
||||||
Auth AuthConfig
|
Auth AuthConfig
|
||||||
ModelRegistry ModelRegistryConfig
|
ModelRegistry ModelRegistryConfig
|
||||||
|
Quota QuotaConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
@@ -52,6 +53,10 @@ type ModelRegistryConfig struct {
|
|||||||
TimeoutSeconds int
|
TimeoutSeconds int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QuotaConfig struct {
|
||||||
|
ResetIntervalSeconds int
|
||||||
|
}
|
||||||
|
|
||||||
func Load() (*Config, error) {
|
func Load() (*Config, error) {
|
||||||
v := viper.New()
|
v := viper.New()
|
||||||
|
|
||||||
@@ -71,6 +76,7 @@ func Load() (*Config, error) {
|
|||||||
v.SetDefault("model_registry.models_dev_ref", "dev")
|
v.SetDefault("model_registry.models_dev_ref", "dev")
|
||||||
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
||||||
v.SetDefault("model_registry.timeout_seconds", 30)
|
v.SetDefault("model_registry.timeout_seconds", 30)
|
||||||
|
v.SetDefault("quota.reset_interval_seconds", 300)
|
||||||
|
|
||||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
v.AutomaticEnv()
|
v.AutomaticEnv()
|
||||||
@@ -91,6 +97,7 @@ func Load() (*Config, error) {
|
|||||||
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
||||||
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
||||||
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
||||||
|
_ = v.BindEnv("quota.reset_interval_seconds", "EZ_QUOTA_RESET_INTERVAL_SECONDS")
|
||||||
|
|
||||||
if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" {
|
if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" {
|
||||||
v.SetConfigFile(configFile)
|
v.SetConfigFile(configFile)
|
||||||
@@ -136,6 +143,9 @@ func Load() (*Config, error) {
|
|||||||
CacheDir: v.GetString("model_registry.cache_dir"),
|
CacheDir: v.GetString("model_registry.cache_dir"),
|
||||||
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
||||||
},
|
},
|
||||||
|
Quota: QuotaConfig{
|
||||||
|
ResetIntervalSeconds: v.GetInt("quota.reset_interval_seconds"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
|
|||||||
92
internal/cron/quota_reset.go
Normal file
92
internal/cron/quota_reset.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package cron
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QuotaResetter struct {
|
||||||
|
db *gorm.DB
|
||||||
|
sync *service.SyncService
|
||||||
|
interval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQuotaResetter(db *gorm.DB, sync *service.SyncService, interval time.Duration) *QuotaResetter {
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = 5 * time.Minute
|
||||||
|
}
|
||||||
|
return &QuotaResetter{db: db, sync: sync, interval: interval}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QuotaResetter) Start(ctx context.Context) {
|
||||||
|
if q == nil || q.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(q.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := q.resetOnce(ctx); err != nil {
|
||||||
|
slog.Default().Warn("quota reset failed", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QuotaResetter) resetOnce(ctx context.Context) error {
|
||||||
|
if q == nil || q.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
var keys []model.Key
|
||||||
|
if err := q.db.Where("quota_reset_type IN ? AND (quota_reset_at IS NULL OR quota_reset_at <= ?)", []string{"daily", "monthly"}, now).Find(&keys).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for i := range keys {
|
||||||
|
resetType := strings.ToLower(strings.TrimSpace(keys[i].QuotaResetType))
|
||||||
|
nextAt, ok := nextQuotaReset(now, resetType)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := q.db.Model(&keys[i]).Updates(map[string]any{
|
||||||
|
"quota_used": 0,
|
||||||
|
"quota_reset_at": nextAt,
|
||||||
|
}).Error; err != nil {
|
||||||
|
slog.Default().Warn("quota reset update failed", "key_id", keys[i].ID, "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys[i].QuotaUsed = 0
|
||||||
|
keys[i].QuotaResetAt = &nextAt
|
||||||
|
if q.sync != nil {
|
||||||
|
_ = q.sync.SyncKey(&keys[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextQuotaReset(now time.Time, resetType string) (time.Time, bool) {
|
||||||
|
now = now.UTC()
|
||||||
|
switch resetType {
|
||||||
|
case "daily":
|
||||||
|
next := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC).AddDate(0, 0, 1)
|
||||||
|
return next, true
|
||||||
|
case "monthly":
|
||||||
|
next := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, 0)
|
||||||
|
return next, true
|
||||||
|
default:
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,9 @@ type LogRecord struct {
|
|||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
KeyID uint `json:"key_id"`
|
KeyID uint `json:"key_id"`
|
||||||
ModelName string `json:"model"`
|
ModelName string `json:"model"`
|
||||||
|
ProviderID uint `json:"provider_id"`
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderName string `json:"provider_name"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
LatencyMs int64 `json:"latency_ms"`
|
LatencyMs int64 `json:"latency_ms"`
|
||||||
TokensIn int64 `json:"tokens_in"`
|
TokensIn int64 `json:"tokens_in"`
|
||||||
|
|||||||
79
internal/service/stats.go
Normal file
79
internal/service/stats.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StatsService struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type RealtimeStats struct {
|
||||||
|
Requests int64
|
||||||
|
Tokens int64
|
||||||
|
LastAccessedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStatsService(rdb *redis.Client) *StatsService {
|
||||||
|
return &StatsService{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StatsService) GetKeyRealtimeStats(ctx context.Context, tokenHash string) (RealtimeStats, error) {
|
||||||
|
if s == nil || s.rdb == nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("redis client is required")
|
||||||
|
}
|
||||||
|
tokenHash = strings.TrimSpace(tokenHash)
|
||||||
|
if tokenHash == "" {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("token hash required")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
reqs, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:requests", tokenHash)).Int64()
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("read key requests: %w", err)
|
||||||
|
}
|
||||||
|
tokens, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:tokens", tokenHash)).Int64()
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("read key tokens: %w", err)
|
||||||
|
}
|
||||||
|
lastRaw, err := s.rdb.Get(ctx, fmt.Sprintf("key:stats:%s:last_access", tokenHash)).Result()
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("read key last access: %w", err)
|
||||||
|
}
|
||||||
|
var lastAt *time.Time
|
||||||
|
if lastRaw != "" {
|
||||||
|
if sec, err := strconv.ParseInt(lastRaw, 10, 64); err == nil && sec > 0 {
|
||||||
|
t := time.Unix(sec, 0).UTC()
|
||||||
|
lastAt = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RealtimeStats{Requests: reqs, Tokens: tokens, LastAccessedAt: lastAt}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StatsService) GetMasterRealtimeStats(ctx context.Context, masterID uint) (RealtimeStats, error) {
|
||||||
|
if s == nil || s.rdb == nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("redis client is required")
|
||||||
|
}
|
||||||
|
if masterID == 0 {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("master id required")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
reqs, err := s.rdb.Get(ctx, fmt.Sprintf("master:stats:%d:requests", masterID)).Int64()
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("read master requests: %w", err)
|
||||||
|
}
|
||||||
|
tokens, err := s.rdb.Get(ctx, fmt.Sprintf("master:stats:%d:tokens", masterID)).Int64()
|
||||||
|
if err != nil && err != redis.Nil {
|
||||||
|
return RealtimeStats{}, fmt.Errorf("read master tokens: %w", err)
|
||||||
|
}
|
||||||
|
return RealtimeStats{Requests: reqs, Tokens: tokens}, nil
|
||||||
|
}
|
||||||
52
internal/service/stats_test.go
Normal file
52
internal/service/stats_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatsService_KeyRealtimeStats(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
svc := NewStatsService(rdb)
|
||||||
|
|
||||||
|
mr.Set("key:stats:hash:requests", "3")
|
||||||
|
mr.Set("key:stats:hash:tokens", "42")
|
||||||
|
mr.Set("key:stats:hash:last_access", "1700000000")
|
||||||
|
|
||||||
|
stats, err := svc.GetKeyRealtimeStats(context.Background(), "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetKeyRealtimeStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.Requests != 3 || stats.Tokens != 42 {
|
||||||
|
t.Fatalf("unexpected stats: %+v", stats)
|
||||||
|
}
|
||||||
|
if stats.LastAccessedAt == nil || !stats.LastAccessedAt.Equal(time.Unix(1700000000, 0).UTC()) {
|
||||||
|
t.Fatalf("unexpected last access: %+v", stats.LastAccessedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatsService_MasterRealtimeStats(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
svc := NewStatsService(rdb)
|
||||||
|
|
||||||
|
mr.Set("master:stats:99:requests", "7")
|
||||||
|
mr.Set("master:stats:99:tokens", "100")
|
||||||
|
|
||||||
|
stats, err := svc.GetMasterRealtimeStats(context.Background(), 99)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetMasterRealtimeStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.Requests != 7 || stats.Tokens != 100 {
|
||||||
|
t.Fatalf("unexpected stats: %+v", stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user