From dea8363e412ca6f35f5ea2c2bd576275f4310622 Mon Sep 17 00:00:00 2001 From: zenfun Date: Wed, 24 Dec 2025 02:15:52 +0800 Subject: [PATCH] refactor(api): split Provider into ProviderGroup and APIKey models Restructure the provider management system by separating the monolithic Provider model into two distinct entities: - ProviderGroup: defines shared upstream configuration (type, base_url, google settings, models, status) - APIKey: represents individual credentials within a group (api_key, weight, status, auto_ban, ban settings) This change also updates: - Binding model to reference GroupID instead of RouteGroup string - All CRUD handlers for the new provider-group and api-key endpoints - Sync service to rebuild provider snapshots from joined tables - Model registry to aggregate capabilities across group/key pairs - Access handler to validate namespace existence and subset constraints - Migration importer to handle the new schema structure - All related tests to use the new model relationships BREAKING CHANGE: Provider API endpoints replaced with /provider-groups and /api-keys endpoints; Binding.RouteGroup replaced with Binding.GroupID --- cmd/server/main.go | 28 +- internal/api/access_handler.go | 54 +++ internal/api/api_key_handler.go | 259 ++++++++++++ internal/api/batch_handler.go | 32 +- internal/api/binding_handler.go | 59 ++- internal/api/handler.go | 251 ------------ internal/api/model_handler_test.go | 12 +- internal/api/namespace_handler_test.go | 17 +- internal/api/provider_admin_handler.go | 350 ----------------- internal/api/provider_admin_handler_test.go | 162 -------- internal/api/provider_create_handler.go | 288 -------------- internal/api/provider_create_handler_test.go | 92 ----- internal/api/provider_group_handler.go | 302 ++++++++++++++ internal/api/provider_handler_test.go | 105 ----- internal/dto/api_key.go | 14 + internal/dto/binding.go | 5 +- internal/dto/provider_group.go | 12 + internal/migrate/importer.go | 74 ++-- internal/migrate/schema.go | 2 +- internal/model/models.go | 30 +- internal/model/provider_group.go | 31 ++ internal/service/model_registry.go | 103 +++-- internal/service/model_registry_check_test.go | 2 +- internal/service/model_registry_test.go | 27 +- internal/service/sync.go | 370 ++++++++++-------- internal/service/sync_bindings_spec_test.go | 70 ++-- internal/service/sync_test.go | 96 ++--- 27 files changed, 1222 insertions(+), 1625 deletions(-) create mode 100644 internal/api/api_key_handler.go delete mode 100644 internal/api/provider_admin_handler.go delete mode 100644 internal/api/provider_admin_handler_test.go delete mode 100644 internal/api/provider_create_handler.go delete mode 100644 internal/api/provider_create_handler_test.go create mode 100644 internal/api/provider_group_handler.go delete mode 100644 internal/api/provider_handler_test.go create mode 100644 internal/dto/api_key.go create mode 100644 internal/dto/provider_group.go create mode 100644 internal/model/provider_group.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 3c4c8b7..4a3ac62 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -135,7 +135,7 @@ func main() { // Auto Migrate if logDB != db { - if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}); err != nil { + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}); err != nil { fatal(logger, "failed to auto migrate", "err", err) } if err := logDB.AutoMigrate(&model.LogRecord{}); err != nil { @@ -145,7 +145,7 @@ func main() { fatal(logger, "failed to ensure log indexes", "err", err) } } else { - if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.LogRecord{}); err != nil { + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.LogRecord{}); err != nil { fatal(logger, "failed to auto migrate", "err", err) } if err := service.EnsureLogIndexes(db); err != nil { @@ -287,17 +287,17 @@ func main() { adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh) adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback) // Other admin routes for managing providers, models, etc. - adminGroup.POST("/providers", handler.CreateProvider) - adminGroup.GET("/providers", handler.ListProviders) - adminGroup.GET("/providers/:id", handler.GetProvider) - adminGroup.POST("/providers/preset", handler.CreateProviderPreset) - adminGroup.POST("/providers/custom", handler.CreateProviderCustom) - adminGroup.POST("/providers/google", handler.CreateProviderGoogle) - adminGroup.PUT("/providers/:id", handler.UpdateProvider) - adminGroup.DELETE("/providers/:id", handler.DeleteProvider) - adminGroup.POST("/providers/batch", handler.BatchProviders) - adminGroup.POST("/providers/:id/test", handler.TestProvider) - adminGroup.POST("/providers/:id/fetch-models", handler.FetchProviderModels) + adminGroup.POST("/provider-groups", handler.CreateProviderGroup) + adminGroup.GET("/provider-groups", handler.ListProviderGroups) + adminGroup.GET("/provider-groups/:id", handler.GetProviderGroup) + adminGroup.PUT("/provider-groups/:id", handler.UpdateProviderGroup) + adminGroup.DELETE("/provider-groups/:id", handler.DeleteProviderGroup) + adminGroup.POST("/api-keys", handler.CreateAPIKey) + adminGroup.GET("/api-keys", handler.ListAPIKeys) + adminGroup.GET("/api-keys/:id", handler.GetAPIKey) + adminGroup.PUT("/api-keys/:id", handler.UpdateAPIKey) + adminGroup.DELETE("/api-keys/:id", handler.DeleteAPIKey) + adminGroup.POST("/api-keys/batch", handler.BatchAPIKeys) adminGroup.POST("/models", handler.CreateModel) adminGroup.GET("/models", handler.ListModels) adminGroup.PUT("/models/:id", handler.UpdateModel) @@ -406,7 +406,7 @@ func runImport(logger *slog.Logger, args []string) int { return 1 } - if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil { + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil { logger.Error("failed to auto migrate", "err", err) return 1 } diff --git a/internal/api/access_handler.go b/internal/api/access_handler.go index 365fb90..18fc4c1 100644 --- a/internal/api/access_handler.go +++ b/internal/api/access_handler.go @@ -1,12 +1,14 @@ package api import ( + "fmt" "net/http" "strconv" "strings" "github.com/ez-api/ez-api/internal/model" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type AccessResponse struct { @@ -94,6 +96,10 @@ func (h *Handler) UpdateMasterAccess(c *gin.Context) { } nsList := normalizeNamespaces(nextNamespaces, nextDefault) nextNamespaces = strings.Join(nsList, ",") + if err := ensureNamespacesExist(h.db, nsList); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } if err := h.db.Model(&m).Updates(map[string]any{ "default_namespace": nextDefault, @@ -203,6 +209,21 @@ func (h *Handler) UpdateKeyAccess(c *gin.Context) { } nsList := normalizeNamespaces(nextNamespaces, nextDefault) nextNamespaces = strings.Join(nsList, ",") + if err := ensureNamespacesExist(h.db, nsList); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + var master model.Master + if err := h.db.First(&master, k.MasterID).Error; err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "master not found"}) + return + } + masterNamespaces := normalizeNamespaces(master.Namespaces, master.DefaultNamespace) + if !isSubset(nsList, masterNamespaces) { + c.JSON(http.StatusBadRequest, gin.H{"error": "namespaces must be a subset of master namespaces"}) + return + } if err := h.db.Model(&k).Updates(map[string]any{ "default_namespace": nextDefault, @@ -264,6 +285,39 @@ func normalizeNamespaces(raw string, defaultNamespace string) []string { return out } +func ensureNamespacesExist(db *gorm.DB, namespaces []string) error { + if db == nil { + return fmt.Errorf("db required") + } + if len(namespaces) == 0 { + return fmt.Errorf("namespaces required") + } + var rows []model.Namespace + if err := db.Where("name IN ?", namespaces).Find(&rows).Error; err != nil { + return fmt.Errorf("failed to load namespaces") + } + if len(rows) != len(namespaces) { + return fmt.Errorf("namespace not found") + } + return nil +} + +func isSubset(child, parent []string) bool { + if len(child) == 0 { + return true + } + parentSet := make(map[string]struct{}, len(parent)) + for _, p := range parent { + parentSet[strings.TrimSpace(p)] = struct{}{} + } + for _, c := range child { + if _, ok := parentSet[strings.TrimSpace(c)]; !ok { + return false + } + } + return true +} + func parseUintParam(c *gin.Context, name string) (uint, bool) { idRaw := strings.TrimSpace(c.Param(name)) if idRaw == "" { diff --git a/internal/api/api_key_handler.go b/internal/api/api_key_handler.go new file mode 100644 index 0000000..c92ab5d --- /dev/null +++ b/internal/api/api_key_handler.go @@ -0,0 +1,259 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/ez-api/ez-api/internal/dto" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/foundation/provider" + "github.com/gin-gonic/gin" +) + +// CreateAPIKey godoc +// @Summary Create an API key +// @Description Create an API key for a provider group +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param key body dto.APIKeyDTO true "API key payload" +// @Success 201 {object} model.APIKey +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/api-keys [post] +func (h *Handler) CreateAPIKey(c *gin.Context) { + var req dto.APIKeyDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.GroupID == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "group_id required"}) + return + } + var group model.ProviderGroup + if err := h.db.First(&group, req.GroupID).Error; err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "provider group not found"}) + return + } + + apiKey := strings.TrimSpace(req.APIKey) + ptype := provider.NormalizeType(group.Type) + if provider.IsGoogleFamily(ptype) && !provider.IsVertexFamily(ptype) && apiKey == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"}) + return + } + + status := strings.TrimSpace(req.Status) + if status == "" { + status = "active" + } + autoBan := true + if req.AutoBan != nil { + autoBan = *req.AutoBan + } + + key := model.APIKey{ + GroupID: req.GroupID, + APIKey: apiKey, + Weight: normalizeWeight(req.Weight), + Status: status, + AutoBan: autoBan, + BanReason: strings.TrimSpace(req.BanReason), + } + if !req.BanUntil.IsZero() { + tu := req.BanUntil.UTC() + key.BanUntil = &tu + } + + if err := h.db.Create(&key).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create api key", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusCreated, key) +} + +// ListAPIKeys godoc +// @Summary List API keys +// @Description List API keys +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param group_id query int false "filter by group_id" +// @Success 200 {array} model.APIKey +// @Failure 500 {object} gin.H +// @Router /admin/api-keys [get] +func (h *Handler) ListAPIKeys(c *gin.Context) { + var keys []model.APIKey + q := h.db.Model(&model.APIKey{}).Order("id desc") + if groupID := strings.TrimSpace(c.Query("group_id")); groupID != "" { + q = q.Where("group_id = ?", groupID) + } + query := parseListQuery(c) + q = applyListPagination(q, query) + if err := q.Find(&keys).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list api keys", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, keys) +} + +// GetAPIKey godoc +// @Summary Get API key +// @Description Get an API key by id +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param id path int true "APIKey ID" +// @Success 200 {object} model.APIKey +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/api-keys/{id} [get] +func (h *Handler) GetAPIKey(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + var key model.APIKey + if err := h.db.First(&key, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"}) + return + } + c.JSON(http.StatusOK, key) +} + +// UpdateAPIKey godoc +// @Summary Update API key +// @Description Update an API key +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param id path int true "APIKey ID" +// @Param key body dto.APIKeyDTO true "API key payload" +// @Success 200 {object} model.APIKey +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/api-keys/{id} [put] +func (h *Handler) UpdateAPIKey(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + var key model.APIKey + if err := h.db.First(&key, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"}) + return + } + var req dto.APIKeyDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + update := map[string]any{} + if req.GroupID != 0 { + var group model.ProviderGroup + if err := h.db.First(&group, req.GroupID).Error; err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "provider group not found"}) + return + } + update["group_id"] = req.GroupID + } + if strings.TrimSpace(req.APIKey) != "" { + update["api_key"] = strings.TrimSpace(req.APIKey) + } + if req.Weight > 0 { + update["weight"] = normalizeWeight(req.Weight) + } + if strings.TrimSpace(req.Status) != "" { + update["status"] = strings.TrimSpace(req.Status) + } + if req.AutoBan != nil { + update["auto_ban"] = *req.AutoBan + } + if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" { + update["ban_reason"] = strings.TrimSpace(req.BanReason) + } + if !req.BanUntil.IsZero() { + tu := req.BanUntil.UTC() + update["ban_until"] = &tu + } + if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { + update["ban_until"] = nil + } + + if err := h.db.Model(&key).Updates(update).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update api key", "details": err.Error()}) + return + } + if err := h.db.First(&key, id).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload api key", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, key) +} + +// DeleteAPIKey godoc +// @Summary Delete API key +// @Description Delete an API key +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param id path int true "APIKey ID" +// @Success 200 {object} gin.H +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/api-keys/{id} [delete] +func (h *Handler) DeleteAPIKey(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + var key model.APIKey + if err := h.db.First(&key, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"}) + return + } + if err := h.db.Delete(&key).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete api key", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "deleted"}) +} diff --git a/internal/api/batch_handler.go b/internal/api/batch_handler.go index 78730de..ef0050c 100644 --- a/internal/api/batch_handler.go +++ b/internal/api/batch_handler.go @@ -104,9 +104,9 @@ func (h *AdminHandler) BatchMasters(c *gin.Context) { c.JSON(http.StatusOK, resp) } -// BatchProviders godoc -// @Summary Batch providers -// @Description Batch delete or status update for providers +// BatchAPIKeys godoc +// @Summary Batch api keys +// @Description Batch delete or status update for api keys // @Tags admin // @Accept json // @Produce json @@ -115,8 +115,8 @@ func (h *AdminHandler) BatchMasters(c *gin.Context) { // @Success 200 {object} BatchResponse // @Failure 400 {object} gin.H // @Failure 500 {object} gin.H -// @Router /admin/providers/batch [post] -func (h *Handler) BatchProviders(c *gin.Context) { +// @Router /admin/api-keys/batch [post] +func (h *Handler) BatchAPIKeys(c *gin.Context) { var req BatchActionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) @@ -140,8 +140,8 @@ func (h *Handler) BatchProviders(c *gin.Context) { } needsBindingSync := false for _, id := range req.IDs { - var p model.Provider - if err := h.db.First(&p, id).Error; err != nil { + var key model.APIKey + if err := h.db.First(&key, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: "not found"}) continue @@ -151,11 +151,7 @@ func (h *Handler) BatchProviders(c *gin.Context) { } switch action { case "delete": - if err := h.db.Delete(&p).Error; err != nil { - resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()}) - continue - } - if err := h.sync.SyncProviderDelete(&p); err != nil { + if err := h.db.Delete(&key).Error; err != nil { resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()}) continue } @@ -167,15 +163,11 @@ func (h *Handler) BatchProviders(c *gin.Context) { update["ban_reason"] = "" update["ban_until"] = nil } - if err := h.db.Model(&p).Updates(update).Error; err != nil { + if err := h.db.Model(&key).Updates(update).Error; err != nil { resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()}) continue } - if err := h.db.First(&p, id).Error; err != nil { - resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()}) - continue - } - if err := h.sync.SyncProvider(&p); err != nil { + if err := h.db.First(&key, id).Error; err != nil { resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()}) continue } @@ -184,6 +176,10 @@ func (h *Handler) BatchProviders(c *gin.Context) { } } if needsBindingSync { + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } if err := h.sync.SyncBindings(h.db); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) return diff --git a/internal/api/binding_handler.go b/internal/api/binding_handler.go index 21aeee1..633213d 100644 --- a/internal/api/binding_handler.go +++ b/internal/api/binding_handler.go @@ -1,19 +1,19 @@ package api import ( + "fmt" "net/http" "strconv" "strings" "github.com/ez-api/ez-api/internal/dto" "github.com/ez-api/ez-api/internal/model" - groupx "github.com/ez-api/foundation/group" "github.com/gin-gonic/gin" ) // CreateBinding godoc // @Summary Create a new binding -// @Description Create a new (namespace, public_model) binding to a route group and selector +// @Description Create a new (namespace, public_model) binding to a provider group and selector // @Tags admin // @Accept json // @Produce json @@ -32,14 +32,14 @@ func (h *Handler) CreateBinding(c *gin.Context) { ns := strings.TrimSpace(req.Namespace) pm := strings.TrimSpace(req.PublicModel) - if ns == "" || pm == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "namespace and public_model required"}) + if ns == "" || pm == "" || req.GroupID == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "namespace, public_model, and group_id required"}) return } - rg := groupx.Normalize(req.RouteGroup) - if strings.TrimSpace(rg) == "" { - rg = "default" + if err := h.ensureActiveGroup(req.GroupID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } st := strings.TrimSpace(req.Status) @@ -55,7 +55,8 @@ func (h *Handler) CreateBinding(c *gin.Context) { b := model.Binding{ Namespace: ns, PublicModel: pm, - RouteGroup: rg, + GroupID: req.GroupID, + Weight: normalizeWeight(req.Weight), SelectorType: selectorType, SelectorValue: strings.TrimSpace(req.SelectorValue), Status: st, @@ -82,7 +83,7 @@ func (h *Handler) CreateBinding(c *gin.Context) { // @Security AdminAuth // @Param page query int false "page (1-based)" // @Param limit query int false "limit (default 50, max 200)" -// @Param search query string false "search by namespace/public_model/route_group" +// @Param search query string false "search by namespace/public_model" // @Success 200 {array} model.Binding // @Failure 500 {object} gin.H // @Router /admin/bindings [get] @@ -90,7 +91,7 @@ func (h *Handler) ListBindings(c *gin.Context) { var out []model.Binding q := h.db.Model(&model.Binding{}).Order("id desc") query := parseListQuery(c) - q = applyListSearch(q, query.Search, "namespace", "public_model", "route_group") + q = applyListSearch(q, query.Search, "namespace", "public_model") q = applyListPagination(q, query) if err := q.Find(&out).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list bindings", "details": err.Error()}) @@ -139,8 +140,15 @@ func (h *Handler) UpdateBinding(c *gin.Context) { if pm := strings.TrimSpace(req.PublicModel); pm != "" { existing.PublicModel = pm } - if rg := strings.TrimSpace(req.RouteGroup); rg != "" { - existing.RouteGroup = groupx.Normalize(rg) + if req.GroupID != 0 { + if err := h.ensureActiveGroup(req.GroupID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + existing.GroupID = req.GroupID + } + if req.Weight > 0 { + existing.Weight = normalizeWeight(req.Weight) } if st := strings.TrimSpace(req.Status); st != "" { existing.Status = st @@ -229,3 +237,30 @@ func (h *Handler) DeleteBinding(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "deleted"}) } + +func normalizeWeight(weight int) int { + if weight <= 0 { + return 1 + } + return weight +} + +func (h *Handler) ensureActiveGroup(groupID uint) error { + var group model.ProviderGroup + if err := h.db.First(&group, groupID).Error; err != nil { + return fmt.Errorf("provider group not found") + } + if strings.TrimSpace(group.Status) != "" && strings.TrimSpace(group.Status) != "active" { + return fmt.Errorf("provider group not active") + } + var count int64 + if err := h.db.Model(&model.APIKey{}). + Where("group_id = ? AND status = ?", groupID, "active"). + Count(&count).Error; err != nil { + return fmt.Errorf("failed to check api keys") + } + if count == 0 { + return fmt.Errorf("provider group has no active api keys") + } + return nil +} diff --git a/internal/api/handler.go b/internal/api/handler.go index f390e1a..3756d6c 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -9,8 +9,6 @@ import ( "github.com/ez-api/ez-api/internal/dto" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/service" - groupx "github.com/ez-api/foundation/group" - "github.com/ez-api/foundation/provider" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "gorm.io/gorm" @@ -54,255 +52,6 @@ func (h *Handler) logBaseQuery() *gorm.DB { // CreateKey is now handled by MasterHandler -// CreateProvider godoc -// @Summary Create a new provider -// @Description Register a new upstream AI provider -// @Tags admin -// @Accept json -// @Produce json -// @Security AdminAuth -// @Param provider body dto.ProviderDTO true "Provider Info" -// @Success 201 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers [post] -func (h *Handler) CreateProvider(c *gin.Context) { - var req dto.ProviderDTO - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - providerType := provider.NormalizeType(req.Type) - baseURL := strings.TrimSpace(req.BaseURL) - googleLocation := provider.DefaultGoogleLocation(providerType, req.GoogleLocation) - - group := strings.TrimSpace(req.Group) - if group == "" { - group = "default" - } - - status := strings.TrimSpace(req.Status) - if status == "" { - status = "active" - } - autoBan := true - if req.AutoBan != nil { - autoBan = *req.AutoBan - } - - // CP-side defaults + validation to prevent DP runtime errors. - switch providerType { - case provider.TypeOpenAI: - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - case provider.TypeAnthropic, provider.TypeClaude: - if baseURL == "" { - baseURL = "https://api.anthropic.com" - } - case provider.TypeCompatible: - if baseURL == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"}) - return - } - default: - // Google SDK providers: base_url is not required. - if provider.IsVertexFamily(providerType) && strings.TrimSpace(googleLocation) == "" { - googleLocation = provider.DefaultGoogleLocation(providerType, "") - } - // For Gemini API providers, api_key is required. - if provider.IsGoogleFamily(providerType) && !provider.IsVertexFamily(providerType) && strings.TrimSpace(req.APIKey) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"}) - return - } - } - - provider := model.Provider{ - Name: req.Name, - Type: strings.TrimSpace(req.Type), - BaseURL: baseURL, - APIKey: req.APIKey, - GoogleProject: strings.TrimSpace(req.GoogleProject), - GoogleLocation: googleLocation, - Group: group, - Models: strings.Join(req.Models, ","), - Status: status, - AutoBan: autoBan, - BanReason: req.BanReason, - Weight: req.Weight, - } - if !req.BanUntil.IsZero() { - tu := req.BanUntil.UTC() - provider.BanUntil = &tu - } - - if err := h.db.Create(&provider).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) - return - } - - if err := h.sync.SyncProvider(&provider); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) - return - } - // Provider model list changes can affect binding upstream mappings; rebuild bindings snapshot. - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusCreated, provider) -} - -// UpdateProvider godoc -// @Summary Update a provider -// @Description Update provider attributes including status/auto-ban flags -// @Tags admin -// @Accept json -// @Produce json -// @Security AdminAuth -// @Param id path int true "Provider ID" -// @Param provider body dto.ProviderDTO true "Provider Info" -// @Success 200 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 404 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/{id} [put] -func (h *Handler) UpdateProvider(c *gin.Context) { - idParam := c.Param("id") - id, err := strconv.Atoi(idParam) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) - return - } - - var existing model.Provider - if err := h.db.First(&existing, id).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) - return - } - - var req dto.ProviderDTO - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - nextType := strings.TrimSpace(existing.Type) - if t := strings.TrimSpace(req.Type); t != "" { - nextType = t - } - nextTypeLower := provider.NormalizeType(nextType) - nextBaseURL := strings.TrimSpace(existing.BaseURL) - if strings.TrimSpace(req.BaseURL) != "" { - nextBaseURL = strings.TrimSpace(req.BaseURL) - } - - update := map[string]any{} - if strings.TrimSpace(req.Name) != "" { - update["name"] = req.Name - } - if strings.TrimSpace(req.Type) != "" { - update["type"] = strings.TrimSpace(req.Type) - } - if strings.TrimSpace(req.BaseURL) != "" { - update["base_url"] = req.BaseURL - } - if req.APIKey != "" { - update["api_key"] = req.APIKey - } - if strings.TrimSpace(req.GoogleProject) != "" { - update["google_project"] = strings.TrimSpace(req.GoogleProject) - } - if strings.TrimSpace(req.GoogleLocation) != "" { - update["google_location"] = strings.TrimSpace(req.GoogleLocation) - } else if provider.IsVertexFamily(nextTypeLower) && strings.TrimSpace(existing.GoogleLocation) == "" { - update["google_location"] = provider.DefaultGoogleLocation(nextTypeLower, "") - } - if req.Models != nil { - update["models"] = strings.Join(req.Models, ",") - } - if req.Weight > 0 { - update["weight"] = req.Weight - } - if strings.TrimSpace(req.Group) != "" { - update["group"] = groupx.Normalize(req.Group) - } - if req.AutoBan != nil { - update["auto_ban"] = *req.AutoBan - } - if strings.TrimSpace(req.Status) != "" { - update["status"] = req.Status - } - if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" { - update["ban_reason"] = req.BanReason - } - if !req.BanUntil.IsZero() { - tu := req.BanUntil.UTC() - update["ban_until"] = &tu - } - if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { - update["ban_until"] = nil - } - - // Defaults/validation after considering intended type/base_url. - switch nextTypeLower { - case provider.TypeOpenAI: - if nextBaseURL == "" { - update["base_url"] = "https://api.openai.com/v1" - } - case provider.TypeAnthropic, provider.TypeClaude: - if nextBaseURL == "" { - update["base_url"] = "https://api.anthropic.com" - } - case provider.TypeCompatible: - if nextBaseURL == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"}) - return - } - default: - if provider.IsGoogleFamily(nextTypeLower) && !provider.IsVertexFamily(nextTypeLower) { - // Ensure Gemini API providers have api_key. - // If update does not include api_key, keep existing; otherwise require new one not empty. - apiKey := existing.APIKey - if req.APIKey != "" { - apiKey = req.APIKey - } - if strings.TrimSpace(apiKey) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"}) - return - } - } - } - - if len(update) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"}) - return - } - - if err := h.db.Model(&existing).Updates(update).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider", "details": err.Error()}) - return - } - - if err := h.db.First(&existing, id).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider", "details": err.Error()}) - return - } - - if err := h.sync.SyncProvider(&existing); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusOK, existing) -} - // CreateModel godoc // @Summary Register a new model // @Description Register a supported model with its capabilities diff --git a/internal/api/model_handler_test.go b/internal/api/model_handler_test.go index 216354f..726269c 100644 --- a/internal/api/model_handler_test.go +++ b/internal/api/model_handler_test.go @@ -26,7 +26,7 @@ func newTestHandlerWithRedis(t *testing.T) (*Handler, *gorm.DB, *miniredis.Minir if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil { t.Fatalf("migrate: %v", err) } @@ -204,10 +204,18 @@ func TestBatchModels_Delete(t *testing.T) { func TestBatchBindings_Status(t *testing.T) { h, db := newTestHandler(t) + group := model.ProviderGroup{Name: "default", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m1", Status: "active"} + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) + } + if err := db.Create(&model.APIKey{GroupID: group.ID, APIKey: "k", Status: "active"}).Error; err != nil { + t.Fatalf("create api key: %v", err) + } b := &model.Binding{ Namespace: "ns", PublicModel: "m1", - RouteGroup: "default", + GroupID: group.ID, + Weight: 1, SelectorType: "exact", SelectorValue: "m1", Status: "active", diff --git a/internal/api/namespace_handler_test.go b/internal/api/namespace_handler_test.go index 36bc911..c085ad2 100644 --- a/internal/api/namespace_handler_test.go +++ b/internal/api/namespace_handler_test.go @@ -26,7 +26,7 @@ func newTestHandlerWithNamespace(t *testing.T) (*Handler, *gorm.DB, *miniredis.M if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Namespace{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Namespace{}); err != nil { t.Fatalf("migrate: %v", err) } @@ -39,10 +39,23 @@ func newTestHandlerWithNamespace(t *testing.T) (*Handler, *gorm.DB, *miniredis.M func TestNamespaceCRUD_DeleteCleansBindings(t *testing.T) { h, db, _ := newTestHandlerWithNamespace(t) + group := model.ProviderGroup{Name: "g1", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m1", Status: "active"} + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) + } + if err := db.Create(&model.APIKey{ + GroupID: group.ID, + APIKey: "k1", + Status: "active", + }).Error; err != nil { + t.Fatalf("create api key: %v", err) + } + if err := db.Create(&model.Binding{ Namespace: "ns1", PublicModel: "m1", - RouteGroup: "default", + GroupID: group.ID, + Weight: 1, SelectorType: "exact", SelectorValue: "m1", Status: "active", diff --git a/internal/api/provider_admin_handler.go b/internal/api/provider_admin_handler.go deleted file mode 100644 index e03f585..0000000 --- a/internal/api/provider_admin_handler.go +++ /dev/null @@ -1,350 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "sort" - "strings" - "time" - - "github.com/ez-api/ez-api/internal/model" - "github.com/ez-api/foundation/provider" - "github.com/gin-gonic/gin" -) - -// ListProviders godoc -// @Summary List providers -// @Description List all configured upstream providers -// @Tags admin -// @Produce json -// @Security AdminAuth -// @Param page query int false "page (1-based)" -// @Param limit query int false "limit (default 50, max 200)" -// @Param search query string false "search by name/type/base_url/group" -// @Success 200 {array} model.Provider -// @Failure 500 {object} gin.H -// @Router /admin/providers [get] -func (h *Handler) ListProviders(c *gin.Context) { - var providers []model.Provider - q := h.db.Model(&model.Provider{}).Order("id desc") - query := parseListQuery(c) - q = applyListSearch(q, query.Search, "name", `"type"`, "base_url", `"group"`) - q = applyListPagination(q, query) - if err := q.Find(&providers).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list providers", "details": err.Error()}) - return - } - c.JSON(http.StatusOK, providers) -} - -// GetProvider godoc -// @Summary Get provider -// @Description Get a provider by id -// @Tags admin -// @Produce json -// @Security AdminAuth -// @Param id path int true "Provider ID" -// @Success 200 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 404 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/{id} [get] -func (h *Handler) GetProvider(c *gin.Context) { - id, ok := parseUintParam(c, "id") - if !ok { - return - } - var p model.Provider - if err := h.db.First(&p, id).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) - return - } - c.JSON(http.StatusOK, p) -} - -// DeleteProvider godoc -// @Summary Delete provider -// @Description Deletes a provider and triggers a full snapshot sync to avoid stale routing -// @Tags admin -// @Produce json -// @Security AdminAuth -// @Param id path int true "Provider ID" -// @Success 200 {object} gin.H -// @Failure 400 {object} gin.H -// @Failure 404 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/{id} [delete] -func (h *Handler) DeleteProvider(c *gin.Context) { - id, ok := parseUintParam(c, "id") - if !ok { - return - } - - var p model.Provider - if err := h.db.First(&p, id).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) - return - } - - if err := h.db.Delete(&p).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider", "details": err.Error()}) - return - } - - if err := h.sync.SyncProviderDelete(&p); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider delete", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"status": "deleted"}) -} - -type testProviderResponse struct { - StatusCode int `json:"status_code"` - OK bool `json:"ok"` - URL string `json:"url"` - Body string `json:"body,omitempty"` -} - -// TestProvider godoc -// @Summary Test provider connectivity -// @Description Performs a lightweight upstream request to verify the provider configuration -// @Tags admin -// @Produce json -// @Security AdminAuth -// @Param id path int true "Provider ID" -// @Success 200 {object} testProviderResponse -// @Failure 400 {object} gin.H -// @Failure 404 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/{id}/test [post] -func (h *Handler) TestProvider(c *gin.Context) { - id, ok := parseUintParam(c, "id") - if !ok { - return - } - var p model.Provider - if err := h.db.First(&p, id).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) - return - } - - req, err := buildProviderModelsRequest(&p) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - c.JSON(http.StatusOK, testProviderResponse{StatusCode: 0, OK: false, URL: req.URL.String(), Body: err.Error()}) - return - } - defer resp.Body.Close() - - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) - ok = resp.StatusCode >= 200 && resp.StatusCode < 300 - - c.JSON(http.StatusOK, testProviderResponse{ - StatusCode: resp.StatusCode, - OK: ok, - URL: req.URL.String(), - Body: string(body), - }) -} - -// FetchProviderModels godoc -// @Summary Fetch models from provider -// @Description Calls upstream /models (or /v1/models) and updates provider model list -// @Tags admin -// @Produce json -// @Security AdminAuth -// @Param id path int true "Provider ID" -// @Success 200 {object} gin.H -// @Failure 400 {object} gin.H -// @Failure 404 {object} gin.H -// @Failure 502 {object} gin.H -// @Router /admin/providers/{id}/fetch-models [post] -func (h *Handler) FetchProviderModels(c *gin.Context) { - id, ok := parseUintParam(c, "id") - if !ok { - return - } - var p model.Provider - if err := h.db.First(&p, id).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) - return - } - - req, err := buildProviderModelsRequest(&p) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - client := &http.Client{Timeout: 15 * time.Second} - resp, err := client.Do(req) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": "failed to fetch models", "details": err.Error()}) - return - } - defer resp.Body.Close() - - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - c.JSON(http.StatusBadGateway, gin.H{ - "error": "upstream returned non-2xx", - "status_code": resp.StatusCode, - "body": string(body), - }) - return - } - - models, err := parseProviderModelIDs(body) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": "failed to parse models", "details": err.Error()}) - return - } - - if err := h.db.Model(&p).Update("models", strings.Join(models, ",")).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update provider models", "details": err.Error()}) - return - } - if err := h.db.First(&p, id).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reload provider", "details": err.Error()}) - return - } - if err := h.sync.SyncProvider(&p); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sync provider", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "status": "updated", - "count": len(models), - "models": models, - }) -} - -func buildProviderModelsRequest(p *model.Provider) (*http.Request, error) { - if p == nil { - return nil, fmt.Errorf("provider required") - } - pt := provider.NormalizeType(p.Type) - baseURL := strings.TrimRight(strings.TrimSpace(p.BaseURL), "/") - if baseURL == "" { - return nil, fmt.Errorf("base_url required for provider models fetch") - } - - url := "" - switch pt { - case provider.TypeOpenAI, provider.TypeCompatible: - if strings.HasSuffix(baseURL, "/v1") { - url = baseURL + "/models" - } else { - url = baseURL + "/v1/models" - } - case provider.TypeAnthropic, provider.TypeClaude: - url = baseURL + "/v1/models" - default: - return nil, fmt.Errorf("provider type not supported for model fetch") - } - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("build request: %w", err) - } - - apiKey := strings.TrimSpace(p.APIKey) - switch pt { - case provider.TypeOpenAI, provider.TypeCompatible: - if apiKey != "" { - req.Header.Set("Authorization", "Bearer "+apiKey) - } - case provider.TypeAnthropic, provider.TypeClaude: - if apiKey != "" { - req.Header.Set("x-api-key", apiKey) - } - req.Header.Set("anthropic-version", "2023-06-01") - } - - return req, nil -} - -type providerModelsResponse struct { - Data []json.RawMessage `json:"data"` - Models []string `json:"models"` -} - -func parseProviderModelIDs(payload []byte) ([]string, error) { - var resp providerModelsResponse - if err := json.Unmarshal(payload, &resp); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - - models := make([]string, 0, len(resp.Data)+len(resp.Models)) - for _, name := range resp.Models { - name = strings.TrimSpace(name) - if name != "" { - models = append(models, name) - } - } - for _, raw := range resp.Data { - var item struct { - ID string `json:"id"` - Model string `json:"model"` - Name string `json:"name"` - } - if err := json.Unmarshal(raw, &item); err == nil { - if item.ID != "" { - models = append(models, strings.TrimSpace(item.ID)) - continue - } - if item.Model != "" { - models = append(models, strings.TrimSpace(item.Model)) - continue - } - if item.Name != "" { - models = append(models, strings.TrimSpace(item.Name)) - continue - } - } - var name string - if err := json.Unmarshal(raw, &name); err == nil { - name = strings.TrimSpace(name) - if name != "" { - models = append(models, name) - } - } - } - - unique := make(map[string]struct{}, len(models)) - out := make([]string, 0, len(models)) - for _, name := range models { - name = strings.TrimSpace(name) - if name == "" { - continue - } - if _, ok := unique[name]; ok { - continue - } - unique[name] = struct{}{} - out = append(out, name) - } - if len(out) == 0 { - return nil, fmt.Errorf("no models found in response") - } - sort.Strings(out) - return out, nil -} diff --git a/internal/api/provider_admin_handler_test.go b/internal/api/provider_admin_handler_test.go deleted file mode 100644 index 8827cd4..0000000 --- a/internal/api/provider_admin_handler_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/ez-api/ez-api/internal/model" - "github.com/gin-gonic/gin" -) - -func TestAdmin_TestProvider_OpenAICompatible(t *testing.T) { - h, db := newTestHandler(t) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/v1/models" { - http.NotFound(w, r) - return - } - if got := r.Header.Get("Authorization"); got != "Bearer k" { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"object":"list","data":[]}`)) - })) - defer upstream.Close() - - p := &model.Provider{ - Name: "p1", - Type: "openai", - BaseURL: upstream.URL + "/v1", - APIKey: "k", - Group: "default", - Models: "gpt-4o-mini", - Status: "active", - } - if err := db.Create(p).Error; err != nil { - t.Fatalf("create provider: %v", err) - } - - r := gin.New() - r.POST("/admin/providers/:id/test", h.TestProvider) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers/1/test", nil) - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) - } - var payload map[string]any - if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if ok, _ := payload["ok"].(bool); !ok { - t.Fatalf("expected ok=true, got %v body=%s", payload["ok"], rr.Body.String()) - } -} - -func TestAdmin_FetchProviderModels_OpenAICompatible(t *testing.T) { - h, db := newTestHandler(t) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/v1/models" { - http.NotFound(w, r) - return - } - if got := r.Header.Get("Authorization"); got != "Bearer k" { - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"gpt-4o-mini"},{"id":"gpt-4o"}]}`)) - })) - defer upstream.Close() - - p := &model.Provider{ - Name: "p1", - Type: "openai", - BaseURL: upstream.URL + "/v1", - APIKey: "k", - Group: "default", - Models: "old-model", - Status: "active", - } - if err := db.Create(p).Error; err != nil { - t.Fatalf("create provider: %v", err) - } - - r := gin.New() - r.POST("/admin/providers/:id/fetch-models", h.FetchProviderModels) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers/1/fetch-models", nil) - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) - } - - var updated model.Provider - if err := db.First(&updated, p.ID).Error; err != nil { - t.Fatalf("reload provider: %v", err) - } - if updated.Models != "gpt-4o,gpt-4o-mini" { - t.Fatalf("expected models to update, got %q", updated.Models) - } -} - -func TestAdmin_BatchProviders_Status(t *testing.T) { - h, db := newTestHandler(t) - - banUntil := time.Now().Add(2 * time.Hour).UTC() - p := &model.Provider{ - Name: "p1", - Type: "openai", - BaseURL: "https://api.openai.com/v1", - Group: "default", - Models: "gpt-4o-mini", - Status: "manual_disabled", - BanReason: "bad", - BanUntil: &banUntil, - } - if err := db.Create(p).Error; err != nil { - t.Fatalf("create provider: %v", err) - } - - r := gin.New() - r.POST("/admin/providers/batch", h.BatchProviders) - - payload := map[string]any{ - "action": "status", - "status": "active", - "ids": []uint{p.ID}, - } - b, _ := json.Marshal(payload) - req := httptest.NewRequest(http.MethodPost, "/admin/providers/batch", bytes.NewReader(b)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) - } - var updated model.Provider - if err := db.First(&updated, p.ID).Error; err != nil { - t.Fatalf("reload provider: %v", err) - } - if updated.Status != "active" { - t.Fatalf("expected status active, got %q", updated.Status) - } - if updated.BanReason != "" { - t.Fatalf("expected ban_reason cleared, got %q", updated.BanReason) - } - if updated.BanUntil != nil { - t.Fatalf("expected ban_until cleared, got %v", updated.BanUntil) - } -} diff --git a/internal/api/provider_create_handler.go b/internal/api/provider_create_handler.go deleted file mode 100644 index 66beecc..0000000 --- a/internal/api/provider_create_handler.go +++ /dev/null @@ -1,288 +0,0 @@ -package api - -import ( - "crypto/rand" - "encoding/hex" - "net/http" - "strings" - - "github.com/ez-api/ez-api/internal/dto" - "github.com/ez-api/ez-api/internal/model" - groupx "github.com/ez-api/foundation/group" - providerx "github.com/ez-api/foundation/provider" - "github.com/gin-gonic/gin" -) - -// CreateProviderPreset godoc -// @Summary Create a preset provider -// @Description Create an official OpenAI/Anthropic provider (only api_key is typically required) -// @Tags admin -// @Accept json -// @Produce json -// @Security AdminAuth -// @Param provider body dto.ProviderPresetCreateDTO true "Provider preset payload" -// @Success 201 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/preset [post] -func (h *Handler) CreateProviderPreset(c *gin.Context) { - var req dto.ProviderPresetCreateDTO - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - preset := providerx.NormalizeType(req.Preset) - if preset == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "preset required"}) - return - } - - var providerType string - var baseURL string - switch preset { - case providerx.TypeOpenAI: - providerType = providerx.TypeOpenAI - baseURL = "https://api.openai.com/v1" - case providerx.TypeAnthropic, providerx.TypeClaude: - providerType = providerx.TypeAnthropic - baseURL = "https://api.anthropic.com" - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported preset: " + preset + " (use /admin/providers/google for Google SDK providers)"}) - return - } - - name := strings.TrimSpace(req.Name) - if name == "" { - name = providerType + "-" + randomSuffix(4) - } - group := strings.TrimSpace(req.Group) - if group == "" { - group = "default" - } - status := strings.TrimSpace(req.Status) - if status == "" { - status = "active" - } - autoBan := true - if req.AutoBan != nil { - autoBan = *req.AutoBan - } - - googleLocation := providerx.DefaultGoogleLocation(providerType, req.GoogleLocation) - - p := model.Provider{ - Name: name, - Type: providerType, - BaseURL: baseURL, - APIKey: strings.TrimSpace(req.APIKey), - GoogleProject: strings.TrimSpace(req.GoogleProject), - GoogleLocation: googleLocation, - Group: groupx.Normalize(group), - Models: strings.Join(req.Models, ","), - Status: status, - AutoBan: autoBan, - } - if req.Weight > 0 { - p.Weight = req.Weight - } - - if err := h.db.Create(&p).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) - return - } - - if err := h.sync.SyncProvider(&p); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusCreated, p) -} - -// CreateProviderGoogle godoc -// @Summary Create a Google SDK provider -// @Description Create a Google SDK provider (Gemini API key or Vertex project/location); base_url is not used -// @Tags admin -// @Accept json -// @Produce json -// @Security AdminAuth -// @Param provider body dto.ProviderGoogleCreateDTO true "Google provider payload" -// @Success 201 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/google [post] -func (h *Handler) CreateProviderGoogle(c *gin.Context) { - var req dto.ProviderGoogleCreateDTO - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - pt := providerx.NormalizeType(req.Type) - if pt == "" { - pt = providerx.TypeGemini - } - if !providerx.IsGoogleFamily(pt) { - c.JSON(http.StatusBadRequest, gin.H{"error": "type must be google family"}) - return - } - - name := strings.TrimSpace(req.Name) - if name == "" { - name = pt + "-" + randomSuffix(4) - } - group := strings.TrimSpace(req.Group) - if group == "" { - group = "default" - } - status := strings.TrimSpace(req.Status) - if status == "" { - status = "active" - } - autoBan := true - if req.AutoBan != nil { - autoBan = *req.AutoBan - } - - // Validate fields by type. - apiKey := strings.TrimSpace(req.APIKey) - googleProject := strings.TrimSpace(req.GoogleProject) - googleLocation := providerx.DefaultGoogleLocation(pt, req.GoogleLocation) - - if providerx.IsVertexFamily(pt) { - // Vertex uses ADC and project/location; api_key is not required. - if strings.TrimSpace(googleLocation) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "google_location required"}) - return - } - apiKey = "" - } else { - // Gemini API requires api_key. - if apiKey == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api"}) - return - } - googleProject = "" - googleLocation = "" - } - - p := model.Provider{ - Name: name, - Type: pt, - BaseURL: "", // intentionally unused for Google SDK - APIKey: apiKey, - GoogleProject: googleProject, - GoogleLocation: googleLocation, - Group: groupx.Normalize(group), - Models: strings.Join(req.Models, ","), - Status: status, - AutoBan: autoBan, - } - if req.Weight > 0 { - p.Weight = req.Weight - } - - if err := h.db.Create(&p).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) - return - } - if err := h.sync.SyncProvider(&p); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - c.JSON(http.StatusCreated, p) -} - -// CreateProviderCustom godoc -// @Summary Create a custom provider -// @Description Create an OpenAI-compatible provider (base_url + api_key required) -// @Tags admin -// @Accept json -// @Produce json -// @Security AdminAuth -// @Param provider body dto.ProviderCustomCreateDTO true "Provider custom payload" -// @Success 201 {object} model.Provider -// @Failure 400 {object} gin.H -// @Failure 500 {object} gin.H -// @Router /admin/providers/custom [post] -func (h *Handler) CreateProviderCustom(c *gin.Context) { - var req dto.ProviderCustomCreateDTO - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - name := strings.TrimSpace(req.Name) - if name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "name required"}) - return - } - baseURL := strings.TrimSpace(req.BaseURL) - if baseURL == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required"}) - return - } - - group := strings.TrimSpace(req.Group) - if group == "" { - group = "default" - } - status := strings.TrimSpace(req.Status) - if status == "" { - status = "active" - } - autoBan := true - if req.AutoBan != nil { - autoBan = *req.AutoBan - } - - p := model.Provider{ - Name: name, - Type: providerx.TypeCompatible, - BaseURL: baseURL, - APIKey: strings.TrimSpace(req.APIKey), - Group: groupx.Normalize(group), - Models: strings.Join(req.Models, ","), - Status: status, - AutoBan: autoBan, - } - if req.Weight > 0 { - p.Weight = req.Weight - } - - if err := h.db.Create(&p).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) - return - } - - if err := h.sync.SyncProvider(&p); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) - return - } - if err := h.sync.SyncBindings(h.db); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) - return - } - - c.JSON(http.StatusCreated, p) -} - -func randomSuffix(bytesLen int) string { - if bytesLen <= 0 { - bytesLen = 4 - } - b := make([]byte, bytesLen) - if _, err := rand.Read(b); err != nil { - return "rand" - } - return hex.EncodeToString(b) -} diff --git a/internal/api/provider_create_handler_test.go b/internal/api/provider_create_handler_test.go deleted file mode 100644 index 7a62141..0000000 --- a/internal/api/provider_create_handler_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/ez-api/ez-api/internal/model" - "github.com/gin-gonic/gin" -) - -func TestCreateProviderPreset_OpenAI_SetsBaseURL(t *testing.T) { - h, _ := newTestHandler(t) - - r := gin.New() - r.POST("/admin/providers/preset", h.CreateProviderPreset) - - reqBody := map[string]any{ - "preset": "openai", - "api_key": "k", - "models": []string{"gpt-4o-mini"}, - } - b, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers/preset", bytes.NewReader(b)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusCreated { - t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String()) - } - var got model.Provider - if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if got.Type != "openai" { - t.Fatalf("expected type openai, got %q", got.Type) - } - if got.BaseURL != "https://api.openai.com/v1" { - t.Fatalf("expected base_url=https://api.openai.com/v1, got %q", got.BaseURL) - } - if got.Name == "" { - t.Fatalf("expected generated name") - } -} - -func TestCreateProviderCustom_RequiresBaseURL(t *testing.T) { - h, _ := newTestHandler(t) - - r := gin.New() - r.POST("/admin/providers/custom", h.CreateProviderCustom) - - reqBody := map[string]any{ - "name": "c1", - "api_key": "k", - } - b, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers/custom", bytes.NewReader(b)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) - } -} - -func TestCreateProviderGoogle_GeminiRequiresAPIKey(t *testing.T) { - h, _ := newTestHandler(t) - - r := gin.New() - r.POST("/admin/providers/google", h.CreateProviderGoogle) - - reqBody := map[string]any{ - "type": "gemini", - "models": []string{"gemini-2.0-flash"}, - } - b, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers/google", bytes.NewReader(b)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) - } -} diff --git a/internal/api/provider_group_handler.go b/internal/api/provider_group_handler.go new file mode 100644 index 0000000..5936f50 --- /dev/null +++ b/internal/api/provider_group_handler.go @@ -0,0 +1,302 @@ +package api + +import ( + "net/http" + "strconv" + "strings" + + "github.com/ez-api/ez-api/internal/dto" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/foundation/provider" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// CreateProviderGroup godoc +// @Summary Create a provider group +// @Description Create a provider group definition +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param group body dto.ProviderGroupDTO true "Provider group payload" +// @Success 201 {object} model.ProviderGroup +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/provider-groups [post] +func (h *Handler) CreateProviderGroup(c *gin.Context) { + var req dto.ProviderGroupDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + name := strings.TrimSpace(req.Name) + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name required"}) + return + } + + ptype := provider.NormalizeType(req.Type) + if ptype == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "type required"}) + return + } + + baseURL := strings.TrimSpace(req.BaseURL) + googleLocation := provider.DefaultGoogleLocation(ptype, req.GoogleLocation) + + switch ptype { + case provider.TypeOpenAI: + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + case provider.TypeAnthropic, provider.TypeClaude: + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + case provider.TypeCompatible: + if baseURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"}) + return + } + default: + if provider.IsVertexFamily(ptype) && strings.TrimSpace(googleLocation) == "" { + googleLocation = provider.DefaultGoogleLocation(ptype, "") + } + } + + status := strings.TrimSpace(req.Status) + if status == "" { + status = "active" + } + + group := model.ProviderGroup{ + Name: name, + Type: strings.TrimSpace(req.Type), + BaseURL: baseURL, + GoogleProject: strings.TrimSpace(req.GoogleProject), + GoogleLocation: googleLocation, + Models: strings.Join(req.Models, ","), + Status: status, + } + if err := h.db.Create(&group).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider group", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusCreated, group) +} + +// ListProviderGroups godoc +// @Summary List provider groups +// @Description List all provider groups +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by name/type" +// @Success 200 {array} model.ProviderGroup +// @Failure 500 {object} gin.H +// @Router /admin/provider-groups [get] +func (h *Handler) ListProviderGroups(c *gin.Context) { + var groups []model.ProviderGroup + q := h.db.Model(&model.ProviderGroup{}).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, "name", "type") + q = applyListPagination(q, query) + if err := q.Find(&groups).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list provider groups", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, groups) +} + +// GetProviderGroup godoc +// @Summary Get provider group +// @Description Get a provider group by id +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param id path int true "ProviderGroup ID" +// @Success 200 {object} model.ProviderGroup +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/provider-groups/{id} [get] +func (h *Handler) GetProviderGroup(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + var group model.ProviderGroup + if err := h.db.First(&group, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"}) + return + } + c.JSON(http.StatusOK, group) +} + +// UpdateProviderGroup godoc +// @Summary Update provider group +// @Description Update a provider group +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param id path int true "ProviderGroup ID" +// @Param group body dto.ProviderGroupDTO true "Provider group payload" +// @Success 200 {object} model.ProviderGroup +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/provider-groups/{id} [put] +func (h *Handler) UpdateProviderGroup(c *gin.Context) { + idParam := c.Param("id") + id, err := strconv.Atoi(idParam) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + var group model.ProviderGroup + if err := h.db.First(&group, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"}) + return + } + + var req dto.ProviderGroupDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + nextType := strings.TrimSpace(group.Type) + if t := strings.TrimSpace(req.Type); t != "" { + nextType = t + } + nextTypeLower := provider.NormalizeType(nextType) + nextBaseURL := strings.TrimSpace(group.BaseURL) + if strings.TrimSpace(req.BaseURL) != "" { + nextBaseURL = strings.TrimSpace(req.BaseURL) + } + + update := map[string]any{} + if strings.TrimSpace(req.Name) != "" { + update["name"] = strings.TrimSpace(req.Name) + } + if strings.TrimSpace(req.Type) != "" { + update["type"] = strings.TrimSpace(req.Type) + } + if strings.TrimSpace(req.BaseURL) != "" { + update["base_url"] = strings.TrimSpace(req.BaseURL) + } + if strings.TrimSpace(req.GoogleProject) != "" { + update["google_project"] = strings.TrimSpace(req.GoogleProject) + } + if strings.TrimSpace(req.GoogleLocation) != "" { + update["google_location"] = strings.TrimSpace(req.GoogleLocation) + } else if provider.IsVertexFamily(nextTypeLower) && strings.TrimSpace(group.GoogleLocation) == "" { + update["google_location"] = provider.DefaultGoogleLocation(nextTypeLower, "") + } + if req.Models != nil { + update["models"] = strings.Join(req.Models, ",") + } + if strings.TrimSpace(req.Status) != "" { + update["status"] = strings.TrimSpace(req.Status) + } + + switch nextTypeLower { + case provider.TypeOpenAI: + if nextBaseURL == "" { + update["base_url"] = "https://api.openai.com/v1" + } + case provider.TypeAnthropic, provider.TypeClaude: + if nextBaseURL == "" { + update["base_url"] = "https://api.anthropic.com" + } + case provider.TypeCompatible: + if nextBaseURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"}) + return + } + } + + if err := h.db.Model(&group).Updates(update).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider group", "details": err.Error()}) + return + } + if err := h.db.First(&group, id).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider group", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, group) +} + +// DeleteProviderGroup godoc +// @Summary Delete provider group +// @Description Delete a provider group and its api keys/bindings +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param id path int true "ProviderGroup ID" +// @Success 200 {object} gin.H +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/provider-groups/{id} [delete] +func (h *Handler) DeleteProviderGroup(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + var group model.ProviderGroup + if err := h.db.First(&group, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"}) + return + } + + if err := h.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("group_id = ?", group.ID).Delete(&model.APIKey{}).Error; err != nil { + return err + } + if err := tx.Where("group_id = ?", group.ID).Delete(&model.Binding{}).Error; err != nil { + return err + } + return tx.Delete(&group).Error + }); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider group", "details": err.Error()}) + return + } + + if err := h.sync.SyncProviders(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "deleted"}) +} diff --git a/internal/api/provider_handler_test.go b/internal/api/provider_handler_test.go deleted file mode 100644 index b0508e5..0000000 --- a/internal/api/provider_handler_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/alicebob/miniredis/v2" - "github.com/ez-api/ez-api/internal/dto" - "github.com/ez-api/ez-api/internal/model" - "github.com/ez-api/ez-api/internal/service" - "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -func newTestHandler(t *testing.T) (*Handler, *gorm.DB) { - t.Helper() - gin.SetMode(gin.TestMode) - - // Use a unique in-memory DB per test to avoid cross-test interference. - dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) - db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) - if err != nil { - t.Fatalf("open sqlite: %v", err) - } - if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.LogRecord{}); err != nil { - t.Fatalf("migrate: %v", err) - } - - mr := miniredis.RunT(t) - rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) - sync := service.NewSyncService(rdb) - - return NewHandler(db, db, sync, nil, rdb, nil), db -} - -func TestCreateProvider_DefaultsVertexLocationGlobal(t *testing.T) { - h, _ := newTestHandler(t) - - r := gin.New() - r.POST("/admin/providers", h.CreateProvider) - - reqBody := dto.ProviderDTO{ - Name: "g1", - Type: "vertex-express", - Group: "default", - Models: []string{"gemini-3-pro-preview"}, - } - b, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPost, "/admin/providers", bytes.NewReader(b)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusCreated { - t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String()) - } - var got model.Provider - if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if got.GoogleLocation != "global" { - t.Fatalf("expected google_location=global, got %q", got.GoogleLocation) - } -} - -func TestUpdateProvider_DefaultsVertexLocationGlobalWhenMissing(t *testing.T) { - h, db := newTestHandler(t) - - existing := &model.Provider{ - Name: "g2", - Type: "vertex", - Group: "default", - Models: "gemini-3-pro-preview", - Status: "active", - } - if err := db.Create(existing).Error; err != nil { - t.Fatalf("create provider: %v", err) - } - - r := gin.New() - r.PUT("/admin/providers/:id", h.UpdateProvider) - - req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/providers/%d", existing.ID), bytes.NewReader([]byte(`{}`))) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) - } - var got model.Provider - if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if got.GoogleLocation != "global" { - t.Fatalf("expected google_location=global, got %q", got.GoogleLocation) - } -} diff --git a/internal/dto/api_key.go b/internal/dto/api_key.go new file mode 100644 index 0000000..53ea860 --- /dev/null +++ b/internal/dto/api_key.go @@ -0,0 +1,14 @@ +package dto + +import "time" + +// APIKeyDTO defines inbound payload for API key creation/update. +type APIKeyDTO struct { + GroupID uint `json:"group_id"` + APIKey string `json:"api_key"` + Weight int `json:"weight,omitempty"` + Status string `json:"status"` + AutoBan *bool `json:"auto_ban,omitempty"` + BanReason string `json:"ban_reason,omitempty"` + BanUntil time.Time `json:"ban_until,omitempty"` +} diff --git a/internal/dto/binding.go b/internal/dto/binding.go index 5e36f6f..9df4cf4 100644 --- a/internal/dto/binding.go +++ b/internal/dto/binding.go @@ -1,11 +1,12 @@ package dto // BindingDTO defines inbound payload for binding creation/update. -// It maps "(namespace, public_model)" to a RouteGroup and an upstream selector. +// It maps "(namespace, public_model)" to a ProviderGroup and an upstream selector. type BindingDTO struct { Namespace string `json:"namespace"` PublicModel string `json:"public_model"` - RouteGroup string `json:"route_group"` + GroupID uint `json:"group_id"` + Weight int `json:"weight"` SelectorType string `json:"selector_type"` SelectorValue string `json:"selector_value"` Status string `json:"status"` diff --git a/internal/dto/provider_group.go b/internal/dto/provider_group.go new file mode 100644 index 0000000..2358cb5 --- /dev/null +++ b/internal/dto/provider_group.go @@ -0,0 +1,12 @@ +package dto + +// ProviderGroupDTO defines inbound payload for provider group creation/update. +type ProviderGroupDTO struct { + Name string `json:"name"` + Type string `json:"type"` + BaseURL string `json:"base_url"` + GoogleProject string `json:"google_project,omitempty"` + GoogleLocation string `json:"google_location,omitempty"` + Models []string `json:"models"` + Status string `json:"status"` +} diff --git a/internal/migrate/importer.go b/internal/migrate/importer.go index 7f07381..8b09cdc 100644 --- a/internal/migrate/importer.go +++ b/internal/migrate/importer.go @@ -258,14 +258,45 @@ func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[st } func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error { + groupCache := make(map[string]model.ProviderGroup) for _, item := range items { - name := strings.TrimSpace(item.Name) - if name == "" { - summary.Warnings = append(summary.Warnings, "skip provider with empty name") + groupName := normalizeGroup(item.PrimaryGroup) + if strings.TrimSpace(groupName) == "" { + groupName = "default" + } + group, ok := groupCache[groupName] + if !ok { + var existing model.ProviderGroup + err := i.db.Where("name = ?", groupName).First(&existing).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil { + group = existing + } else { + group = model.ProviderGroup{ + Name: groupName, + Type: strings.TrimSpace(item.Type), + BaseURL: strings.TrimSpace(item.BaseURL), + Models: strings.Join(item.Models, ","), + Status: normalizeStatus(item.Status, "active"), + } + if !i.opts.DryRun { + if err := i.db.Create(&group).Error; err != nil { + return err + } + } + } + groupCache[groupName] = group + } + + apiKey := strings.TrimSpace(item.APIKey) + if apiKey == "" { + summary.Warnings = append(summary.Warnings, "skip api key with empty api_key") continue } - var existing model.Provider - err := i.db.Where("name = ?", name).First(&existing).Error + var existingKey model.APIKey + err := i.db.Where("group_id = ? AND api_key = ?", group.ID, apiKey).First(&existingKey).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } @@ -277,16 +308,11 @@ func (i *Importer) importProviders(items []Provider, summary *ImportSummary) err continue } update := map[string]any{ - "type": strings.TrimSpace(item.Type), - "base_url": strings.TrimSpace(item.BaseURL), - "api_key": strings.TrimSpace(item.APIKey), - "group": normalizeGroup(item.PrimaryGroup), - "models": strings.Join(item.Models, ","), "weight": resolveWeight(item.Weight, item.Priority), "status": normalizeProviderStatus(item.Status), "auto_ban": item.AutoBan, } - if err := i.db.Model(&existing).Updates(update).Error; err != nil { + if err := i.db.Model(&existingKey).Updates(update).Error; err != nil { return err } summary.ProvidersUpdated++ @@ -301,18 +327,14 @@ func (i *Importer) importProviders(items []Provider, summary *ImportSummary) err continue } - provider := model.Provider{ - Name: name, - Type: strings.TrimSpace(item.Type), - BaseURL: strings.TrimSpace(item.BaseURL), - APIKey: strings.TrimSpace(item.APIKey), - Group: normalizeGroup(item.PrimaryGroup), - Models: strings.Join(item.Models, ","), + key := model.APIKey{ + GroupID: group.ID, + APIKey: apiKey, Weight: resolveWeight(item.Weight, item.Priority), Status: normalizeProviderStatus(item.Status), AutoBan: item.AutoBan, } - if err := i.db.Create(&provider).Error; err != nil { + if err := i.db.Create(&key).Error; err != nil { return err } summary.ProvidersCreated++ @@ -420,8 +442,14 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error summary.Warnings = append(summary.Warnings, "skip binding with empty model") continue } + groupName := normalizeGroup(item.RouteGroup) + var group model.ProviderGroup + if err := i.db.Where("name = ?", groupName).First(&group).Error; err != nil { + summary.Warnings = append(summary.Warnings, "skip binding with missing provider group: "+groupName) + continue + } var existing model.Binding - err := i.db.Where("namespace = ? AND public_model = ?", ns, publicModel).First(&existing).Error + err := i.db.Where("namespace = ? AND public_model = ? AND group_id = ?", ns, publicModel, group.ID).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } @@ -433,7 +461,8 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error continue } update := map[string]any{ - "route_group": normalizeGroup(item.RouteGroup), + "group_id": group.ID, + "weight": 1, "selector_type": "exact", "selector_value": publicModel, "status": normalizeStatus(item.Status, "active"), @@ -456,7 +485,8 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error binding := model.Binding{ Namespace: ns, PublicModel: publicModel, - RouteGroup: normalizeGroup(item.RouteGroup), + GroupID: group.ID, + Weight: 1, SelectorType: "exact", SelectorValue: publicModel, Status: normalizeStatus(item.Status, "active"), diff --git a/internal/migrate/schema.go b/internal/migrate/schema.go index 7188aca..78de928 100644 --- a/internal/migrate/schema.go +++ b/internal/migrate/schema.go @@ -92,7 +92,7 @@ type Key struct { // Binding represents an EZ-API binding (optional, from abilities). type Binding struct { Namespace string `json:"namespace"` - RouteGroup string `json:"route_group"` + RouteGroup string `json:"route_group"` // provider group name Model string `json:"model"` Status string `json:"status"` } diff --git a/internal/model/models.go b/internal/model/models.go index 9f07020..f94bceb 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -1,10 +1,6 @@ package model -import ( - "time" - - "gorm.io/gorm" -) +import "gorm.io/gorm" // Admin is not a database model. It's configured via environment variables. @@ -50,24 +46,6 @@ type Key struct { QuotaResetType string `gorm:"size:20" json:"quota_reset_type"` } -// Provider remains the same. -type Provider struct { - gorm.Model - Name string `gorm:"not null" json:"name"` - Type string `gorm:"not null" json:"type"` // openai, anthropic, etc. - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` - GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` - Group string `gorm:"default:'default'" json:"group"` // routing group/tier - Models string `json:"models"` // comma-separated list of supported models (e.g. "gpt-4,gpt-3.5-turbo") - Weight int `gorm:"default:1" json:"weight"` // routing weight inside route_group - Status string `gorm:"size:50;default:'active'" json:"status"` // active, auto_disabled, manual_disabled - AutoBan bool `gorm:"default:true" json:"auto_ban"` // whether DP-triggered disable is allowed - BanReason string `gorm:"size:255" json:"ban_reason"` // reason for current disable - BanUntil *time.Time `json:"ban_until"` // optional TTL for disable -} - // Model remains the same. type Model struct { gorm.Model @@ -82,13 +60,13 @@ type Model struct { MaxOutputTokens int `json:"max_output_tokens"` } -// Binding defines a stable "namespace.public_model" routing key and its target RouteGroup + selector. -// RouteGroup currently reuses Provider.Group. +// Binding defines a stable "namespace.public_model" routing key and its target ProviderGroup + selector. type Binding struct { gorm.Model Namespace string `gorm:"size:100;not null;index:idx_binding_key,unique" json:"namespace"` PublicModel string `gorm:"size:255;not null;index:idx_binding_key,unique" json:"public_model"` - RouteGroup string `gorm:"size:100;not null" json:"route_group"` + GroupID uint `gorm:"not null;index:idx_binding_key,unique" json:"group_id"` + Weight int `gorm:"default:1" json:"weight"` SelectorType string `gorm:"size:50;default:'exact'" json:"selector_type"` SelectorValue string `gorm:"size:255" json:"selector_value"` Status string `gorm:"size:50;default:'active'" json:"status"` diff --git a/internal/model/provider_group.go b/internal/model/provider_group.go new file mode 100644 index 0000000..37e1e0c --- /dev/null +++ b/internal/model/provider_group.go @@ -0,0 +1,31 @@ +package model + +import ( + "time" + + "gorm.io/gorm" +) + +// ProviderGroup represents a shared upstream definition. +type ProviderGroup struct { + gorm.Model + Name string `gorm:"size:255;uniqueIndex;not null" json:"name"` + Type string `gorm:"size:50;not null" json:"type"` // openai, anthropic, gemini + BaseURL string `gorm:"size:512;not null" json:"base_url"` + GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` + GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` + Models string `json:"models"` // comma-separated list of supported models + Status string `gorm:"size:50;default:'active'" json:"status"` +} + +// APIKey represents a credential within a provider group. +type APIKey struct { + gorm.Model + GroupID uint `gorm:"not null;index" json:"group_id"` + APIKey string `gorm:"not null" json:"api_key"` + Weight int `gorm:"default:1" json:"weight"` + Status string `gorm:"size:50;default:'active'" json:"status"` + AutoBan bool `gorm:"default:true" json:"auto_ban"` + BanReason string `gorm:"size:255" json:"ban_reason"` + BanUntil *time.Time `json:"ban_until"` +} diff --git a/internal/service/model_registry.go b/internal/service/model_registry.go index 492f3b2..e8951fb 100644 --- a/internal/service/model_registry.go +++ b/internal/service/model_registry.go @@ -18,7 +18,6 @@ import ( "time" "github.com/ez-api/ez-api/internal/model" - groupx "github.com/ez-api/foundation/group" "github.com/ez-api/foundation/jsoncodec" "github.com/ez-api/foundation/modelcap" "github.com/ez-api/foundation/routing" @@ -373,6 +372,33 @@ type upstreamCap struct { SupportsTools boolVal } +func boolValEqual(a, b boolVal) bool { + if a.Known != b.Known { + return false + } + if !a.Known { + return true + } + return a.Val == b.Val +} + +func intValEqual(a, b intVal) bool { + if a.Known != b.Known { + return false + } + if !a.Known { + return true + } + return a.Val == b.Val +} + +func capsEqual(a, b upstreamCap) bool { + return boolValEqual(a.SupportsVision, b.SupportsVision) && + boolValEqual(a.SupportsTools, b.SupportsTools) && + intValEqual(a.ContextWindow, b.ContextWindow) && + intValEqual(a.MaxOutputTokens, b.MaxOutputTokens) +} + type modelsDevRegistry struct { ByProviderModel map[string]upstreamCap // key: providerID|modelID ByModel map[string]upstreamCap // fallback: modelID @@ -707,9 +733,13 @@ func (a *capAgg) finalize(name string) modelcap.Model { } func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) { - var providers []model.Provider - if err := s.db.Find(&providers).Error; err != nil { - return nil, nil, fmt.Errorf("load providers: %w", err) + var groups []model.ProviderGroup + if err := s.db.Find(&groups).Error; err != nil { + return nil, nil, fmt.Errorf("load provider groups: %w", err) + } + var apiKeys []model.APIKey + if err := s.db.Find(&apiKeys).Error; err != nil { + return nil, nil, fmt.Errorf("load api keys: %w", err) } var bindings []model.Binding if err := s.db.Find(&bindings).Error; err != nil { @@ -718,21 +748,29 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode type providerLite struct { id uint - group string ptype string models []string } - providersByGroup := make(map[string][]providerLite) + providersByGroupID := make(map[uint]providerLite) now := time.Now().Unix() - for _, p := range providers { - if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" { + activeKeys := make(map[uint]bool) + for _, k := range apiKeys { + if strings.TrimSpace(k.Status) != "" && strings.TrimSpace(k.Status) != "active" { continue } - if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now { + if k.BanUntil != nil && k.BanUntil.UTC().Unix() > now { continue } - group := groupx.Normalize(p.Group) - rawModels := strings.Split(p.Models, ",") + activeKeys[k.GroupID] = true + } + for _, g := range groups { + if strings.TrimSpace(g.Status) != "" && strings.TrimSpace(g.Status) != "active" { + continue + } + if !activeKeys[g.ID] { + continue + } + rawModels := strings.Split(g.Models, ",") var outModels []string for _, m := range rawModels { m = strings.TrimSpace(m) @@ -740,19 +778,20 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode outModels = append(outModels, m) } } - if group == "" || len(outModels) == 0 { + if len(outModels) == 0 { continue } - providersByGroup[group] = append(providersByGroup[group], providerLite{ - id: p.ID, - group: group, - ptype: strings.TrimSpace(p.Type), + providersByGroupID[g.ID] = providerLite{ + id: g.ID, + ptype: strings.TrimSpace(g.Type), models: outModels, - }) + } } modelsOut := make(map[string]modelcap.Model) payloads := make(map[string]string) + capBaseline := make(map[string]upstreamCap) + capBaselineOK := make(map[string]bool) for _, b := range bindings { if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" { @@ -764,12 +803,8 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode continue } key := ns + "." + pm - rg := groupx.Normalize(b.RouteGroup) - if rg == "" { - continue - } - pgroup := providersByGroup[rg] - if len(pgroup) == 0 { + group := providersByGroupID[b.GroupID] + if group.id == 0 { continue } @@ -782,13 +817,25 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType)) selectorValue := strings.TrimSpace(b.SelectorValue) - for _, p := range pgroup { - up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models) - if err != nil { - continue + up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, group.models) + if err == nil { + cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(group.ptype), up) + if baseOK, seen := capBaselineOK[key]; seen { + if !ok || !baseOK || !capsEqual(capBaseline[key], cap) { + return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key) + } + } else { + capBaselineOK[key] = ok + if ok { + capBaseline[key] = cap + } } - cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up) agg.merge(cap, ok) + } else { + if _, seen := capBaselineOK[key]; seen { + return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key) + } + capBaselineOK[key] = false } out := agg.finalize(key) diff --git a/internal/service/model_registry_check_test.go b/internal/service/model_registry_check_test.go index 2ff8df1..7886866 100644 --- a/internal/service/model_registry_check_test.go +++ b/internal/service/model_registry_check_test.go @@ -23,7 +23,7 @@ func TestModelRegistry_Check(t *testing.T) { if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil { t.Fatalf("migrate: %v", err) } diff --git a/internal/service/model_registry_test.go b/internal/service/model_registry_test.go index 82e600c..cbc3f2a 100644 --- a/internal/service/model_registry_test.go +++ b/internal/service/model_registry_test.go @@ -58,22 +58,31 @@ func TestModelRegistry_RefreshAndRollback(t *testing.T) { if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil { t.Fatalf("migrate: %v", err) } - if err := db.Create(&model.Provider{ - Name: "p1", - Type: "openai", - Group: "rg", - Models: "gpt-4o-mini", - Status: "active", + group := model.ProviderGroup{ + Name: "rg", + Type: "openai", + BaseURL: "https://api.openai.com/v1", + Models: "gpt-4o-mini", + Status: "active", + } + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create provider group: %v", err) + } + if err := db.Create(&model.APIKey{ + GroupID: group.ID, + APIKey: "k", + Status: "active", }).Error; err != nil { - t.Fatalf("create provider: %v", err) + t.Fatalf("create api key: %v", err) } if err := db.Create(&model.Binding{ Namespace: "ns", PublicModel: "m", - RouteGroup: "rg", + GroupID: group.ID, + Weight: 1, SelectorType: "exact", SelectorValue: "gpt-4o-mini", Status: "active", diff --git a/internal/service/sync.go b/internal/service/sync.go index c4374c6..f0f5f20 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -77,78 +77,29 @@ func (s *SyncService) SyncMaster(master *model.Master) error { return nil } -// SyncProvider writes a single provider into Redis hash storage and updates routing tables. -func (s *SyncService) SyncProvider(provider *model.Provider) error { - ctx := context.Background() - group := groupx.Normalize(provider.Group) - models := strings.Split(provider.Models, ",") - - snap := providerSnapshot{ - ID: provider.ID, - Name: provider.Name, - Type: provider.Type, - BaseURL: provider.BaseURL, - APIKey: provider.APIKey, - GoogleProject: provider.GoogleProject, - GoogleLocation: provider.GoogleLocation, - Group: group, - Models: models, - Weight: provider.Weight, - Status: normalizeStatus(provider.Status), - AutoBan: provider.AutoBan, - BanReason: provider.BanReason, - } - if provider.BanUntil != nil { - snap.BanUntil = provider.BanUntil.UTC().Unix() - } - - // 1. Update Provider Config - if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil { - return err - } - - // 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id) - // Note: This is an additive operation. Removing models requires full sync or smarter logic. - pipe := s.rdb.Pipeline() - for _, m := range models { - m = strings.TrimSpace(m) - if m == "" { - continue - } - if snap.Status != "active" { - continue - } - if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { - continue - } - routeKey := fmt.Sprintf("route:group:%s:%s", group, m) - pipe.SAdd(ctx, routeKey, provider.ID) - } - _, err := pipe.Exec(ctx) - return err -} - -// SyncProviderDelete removes provider snapshot and routing entries from Redis. -func (s *SyncService) SyncProviderDelete(provider *model.Provider) error { - if provider == nil { - return fmt.Errorf("provider required") +// SyncProviders rebuilds provider snapshots from ProviderGroup + APIKey tables. +func (s *SyncService) SyncProviders(db *gorm.DB) error { + if db == nil { + return fmt.Errorf("db required") } ctx := context.Background() - group := groupx.Normalize(provider.Group) - models := strings.Split(provider.Models, ",") + + var groups []model.ProviderGroup + if err := db.Find(&groups).Error; err != nil { + return fmt.Errorf("load provider groups: %w", err) + } + var apiKeys []model.APIKey + if err := db.Find(&apiKeys).Error; err != nil { + return fmt.Errorf("load api keys: %w", err) + } pipe := s.rdb.TxPipeline() - pipe.HDel(ctx, "config:providers", fmt.Sprintf("%d", provider.ID)) - for _, m := range models { - m = strings.TrimSpace(m) - if m == "" { - continue - } - routeKey := fmt.Sprintf("route:group:%s:%s", group, m) - pipe.SRem(ctx, routeKey, provider.ID) + pipe.Del(ctx, "config:providers") + if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil { + return err } if _, err := pipe.Exec(ctx); err != nil { - return fmt.Errorf("delete provider snapshot: %w", err) + return fmt.Errorf("write provider snapshot: %w", err) } return nil } @@ -203,6 +154,7 @@ type providerSnapshot struct { APIKey string `json:"api_key"` GoogleProject string `json:"google_project,omitempty"` GoogleLocation string `json:"google_location,omitempty"` + GroupID uint `json:"group_id,omitempty"` Group string `json:"group"` Models []string `json:"models"` Weight int `json:"weight,omitempty"` @@ -212,15 +164,100 @@ type providerSnapshot struct { BanUntil int64 `json:"ban_until,omitempty"` // unix seconds } +func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pipeliner, groups []model.ProviderGroup, apiKeys []model.APIKey) error { + groupMap := make(map[uint]model.ProviderGroup, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + + for _, k := range apiKeys { + g, ok := groupMap[k.GroupID] + if !ok { + continue + } + groupName := groupx.Normalize(g.Name) + if strings.TrimSpace(groupName) == "" { + groupName = "default" + } + groupStatus := normalizeStatus(g.Status) + keyStatus := normalizeStatus(k.Status) + status := keyStatus + if groupStatus != "" && groupStatus != "active" { + status = groupStatus + } + + rawModels := strings.Split(g.Models, ",") + var models []string + for _, m := range rawModels { + m = strings.TrimSpace(m) + if m != "" { + models = append(models, m) + } + } + + name := strings.TrimSpace(g.Name) + if name == "" { + name = groupName + } + name = fmt.Sprintf("%s#%d", name, k.ID) + + snap := providerSnapshot{ + ID: k.ID, + Name: name, + Type: strings.TrimSpace(g.Type), + BaseURL: strings.TrimSpace(g.BaseURL), + APIKey: strings.TrimSpace(k.APIKey), + GoogleProject: strings.TrimSpace(g.GoogleProject), + GoogleLocation: strings.TrimSpace(g.GoogleLocation), + GroupID: g.ID, + Group: groupName, + Models: models, + Weight: k.Weight, + Status: status, + AutoBan: k.AutoBan, + BanReason: strings.TrimSpace(k.BanReason), + } + if k.BanUntil != nil { + snap.BanUntil = k.BanUntil.UTC().Unix() + } + + payload, err := jsoncodec.Marshal(snap) + if err != nil { + return fmt.Errorf("marshal provider %d: %w", k.ID, err) + } + pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", k.ID), payload) + + // Legacy route table maintenance for compatibility. + for _, m := range models { + if m == "" { + continue + } + if snap.Status != "active" { + continue + } + if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { + continue + } + routeKey := fmt.Sprintf("route:group:%s:%s", groupName, m) + pipe.SAdd(ctx, routeKey, k.ID) + } + } + return nil +} + // keySnapshot is no longer needed as we write directly to auth:token:* // SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes. func (s *SyncService) SyncAll(db *gorm.DB) error { ctx := context.Background() - var providers []model.Provider - if err := db.Find(&providers).Error; err != nil { - return fmt.Errorf("load providers: %w", err) + var groups []model.ProviderGroup + if err := db.Find(&groups).Error; err != nil { + return fmt.Errorf("load provider groups: %w", err) + } + var apiKeys []model.APIKey + if err := db.Find(&apiKeys).Error; err != nil { + return fmt.Errorf("load api keys: %w", err) } var keys []model.Key @@ -259,53 +296,8 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { pipe.Del(ctx, masterKeys...) } - // Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them) - // For MVP, we rely on the fact that we are rebuilding. - // Ideally, we should scan "route:group:*" and del, but let's just rebuild. - - for _, p := range providers { - group := groupx.Normalize(p.Group) - models := strings.Split(p.Models, ",") - - snap := providerSnapshot{ - ID: p.ID, - Name: p.Name, - Type: p.Type, - BaseURL: p.BaseURL, - APIKey: p.APIKey, - GoogleProject: p.GoogleProject, - GoogleLocation: p.GoogleLocation, - Group: group, - Models: models, - Weight: p.Weight, - Status: normalizeStatus(p.Status), - AutoBan: p.AutoBan, - BanReason: p.BanReason, - } - if p.BanUntil != nil { - snap.BanUntil = p.BanUntil.UTC().Unix() - } - payload, err := jsoncodec.Marshal(snap) - if err != nil { - return fmt.Errorf("marshal provider %d: %w", p.ID, err) - } - pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload) - - // Rebuild Routing Table - for _, m := range models { - m = strings.TrimSpace(m) - if m == "" { - continue - } - if snap.Status != "active" { - continue - } - if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { - continue - } - routeKey := fmt.Sprintf("route:group:%s:%s", group, m) - pipe.SAdd(ctx, routeKey, p.ID) - } + if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil { + return err } for _, k := range keys { @@ -382,7 +374,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { return err } - if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil { + if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil { return err } @@ -398,9 +390,13 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { func (s *SyncService) SyncBindings(db *gorm.DB) error { ctx := context.Background() - var providers []model.Provider - if err := db.Find(&providers).Error; err != nil { - return fmt.Errorf("load providers: %w", err) + var groups []model.ProviderGroup + if err := db.Find(&groups).Error; err != nil { + return fmt.Errorf("load provider groups: %w", err) + } + var apiKeys []model.APIKey + if err := db.Find(&apiKeys).Error; err != nil { + return fmt.Errorf("load api keys: %w", err) } var bindings []model.Binding if err := db.Find(&bindings).Error; err != nil { @@ -409,7 +405,7 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error { pipe := s.rdb.TxPipeline() pipe.Del(ctx, "config:bindings", "meta:bindings_meta") - if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil { + if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil { return err } if _, err := pipe.Exec(ctx); err != nil { @@ -418,32 +414,65 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error { return nil } -func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, providers []model.Provider) error { - // Group providers by route group for selector resolution. - type providerLite struct { - id uint - group string - models []string +func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, groups []model.ProviderGroup, apiKeys []model.APIKey) error { + type groupLite struct { + id uint + name string + ptype string + baseURL string + googleProject string + googleLoc string + models []string + status string } - providersByGroup := make(map[string][]providerLite) - for _, p := range providers { - group := groupx.Normalize(p.Group) - models := strings.Split(p.Models, ",") + groupsByID := make(map[uint]groupLite, len(groups)) + for _, g := range groups { + rawModels := strings.Split(g.Models, ",") var outModels []string - for _, m := range models { + for _, m := range rawModels { m = strings.TrimSpace(m) if m != "" { outModels = append(outModels, m) } } - providersByGroup[group] = append(providersByGroup[group], providerLite{ - id: p.ID, - group: group, - models: outModels, + groupsByID[g.ID] = groupLite{ + id: g.ID, + name: groupx.Normalize(g.Name), + ptype: strings.TrimSpace(g.Type), + baseURL: strings.TrimSpace(g.BaseURL), + googleProject: strings.TrimSpace(g.GoogleProject), + googleLoc: strings.TrimSpace(g.GoogleLocation), + models: outModels, + status: normalizeStatus(g.Status), + } + } + + type apiKeyLite struct { + id uint + groupID uint + status string + weight int + autoBan bool + banUntil *time.Time + } + keysByGroup := make(map[uint][]apiKeyLite) + for _, k := range apiKeys { + keysByGroup[k.GroupID] = append(keysByGroup[k.GroupID], apiKeyLite{ + id: k.ID, + groupID: k.GroupID, + status: normalizeStatus(k.Status), + weight: k.Weight, + autoBan: k.AutoBan, + banUntil: k.BanUntil, }) } + type bindingAgg struct { + snap routing.BindingSnapshot + } + snaps := make(map[string]*routing.BindingSnapshot) now := time.Now().Unix() + for _, b := range bindings { if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" { continue @@ -453,43 +482,65 @@ func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipe if ns == "" || pm == "" { continue } - rg := groupx.Normalize(b.RouteGroup) - if rg == "" { + group, ok := groupsByID[b.GroupID] + if !ok { + continue + } + if group.status != "" && group.status != "active" { continue } - snap := struct { - Namespace string `json:"namespace"` - PublicModel string `json:"public_model"` - RouteGroup string `json:"route_group"` - SelectorType string `json:"selector_type,omitempty"` - SelectorValue string `json:"selector_value,omitempty"` - Status string `json:"status,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - Upstreams map[string]string `json:"upstreams"` - }{ - Namespace: ns, - PublicModel: pm, - RouteGroup: rg, + key := ns + "." + pm + snap := snaps[key] + if snap == nil { + snap = &routing.BindingSnapshot{ + Namespace: ns, + PublicModel: pm, + Status: "active", + UpdatedAt: now, + } + snaps[key] = snap + } + + candidate := routing.BindingCandidate{ + GroupID: group.id, + RouteGroup: group.name, + Weight: normalizeWeight(b.Weight), SelectorType: strings.TrimSpace(b.SelectorType), SelectorValue: strings.TrimSpace(b.SelectorValue), Status: "active", - UpdatedAt: now, Upstreams: make(map[string]string), } selectorType := strings.TrimSpace(b.SelectorType) selectorValue := strings.TrimSpace(b.SelectorValue) + keys := keysByGroup[b.GroupID] + if len(keys) == 0 { + candidate.Error = "no_provider" + } - for _, p := range providersByGroup[rg] { - up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, p.models) + nowUnix := time.Now().Unix() + for _, k := range keys { + if k.status != "" && k.status != "active" { + continue + } + if k.banUntil != nil && k.banUntil.UTC().Unix() > nowUnix { + continue + } + up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, group.models) if err != nil { continue } - snap.Upstreams[fmt.Sprintf("%d", p.id)] = up + candidate.Upstreams[fmt.Sprintf("%d", k.id)] = up + } + if len(candidate.Upstreams) == 0 && candidate.Error == "" { + candidate.Error = "config_error" } - key := ns + "." + pm + snap.Candidates = append(snap.Candidates, candidate) + } + + for key, snap := range snaps { payload, err := jsoncodec.Marshal(snap) if err != nil { return fmt.Errorf("marshal config:bindings:%s: %w", key, err) @@ -519,6 +570,13 @@ func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val inter return nil } +func normalizeWeight(weight int) int { + if weight <= 0 { + return 1 + } + return weight +} + func normalizeStatus(status string) string { st := strings.ToLower(strings.TrimSpace(status)) if st == "" { diff --git a/internal/service/sync_bindings_spec_test.go b/internal/service/sync_bindings_spec_test.go index bb6c181..d70fcbc 100644 --- a/internal/service/sync_bindings_spec_test.go +++ b/internal/service/sync_bindings_spec_test.go @@ -13,10 +13,13 @@ import ( ) type bindingSnapshot struct { - Namespace string `json:"namespace"` - PublicModel string `json:"public_model"` - RouteGroup string `json:"route_group"` - Upstreams map[string]string `json:"upstreams"` + Namespace string `json:"namespace"` + PublicModel string `json:"public_model"` + Candidates []struct { + RouteGroup string `json:"route_group"` + Error string `json:"error,omitempty"` + Upstreams map[string]string `json:"upstreams"` + } `json:"candidates"` } func TestSyncBindings_SelectorExact(t *testing.T) { @@ -26,15 +29,19 @@ func TestSyncBindings_SelectorExact(t *testing.T) { if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil { t.Fatalf("migrate: %v", err) } - p := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "m"} - if err := db.Create(&p).Error; err != nil { - t.Fatalf("create provider: %v", err) + group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m", Status: "active"} + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) } - b := model.Binding{Namespace: "ns", PublicModel: "m", RouteGroup: "rg", SelectorType: "exact", Status: "active"} + key := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"} + if err := db.Create(&key).Error; err != nil { + t.Fatalf("create api key: %v", err) + } + b := model.Binding{Namespace: "ns", PublicModel: "m", GroupID: group.ID, Weight: 1, SelectorType: "exact", Status: "active"} if err := db.Create(&b).Error; err != nil { t.Fatalf("create binding: %v", err) } @@ -54,8 +61,11 @@ func TestSyncBindings_SelectorExact(t *testing.T) { if err := json.Unmarshal([]byte(raw), &snap); err != nil { t.Fatalf("unmarshal: %v", err) } - if snap.Upstreams == nil || snap.Upstreams[jsonID(p.ID)] != "m" { - t.Fatalf("unexpected upstreams: %+v", snap.Upstreams) + if len(snap.Candidates) != 1 { + t.Fatalf("expected 1 candidate, got %+v", snap.Candidates) + } + if snap.Candidates[0].Upstreams == nil || snap.Candidates[0].Upstreams[jsonID(key.ID)] != "m" { + t.Fatalf("unexpected upstreams: %+v", snap.Candidates[0].Upstreams) } } @@ -66,27 +76,31 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) { if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil { + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil { t.Fatalf("migrate: %v", err) } - p1 := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "moonshot/kimi2,kimi2"} - p2 := model.Provider{Name: "p2", Type: "openai", Group: "rg", Models: "moonshot/kimi2"} - if err := db.Create(&p1).Error; err != nil { - t.Fatalf("create provider1: %v", err) + group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "moonshot/kimi2,kimi2", Status: "active"} + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) } - if err := db.Create(&p2).Error; err != nil { - t.Fatalf("create provider2: %v", err) + k1 := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"} + k2 := model.APIKey{GroupID: group.ID, APIKey: "k2", Status: "active"} + if err := db.Create(&k1).Error; err != nil { + t.Fatalf("create api key1: %v", err) + } + if err := db.Create(&k2).Error; err != nil { + t.Fatalf("create api key2: %v", err) } // Regex should match uniquely (moonshot/kimi2 only). - bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", RouteGroup: "rg", SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"} + bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", GroupID: group.ID, Weight: 1, SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"} if err := db.Create(&bRegex).Error; err != nil { t.Fatalf("create binding regex: %v", err) } // Normalize_exact should match p2 (moonshot/kimi2) for "kimi2". - bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", RouteGroup: "rg", SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"} + bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", GroupID: group.ID, Weight: 1, SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"} if err := db.Create(&bNorm).Error; err != nil { t.Fatalf("create binding normalize: %v", err) } @@ -104,8 +118,12 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) { if err := json.Unmarshal([]byte(raw), &snapRegex); err != nil { t.Fatalf("unmarshal regex: %v", err) } - if snapRegex.Upstreams[jsonID(p1.ID)] != "moonshot/kimi2" || snapRegex.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" { - t.Fatalf("unexpected regex upstreams: %+v", snapRegex.Upstreams) + if len(snapRegex.Candidates) != 1 { + t.Fatalf("expected 1 candidate, got %+v", snapRegex.Candidates) + } + upstreams := snapRegex.Candidates[0].Upstreams + if upstreams[jsonID(k1.ID)] != "moonshot/kimi2" || upstreams[jsonID(k2.ID)] != "moonshot/kimi2" { + t.Fatalf("unexpected regex upstreams: %+v", upstreams) } // Normalize_exact binding should include p2 but exclude p1 due to multi-match (moonshot/kimi2 + kimi2). @@ -114,11 +132,11 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) { if err := json.Unmarshal([]byte(raw), &snapNorm); err != nil { t.Fatalf("unmarshal normalize: %v", err) } - if snapNorm.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" { - t.Fatalf("expected p2 upstream, got %+v", snapNorm.Upstreams) + if len(snapNorm.Candidates) != 1 { + t.Fatalf("expected 1 candidate, got %+v", snapNorm.Candidates) } - if _, ok := snapNorm.Upstreams[jsonID(p1.ID)]; ok { - t.Fatalf("did not expect p1 upstream due to normalize multi-match, got %+v", snapNorm.Upstreams) + if len(snapNorm.Candidates[0].Upstreams) != 0 || snapNorm.Candidates[0].Error != "config_error" { + t.Fatalf("expected config_error with no upstreams, got %+v", snapNorm.Candidates[0]) } } diff --git a/internal/service/sync_test.go b/internal/service/sync_test.go index dea6f51..3232ece 100644 --- a/internal/service/sync_test.go +++ b/internal/service/sync_test.go @@ -2,68 +2,76 @@ package service import ( "encoding/json" - "reflect" + "strconv" "testing" "github.com/alicebob/miniredis/v2" "github.com/ez-api/ez-api/internal/model" - "github.com/ez-api/foundation/contract" "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) -func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) { - goldenRaw := contract.ProviderSnapshotJSON() - var golden map[string]any - if err := json.Unmarshal(goldenRaw, &golden); err != nil { - t.Fatalf("parse golden json: %v", err) - } - +func TestSyncProviders_WritesSnapshotAndRouting(t *testing.T) { mr := miniredis.RunT(t) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) - svc := NewSyncService(rdb) - p := &model.Provider{ - Name: "p1", + db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + group := model.ProviderGroup{ + Name: "default", Type: "vertex-express", - Group: "default", + BaseURL: "https://vertex.example", + GoogleLocation: "global", Models: "gemini-3-pro-preview", Status: "active", - AutoBan: true, - GoogleProject: "", - GoogleLocation: "global", } - p.ID = 42 - - if err := svc.SyncProvider(p); err != nil { - t.Fatalf("SyncProvider: %v", err) + if err := db.Create(&group).Error; err != nil { + t.Fatalf("create group: %v", err) + } + key := model.APIKey{ + GroupID: group.ID, + APIKey: "k", + Status: "active", + AutoBan: true, + } + if err := db.Create(&key).Error; err != nil { + t.Fatalf("create key: %v", err) } - raw := mr.HGet("config:providers", "42") + if err := svc.SyncProviders(db); err != nil { + t.Fatalf("SyncProviders: %v", err) + } + + raw := mr.HGet("config:providers", jsonID(key.ID)) if raw == "" { t.Fatalf("expected config:providers hash entry") } - var snap map[string]any if err := json.Unmarshal([]byte(raw), &snap); err != nil { t.Fatalf("invalid snapshot json: %v", err) } - for k, v := range golden { - if !reflect.DeepEqual(snap[k], v) { - t.Fatalf("snapshot mismatch for %q: got=%#v want=%#v", k, snap[k], v) - } + if snap["group"] != "default" { + t.Fatalf("expected group default, got %#v", snap["group"]) } routeKey := "route:group:default:gemini-3-pro-preview" if !mr.Exists(routeKey) { t.Fatalf("expected routing key %q to exist", routeKey) } - ok, err := mr.SIsMember(routeKey, "42") + ok, err := mr.SIsMember(routeKey, jsonID(key.ID)) if err != nil { t.Fatalf("SIsMember: %v", err) } if !ok { - t.Fatalf("expected provider id 42 in routing set %q", routeKey) + t.Fatalf("expected provider id in routing set %q", routeKey) } } @@ -113,34 +121,6 @@ func TestSyncModelDelete_RemovesMeta(t *testing.T) { } } -func TestSyncProviderDelete_RemovesSnapshotAndRouting(t *testing.T) { - mr := miniredis.RunT(t) - rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) - svc := NewSyncService(rdb) - - p := &model.Provider{ - Name: "p1", - Type: "openai", - Group: "default", - Models: "gpt-4o-mini,gpt-4o", - Status: "active", - } - p.ID = 7 - - if err := svc.SyncProvider(p); err != nil { - t.Fatalf("SyncProvider: %v", err) - } - if err := svc.SyncProviderDelete(p); err != nil { - t.Fatalf("SyncProviderDelete: %v", err) - } - - if got := mr.HGet("config:providers", "7"); got != "" { - t.Fatalf("expected provider snapshot removed, got %q", got) - } - if ok, _ := mr.SIsMember("route:group:default:gpt-4o-mini", "7"); ok { - t.Fatalf("expected provider removed from route set") - } - if ok, _ := mr.SIsMember("route:group:default:gpt-4o", "7"); ok { - t.Fatalf("expected provider removed from route set") - } +func jsonID(id uint) string { + return strconv.FormatUint(uint64(id), 10) }