Files
ez-api/cmd/server/main.go
RC-CHN 1ee6bea413 feat(api): enhance whoami endpoint with realtime stats and extended key info
Add realtime statistics (requests, tokens, QPS, rate limiting) to whoami
response for both master and key authentication types. Extend key response
with additional fields including master name, model limits, quota tracking,
and usage statistics.

- Inject StatsService into AuthHandler for realtime stats retrieval
- Add WhoamiRealtimeView struct for realtime statistics
- Include admin permissions field in admin response
- Add comprehensive key metadata (quotas, model limits, usage stats)
- Add test for expired key returning 401 Unauthorized
2026-01-06 09:15:49 +08:00

534 lines
20 KiB
Go

package main
import (
"context"
"encoding/json"
"expvar"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/ez-api/ez-api/docs"
"github.com/ez-api/ez-api/internal/api"
"github.com/ez-api/ez-api/internal/config"
"github.com/ez-api/ez-api/internal/cron"
"github.com/ez-api/ez-api/internal/middleware"
"github.com/ez-api/ez-api/internal/migrate"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/service"
"github.com/ez-api/foundation/logging"
"github.com/ez-api/foundation/scheduler"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// @title EZ-API Control Plane
// @version 0.0.1
// @description Management API for EZ-API Gateway system.
// @termsOfService http://swagger.io/terms/
// @contact.name API Support
// @contact.url http://www.swagger.io/support
// @contact.email support@swagger.io
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host localhost:8080
// @BasePath /
// @securityDefinitions.apikey AdminAuth
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and the admin token. Example: Bearer admin123
// @securityDefinitions.apikey MasterAuth
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and the master key. Example: Bearer sk-xxx
func fatal(logger *slog.Logger, msg string, args ...any) {
logger.Error(msg, args...)
os.Exit(1)
}
func isOriginAllowed(allowed []string, origin string) bool {
if len(allowed) == 0 {
return false
}
for _, item := range allowed {
if item == "*" {
return true
}
if strings.EqualFold(strings.TrimSpace(item), strings.TrimSpace(origin)) {
return true
}
}
return false
}
func main() {
// Handle --version flag before any initialization
if len(os.Args) > 1 && (os.Args[1] == "--version" || os.Args[1] == "-v") {
fmt.Printf("ez-api %s\n", api.Version)
os.Exit(0)
}
logger, _ := logging.New(logging.Options{Service: "ez-api"})
appCtx, appCancel := context.WithCancel(context.Background())
defer appCancel()
if len(os.Args) > 1 && os.Args[1] == "import" {
code := runImport(logger, os.Args[2:])
os.Exit(code)
}
// 1. Load Configuration
cfg, err := config.Load()
if err != nil {
fatal(logger, "failed to load config", "err", err)
}
// 2. Initialize Redis Client
rdb := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Addr,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
// Verify Redis connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := rdb.Ping(ctx).Err(); err != nil {
fatal(logger, "failed to connect to redis", "err", err)
}
logger.Info("connected to redis successfully")
// 3. Initialize GORM (PostgreSQL)
db, err := gorm.Open(postgres.Open(cfg.Postgres.DSN), &gorm.Config{})
if err != nil {
fatal(logger, "failed to connect to postgresql", "err", err)
}
sqlDB, err := db.DB()
if err != nil {
fatal(logger, "failed to get generic database object", "err", err)
}
// Verify DB connection
if err := sqlDB.Ping(); err != nil {
fatal(logger, "failed to ping postgresql", "err", err)
}
logger.Info("connected to postgresql successfully")
logDB := db
if cfg.Log.DSN != "" {
logDB, err = gorm.Open(postgres.Open(cfg.Log.DSN), &gorm.Config{})
if err != nil {
fatal(logger, "failed to connect to log database", "err", err)
}
sqlLogDB, err := logDB.DB()
if err != nil {
fatal(logger, "failed to get log database object", "err", err)
}
if err := sqlLogDB.Ping(); err != nil {
fatal(logger, "failed to ping log database", "err", err)
}
logger.Info("connected to log database successfully")
}
// Auto Migrate
if logDB != db {
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.SyncOutbox{}, &model.Alert{}, &model.AlertThresholdConfig{}, &model.IPBan{}); err != nil {
fatal(logger, "failed to auto migrate", "err", err)
}
if err := logDB.AutoMigrate(&model.LogRecord{}); err != nil {
fatal(logger, "failed to auto migrate log tables", "err", err)
}
if err := service.EnsureLogIndexes(logDB); err != nil {
fatal(logger, "failed to ensure log indexes", "err", err)
}
} else {
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.LogRecord{}, &model.SyncOutbox{}, &model.Alert{}, &model.AlertThresholdConfig{}, &model.IPBan{}); err != nil {
fatal(logger, "failed to auto migrate", "err", err)
}
if err := service.EnsureLogIndexes(db); err != nil {
fatal(logger, "failed to ensure log indexes", "err", err)
}
}
// 4. Setup Services and Handlers
syncService := service.NewSyncService(rdb)
var outboxService *service.SyncOutboxService
if cfg.SyncOutbox.Enabled {
outboxCfg := service.SyncOutboxConfig{
Enabled: cfg.SyncOutbox.Enabled,
Interval: time.Duration(cfg.SyncOutbox.IntervalSeconds) * time.Second,
BatchSize: cfg.SyncOutbox.BatchSize,
MaxRetries: cfg.SyncOutbox.MaxRetries,
}
outboxService = service.NewSyncOutboxService(db, syncService, outboxCfg, logger)
syncService.SetOutbox(outboxService)
}
logPartitioner := service.NewLogPartitioner(logDB, cfg.Log.Partitioning)
if logPartitioner.Enabled() {
if _, err := logPartitioner.EnsurePartitionFor(time.Now().UTC()); err != nil {
fatal(logger, "failed to ensure log partition", "err", err)
}
}
logWriter := service.NewLogWriter(logDB, cfg.Log.QueueCapacity, cfg.Log.BatchSize, cfg.Log.FlushInterval, logPartitioner)
logCtx, cancelLogs := context.WithCancel(context.Background())
defer cancelLogs()
logWriter.Start(logCtx)
// Initialize cron jobs
quotaResetter := cron.NewQuotaResetter(db, syncService)
logCleaner := cron.NewLogCleaner(logDB, rdb, cfg.Log.RetentionDays, int64(cfg.Log.MaxRecords), logPartitioner)
tokenRefresher := cron.NewTokenRefresher(
db,
rdb,
syncService,
time.Duration(cfg.TokenRefresh.RefreshSkewSeconds)*time.Second,
cfg.TokenRefresh.BatchSize,
cfg.TokenRefresh.MaxRetries,
)
alertDetectorConfig := cron.DefaultAlertDetectorConfig()
alertDetector := cron.NewAlertDetector(db, logDB, rdb, service.NewStatsService(rdb), alertDetectorConfig, logger)
// Setup scheduler (jobs are added incrementally, Start() called after all services initialized)
sched := scheduler.New(
scheduler.WithLogger(logger),
scheduler.WithSkipIfRunning(),
scheduler.WithBaseContext(appCtx),
)
sched.Every("quota-reset", time.Duration(cfg.Quota.ResetIntervalSeconds)*time.Second, quotaResetter.RunOnce)
sched.Every("log-cleanup", time.Hour, logCleaner.RunOnce)
sched.Every("token-refresh", time.Duration(cfg.TokenRefresh.IntervalSeconds)*time.Second, tokenRefresher.RunOnce)
sched.Every("alert-detection", time.Minute, alertDetector.RunOnce)
if outboxService != nil && outboxService.Enabled() {
sched.Every("sync-outbox", outboxService.Interval(), outboxService.RunOnce)
}
adminService, err := service.NewAdminService()
if err != nil {
fatal(logger, "failed to create admin service", "err", err)
}
masterService := service.NewMasterService(db)
statsService := service.NewStatsService(rdb)
healthService := service.NewHealthCheckService(db, rdb)
statusHandler := api.NewStatusHandler(healthService)
handler := api.NewHandler(db, logDB, syncService, logWriter, rdb, logPartitioner)
adminHandler := api.NewAdminHandler(db, logDB, masterService, syncService, statsService, logPartitioner)
masterHandler := api.NewMasterHandler(db, logDB, masterService, syncService, statsService, logPartitioner)
dashboardHandler := api.NewDashboardHandler(db, logDB, statsService, logPartitioner)
alertHandler := api.NewAlertHandler(db)
internalHandler := api.NewInternalHandler(db)
featureHandler := api.NewFeatureHandler(rdb)
authHandler := api.NewAuthHandler(db, rdb, adminService, masterService, statsService)
ipBanService := service.NewIPBanService(db, rdb)
ipBanHandler := api.NewIPBanHandler(ipBanService)
ipBanManager := cron.NewIPBanManager(ipBanService)
modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{
Enabled: cfg.ModelRegistry.Enabled,
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
ModelsDevAPIBaseURL: cfg.ModelRegistry.ModelsDevAPIBaseURL,
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
CacheDir: cfg.ModelRegistry.CacheDir,
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
})
modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService)
// 4.1 Prime Redis snapshots so DP can start with data
if err := syncService.SyncAll(db); err != nil {
logger.Warn("initial sync warning", "err", err)
}
if err := ipBanService.SyncAllToRedis(context.Background()); err != nil {
logger.Warn("initial IP ban sync warning", "err", err)
}
// Initial model registry refresh before scheduler starts
if modelRegistryService.Enabled() {
modelRegistryService.RunOnce(context.Background())
sched.Every("model-registry-refresh", modelRegistryService.RefreshEvery(), modelRegistryService.RunOnce)
}
sched.Every("ip-ban-expire", time.Minute, ipBanManager.ExpireRunOnce)
sched.Every("ip-ban-hit-sync", 5*time.Minute, ipBanManager.HitSyncRunOnce)
sched.Every("ip-ban-full-sync", 5*time.Minute, ipBanManager.FullSyncRunOnce)
sched.Start()
// 5. Setup Gin Router
r := gin.Default()
r.Use(middleware.RequestID())
allowedOrigins := cfg.CORS.AllowOrigins
allowAllOrigins := isOriginAllowed(allowedOrigins, "*")
// CORS Middleware
r.Use(func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if allowAllOrigins {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" && isOriginAllowed(allowedOrigins, origin) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
// 动态设置 Swagger Host
if cfg.Server.SwaggerHost != "" {
docs.SwaggerInfo.Host = cfg.Server.SwaggerHost
} else {
docs.SwaggerInfo.Host = "" // 使用相对路径
}
// Health Check Endpoint
r.GET("/health", func(c *gin.Context) {
status := healthService.Check(c.Request.Context())
httpStatus := http.StatusOK
if status.Status == "down" {
httpStatus = http.StatusServiceUnavailable
}
c.JSON(httpStatus, status)
})
// Public Status Endpoints
r.GET("/status", statusHandler.Status)
r.GET("/about", statusHandler.About)
// Swagger Documentation
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
// API Routes
// Internal Routes
internalGroup := r.Group("/internal")
internalGroup.Use(middleware.InternalAuthMiddleware(cfg.Internal.StatsToken, cfg.Internal.AllowAnonymous))
{
internalGroup.POST("/stats/flush", internalHandler.FlushStats)
internalGroup.POST("/apikey-stats/flush", internalHandler.FlushAPIKeyStats)
internalGroup.POST("/alerts/report", internalHandler.ReportAlerts)
internalGroup.GET("/metrics", gin.WrapH(expvar.Handler()))
}
// Admin Routes
adminGroup := r.Group("/admin")
adminGroup.Use(middleware.AdminAuthMiddleware(adminService))
adminGroup.Use(middleware.OperationLogMiddleware(db))
{
adminGroup.POST("/masters", adminHandler.CreateMaster)
adminGroup.GET("/masters", adminHandler.ListMasters)
adminGroup.GET("/masters/:id", adminHandler.GetMaster)
adminGroup.GET("/masters/:id/realtime", adminHandler.GetMasterRealtime)
adminGroup.PUT("/masters/:id", adminHandler.UpdateMaster)
adminGroup.DELETE("/masters/:id", adminHandler.DeleteMaster)
adminGroup.POST("/masters/batch", adminHandler.BatchMasters)
adminGroup.POST("/masters/:id/manage", adminHandler.ManageMaster)
adminGroup.POST("/masters/:id/keys", adminHandler.IssueChildKeyForMaster)
adminGroup.GET("/masters/:id/access", handler.GetMasterAccess)
adminGroup.PUT("/masters/:id/access", handler.UpdateMasterAccess)
adminGroup.GET("/keys/:id/access", handler.GetKeyAccess)
adminGroup.PUT("/keys/:id/access", handler.UpdateKeyAccess)
adminGroup.GET("/operation-logs", adminHandler.ListOperationLogs)
adminGroup.POST("/namespaces", handler.CreateNamespace)
adminGroup.GET("/namespaces", handler.ListNamespaces)
adminGroup.GET("/namespaces/:id", handler.GetNamespace)
adminGroup.PUT("/namespaces/:id", handler.UpdateNamespace)
adminGroup.DELETE("/namespaces/:id", handler.DeleteNamespace)
adminGroup.GET("/features", featureHandler.ListFeatures)
adminGroup.PUT("/features", featureHandler.UpdateFeatures)
adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus)
adminGroup.POST("/model-registry/check", modelRegistryHandler.Check)
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
// Other admin routes for managing providers, models, etc.
adminGroup.POST("/provider-groups", handler.CreateProviderGroup)
adminGroup.GET("/provider-groups", handler.ListProviderGroups)
adminGroup.GET("/provider-groups/:id", handler.GetProviderGroup)
adminGroup.PUT("/provider-groups/:id", handler.UpdateProviderGroup)
adminGroup.DELETE("/provider-groups/:id", handler.DeleteProviderGroup)
adminGroup.POST("/api-keys", handler.CreateAPIKey)
adminGroup.GET("/api-keys", handler.ListAPIKeys)
adminGroup.GET("/api-keys/:id", handler.GetAPIKey)
adminGroup.PUT("/api-keys/:id", handler.UpdateAPIKey)
adminGroup.DELETE("/api-keys/:id", handler.DeleteAPIKey)
adminGroup.POST("/api-keys/batch", handler.BatchAPIKeys)
adminGroup.POST("/models", handler.CreateModel)
adminGroup.GET("/models", handler.ListModels)
adminGroup.PUT("/models/:id", handler.UpdateModel)
adminGroup.DELETE("/models/:id", handler.DeleteModel)
adminGroup.POST("/models/batch", handler.BatchModels)
adminGroup.GET("/logs", handler.ListLogs)
adminGroup.DELETE("/logs", handler.DeleteLogs)
adminGroup.GET("/logs/stats", handler.LogStats)
adminGroup.GET("/logs/stats/traffic-chart", handler.GetTrafficChart)
adminGroup.GET("/logs/webhook", handler.GetLogWebhookConfig)
adminGroup.PUT("/logs/webhook", handler.UpdateLogWebhookConfig)
adminGroup.GET("/stats", adminHandler.GetAdminStats)
adminGroup.GET("/realtime", adminHandler.GetAdminRealtime)
adminGroup.GET("/dashboard/summary", dashboardHandler.GetSummary)
adminGroup.GET("/apikey-stats/summary", adminHandler.GetAPIKeyStatsSummary)
adminGroup.GET("/alerts", alertHandler.ListAlerts)
adminGroup.POST("/alerts", alertHandler.CreateAlert)
adminGroup.GET("/alerts/stats", alertHandler.GetAlertStats)
adminGroup.GET("/alerts/thresholds", alertHandler.GetAlertThresholds)
adminGroup.PUT("/alerts/thresholds", alertHandler.UpdateAlertThresholds)
adminGroup.GET("/alerts/:id", alertHandler.GetAlert)
adminGroup.POST("/alerts/:id/ack", alertHandler.AcknowledgeAlert)
adminGroup.POST("/alerts/:id/resolve", alertHandler.ResolveAlert)
adminGroup.DELETE("/alerts/:id", alertHandler.DismissAlert)
adminGroup.POST("/bindings", handler.CreateBinding)
adminGroup.GET("/bindings", handler.ListBindings)
adminGroup.GET("/bindings/:id", handler.GetBinding)
adminGroup.PUT("/bindings/:id", handler.UpdateBinding)
adminGroup.DELETE("/bindings/:id", handler.DeleteBinding)
adminGroup.POST("/bindings/batch", handler.BatchBindings)
adminGroup.POST("/sync/snapshot", handler.SyncSnapshot)
// IP Ban routes
adminGroup.POST("/ip-bans", ipBanHandler.Create)
adminGroup.GET("/ip-bans", ipBanHandler.List)
adminGroup.GET("/ip-bans/:id", ipBanHandler.Get)
adminGroup.PUT("/ip-bans/:id", ipBanHandler.Update)
adminGroup.DELETE("/ip-bans/:id", ipBanHandler.Delete)
}
// Master Routes
masterGroup := r.Group("/v1")
masterGroup.Use(middleware.MasterAuthMiddleware(masterService))
{
masterGroup.GET("/self", masterHandler.GetSelf)
masterGroup.POST("/tokens", masterHandler.IssueChildKey)
masterGroup.GET("/tokens", masterHandler.ListTokens)
masterGroup.GET("/tokens/:id", masterHandler.GetToken)
masterGroup.PUT("/tokens/:id", masterHandler.UpdateToken)
masterGroup.DELETE("/tokens/:id", masterHandler.DeleteToken)
masterGroup.GET("/logs", masterHandler.ListSelfLogs)
masterGroup.GET("/logs/stats", masterHandler.GetSelfLogStats)
masterGroup.GET("/realtime", masterHandler.GetSelfRealtime)
masterGroup.GET("/stats", masterHandler.GetSelfStats)
}
// Auth Routes (public, no middleware - self-validates token)
authGroup := r.Group("/auth")
{
authGroup.GET("/whoami", authHandler.Whoami)
}
// Public/General Routes (if any)
r.POST("/logs", handler.IngestLog)
srv := &http.Server{
Addr: ":" + cfg.Server.Port,
Handler: r,
}
// 6. Start Server with Graceful Shutdown
go func() {
logger.Info("starting ez-api", "port", cfg.Server.Port)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fatal(logger, "server failed", "err", err)
}
}()
// Wait for interrupt signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("shutting down server...")
appCancel()
sched.Stop()
// Shutdown with timeout
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
fatal(logger, "server forced to shutdown", "err", err)
}
logger.Info("server exited properly")
}
func runImport(logger *slog.Logger, args []string) int {
fs := flag.NewFlagSet("import", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
var filePath string
var dryRun bool
var conflictPolicy string
var includeBindings bool
fs.StringVar(&filePath, "file", "", "Path to export JSON")
fs.BoolVar(&dryRun, "dry-run", false, "Validate only, do not write to database")
fs.StringVar(&conflictPolicy, "conflict", migrate.ConflictSkip, "Conflict policy: skip or overwrite")
fs.BoolVar(&includeBindings, "include-bindings", false, "Import bindings from payload")
if err := fs.Parse(args); err != nil {
logger.Error("failed to parse flags", "err", err)
return 2
}
if strings.TrimSpace(filePath) == "" {
fmt.Fprintln(os.Stderr, "missing --file")
return 2
}
conflictPolicy = strings.ToLower(strings.TrimSpace(conflictPolicy))
if conflictPolicy != migrate.ConflictSkip && conflictPolicy != migrate.ConflictOverwrite {
fmt.Fprintf(os.Stderr, "invalid --conflict value: %s\n", conflictPolicy)
return 2
}
cfg, err := config.Load()
if err != nil {
logger.Error("failed to load config", "err", err)
return 1
}
db, err := gorm.Open(postgres.Open(cfg.Postgres.DSN), &gorm.Config{})
if err != nil {
logger.Error("failed to connect to postgresql", "err", err)
return 1
}
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.SyncOutbox{}); err != nil {
logger.Error("failed to auto migrate", "err", err)
return 1
}
importer := migrate.NewImporter(db, migrate.ImportOptions{
DryRun: dryRun,
ConflictPolicy: conflictPolicy,
IncludeBindings: includeBindings,
})
summary, err := importer.ImportFile(filePath)
if err != nil {
logger.Error("import failed", "err", err)
return 1
}
payload, err := json.MarshalIndent(summary, "", " ")
if err != nil {
logger.Error("failed to render import summary", "err", err)
return 1
}
fmt.Fprintln(os.Stdout, string(payload))
if dryRun {
fmt.Fprintln(os.Stdout, "dry-run only: no data written")
}
return 0
}