mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
refactor(api): split Provider into ProviderGroup and APIKey models
Restructure the provider management system by separating the monolithic Provider model into two distinct entities: - ProviderGroup: defines shared upstream configuration (type, base_url, google settings, models, status) - APIKey: represents individual credentials within a group (api_key, weight, status, auto_ban, ban settings) This change also updates: - Binding model to reference GroupID instead of RouteGroup string - All CRUD handlers for the new provider-group and api-key endpoints - Sync service to rebuild provider snapshots from joined tables - Model registry to aggregate capabilities across group/key pairs - Access handler to validate namespace existence and subset constraints - Migration importer to handle the new schema structure - All related tests to use the new model relationships BREAKING CHANGE: Provider API endpoints replaced with /provider-groups and /api-keys endpoints; Binding.RouteGroup replaced with Binding.GroupID
This commit is contained in:
@@ -135,7 +135,7 @@ func main() {
|
|||||||
|
|
||||||
// Auto Migrate
|
// Auto Migrate
|
||||||
if logDB != db {
|
if logDB != db {
|
||||||
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}); err != nil {
|
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}); err != nil {
|
||||||
fatal(logger, "failed to auto migrate", "err", err)
|
fatal(logger, "failed to auto migrate", "err", err)
|
||||||
}
|
}
|
||||||
if err := logDB.AutoMigrate(&model.LogRecord{}); err != nil {
|
if err := logDB.AutoMigrate(&model.LogRecord{}); err != nil {
|
||||||
@@ -145,7 +145,7 @@ func main() {
|
|||||||
fatal(logger, "failed to ensure log indexes", "err", err)
|
fatal(logger, "failed to ensure log indexes", "err", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.LogRecord{}); err != nil {
|
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}, &model.OperationLog{}, &model.LogRecord{}); err != nil {
|
||||||
fatal(logger, "failed to auto migrate", "err", err)
|
fatal(logger, "failed to auto migrate", "err", err)
|
||||||
}
|
}
|
||||||
if err := service.EnsureLogIndexes(db); err != nil {
|
if err := service.EnsureLogIndexes(db); err != nil {
|
||||||
@@ -287,17 +287,17 @@ func main() {
|
|||||||
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
|
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
|
||||||
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
|
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
|
||||||
// Other admin routes for managing providers, models, etc.
|
// Other admin routes for managing providers, models, etc.
|
||||||
adminGroup.POST("/providers", handler.CreateProvider)
|
adminGroup.POST("/provider-groups", handler.CreateProviderGroup)
|
||||||
adminGroup.GET("/providers", handler.ListProviders)
|
adminGroup.GET("/provider-groups", handler.ListProviderGroups)
|
||||||
adminGroup.GET("/providers/:id", handler.GetProvider)
|
adminGroup.GET("/provider-groups/:id", handler.GetProviderGroup)
|
||||||
adminGroup.POST("/providers/preset", handler.CreateProviderPreset)
|
adminGroup.PUT("/provider-groups/:id", handler.UpdateProviderGroup)
|
||||||
adminGroup.POST("/providers/custom", handler.CreateProviderCustom)
|
adminGroup.DELETE("/provider-groups/:id", handler.DeleteProviderGroup)
|
||||||
adminGroup.POST("/providers/google", handler.CreateProviderGoogle)
|
adminGroup.POST("/api-keys", handler.CreateAPIKey)
|
||||||
adminGroup.PUT("/providers/:id", handler.UpdateProvider)
|
adminGroup.GET("/api-keys", handler.ListAPIKeys)
|
||||||
adminGroup.DELETE("/providers/:id", handler.DeleteProvider)
|
adminGroup.GET("/api-keys/:id", handler.GetAPIKey)
|
||||||
adminGroup.POST("/providers/batch", handler.BatchProviders)
|
adminGroup.PUT("/api-keys/:id", handler.UpdateAPIKey)
|
||||||
adminGroup.POST("/providers/:id/test", handler.TestProvider)
|
adminGroup.DELETE("/api-keys/:id", handler.DeleteAPIKey)
|
||||||
adminGroup.POST("/providers/:id/fetch-models", handler.FetchProviderModels)
|
adminGroup.POST("/api-keys/batch", handler.BatchAPIKeys)
|
||||||
adminGroup.POST("/models", handler.CreateModel)
|
adminGroup.POST("/models", handler.CreateModel)
|
||||||
adminGroup.GET("/models", handler.ListModels)
|
adminGroup.GET("/models", handler.ListModels)
|
||||||
adminGroup.PUT("/models/:id", handler.UpdateModel)
|
adminGroup.PUT("/models/:id", handler.UpdateModel)
|
||||||
@@ -406,7 +406,7 @@ func runImport(logger *slog.Logger, args []string) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil {
|
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.ProviderGroup{}, &model.APIKey{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil {
|
||||||
logger.Error("failed to auto migrate", "err", err)
|
logger.Error("failed to auto migrate", "err", err)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccessResponse struct {
|
type AccessResponse struct {
|
||||||
@@ -94,6 +96,10 @@ func (h *Handler) UpdateMasterAccess(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
nsList := normalizeNamespaces(nextNamespaces, nextDefault)
|
nsList := normalizeNamespaces(nextNamespaces, nextDefault)
|
||||||
nextNamespaces = strings.Join(nsList, ",")
|
nextNamespaces = strings.Join(nsList, ",")
|
||||||
|
if err := ensureNamespacesExist(h.db, nsList); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.db.Model(&m).Updates(map[string]any{
|
if err := h.db.Model(&m).Updates(map[string]any{
|
||||||
"default_namespace": nextDefault,
|
"default_namespace": nextDefault,
|
||||||
@@ -203,6 +209,21 @@ func (h *Handler) UpdateKeyAccess(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
nsList := normalizeNamespaces(nextNamespaces, nextDefault)
|
nsList := normalizeNamespaces(nextNamespaces, nextDefault)
|
||||||
nextNamespaces = strings.Join(nsList, ",")
|
nextNamespaces = strings.Join(nsList, ",")
|
||||||
|
if err := ensureNamespacesExist(h.db, nsList); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var master model.Master
|
||||||
|
if err := h.db.First(&master, k.MasterID).Error; err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "master not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
masterNamespaces := normalizeNamespaces(master.Namespaces, master.DefaultNamespace)
|
||||||
|
if !isSubset(nsList, masterNamespaces) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "namespaces must be a subset of master namespaces"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.db.Model(&k).Updates(map[string]any{
|
if err := h.db.Model(&k).Updates(map[string]any{
|
||||||
"default_namespace": nextDefault,
|
"default_namespace": nextDefault,
|
||||||
@@ -264,6 +285,39 @@ func normalizeNamespaces(raw string, defaultNamespace string) []string {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureNamespacesExist(db *gorm.DB, namespaces []string) error {
|
||||||
|
if db == nil {
|
||||||
|
return fmt.Errorf("db required")
|
||||||
|
}
|
||||||
|
if len(namespaces) == 0 {
|
||||||
|
return fmt.Errorf("namespaces required")
|
||||||
|
}
|
||||||
|
var rows []model.Namespace
|
||||||
|
if err := db.Where("name IN ?", namespaces).Find(&rows).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to load namespaces")
|
||||||
|
}
|
||||||
|
if len(rows) != len(namespaces) {
|
||||||
|
return fmt.Errorf("namespace not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSubset(child, parent []string) bool {
|
||||||
|
if len(child) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
parentSet := make(map[string]struct{}, len(parent))
|
||||||
|
for _, p := range parent {
|
||||||
|
parentSet[strings.TrimSpace(p)] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, c := range child {
|
||||||
|
if _, ok := parentSet[strings.TrimSpace(c)]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func parseUintParam(c *gin.Context, name string) (uint, bool) {
|
func parseUintParam(c *gin.Context, name string) (uint, bool) {
|
||||||
idRaw := strings.TrimSpace(c.Param(name))
|
idRaw := strings.TrimSpace(c.Param(name))
|
||||||
if idRaw == "" {
|
if idRaw == "" {
|
||||||
|
|||||||
259
internal/api/api_key_handler.go
Normal file
259
internal/api/api_key_handler.go
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ez-api/ez-api/internal/dto"
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/ez-api/foundation/provider"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateAPIKey godoc
|
||||||
|
// @Summary Create an API key
|
||||||
|
// @Description Create an API key for a provider group
|
||||||
|
// @Tags admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param key body dto.APIKeyDTO true "API key payload"
|
||||||
|
// @Success 201 {object} model.APIKey
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/api-keys [post]
|
||||||
|
func (h *Handler) CreateAPIKey(c *gin.Context) {
|
||||||
|
var req dto.APIKeyDTO
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.GroupID == 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "group_id required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, req.GroupID).Error; err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "provider group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := strings.TrimSpace(req.APIKey)
|
||||||
|
ptype := provider.NormalizeType(group.Type)
|
||||||
|
if provider.IsGoogleFamily(ptype) && !provider.IsVertexFamily(ptype) && apiKey == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := strings.TrimSpace(req.Status)
|
||||||
|
if status == "" {
|
||||||
|
status = "active"
|
||||||
|
}
|
||||||
|
autoBan := true
|
||||||
|
if req.AutoBan != nil {
|
||||||
|
autoBan = *req.AutoBan
|
||||||
|
}
|
||||||
|
|
||||||
|
key := model.APIKey{
|
||||||
|
GroupID: req.GroupID,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Weight: normalizeWeight(req.Weight),
|
||||||
|
Status: status,
|
||||||
|
AutoBan: autoBan,
|
||||||
|
BanReason: strings.TrimSpace(req.BanReason),
|
||||||
|
}
|
||||||
|
if !req.BanUntil.IsZero() {
|
||||||
|
tu := req.BanUntil.UTC()
|
||||||
|
key.BanUntil = &tu
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.Create(&key).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create api key", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAPIKeys godoc
|
||||||
|
// @Summary List API keys
|
||||||
|
// @Description List API keys
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param page query int false "page (1-based)"
|
||||||
|
// @Param limit query int false "limit (default 50, max 200)"
|
||||||
|
// @Param group_id query int false "filter by group_id"
|
||||||
|
// @Success 200 {array} model.APIKey
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/api-keys [get]
|
||||||
|
func (h *Handler) ListAPIKeys(c *gin.Context) {
|
||||||
|
var keys []model.APIKey
|
||||||
|
q := h.db.Model(&model.APIKey{}).Order("id desc")
|
||||||
|
if groupID := strings.TrimSpace(c.Query("group_id")); groupID != "" {
|
||||||
|
q = q.Where("group_id = ?", groupID)
|
||||||
|
}
|
||||||
|
query := parseListQuery(c)
|
||||||
|
q = applyListPagination(q, query)
|
||||||
|
if err := q.Find(&keys).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list api keys", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIKey godoc
|
||||||
|
// @Summary Get API key
|
||||||
|
// @Description Get an API key by id
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "APIKey ID"
|
||||||
|
// @Success 200 {object} model.APIKey
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/api-keys/{id} [get]
|
||||||
|
func (h *Handler) GetAPIKey(c *gin.Context) {
|
||||||
|
id, ok := parseUintParam(c, "id")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var key model.APIKey
|
||||||
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAPIKey godoc
|
||||||
|
// @Summary Update API key
|
||||||
|
// @Description Update an API key
|
||||||
|
// @Tags admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "APIKey ID"
|
||||||
|
// @Param key body dto.APIKeyDTO true "API key payload"
|
||||||
|
// @Success 200 {object} model.APIKey
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/api-keys/{id} [put]
|
||||||
|
func (h *Handler) UpdateAPIKey(c *gin.Context) {
|
||||||
|
id, ok := parseUintParam(c, "id")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var key model.APIKey
|
||||||
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req dto.APIKeyDTO
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update := map[string]any{}
|
||||||
|
if req.GroupID != 0 {
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, req.GroupID).Error; err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "provider group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
update["group_id"] = req.GroupID
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.APIKey) != "" {
|
||||||
|
update["api_key"] = strings.TrimSpace(req.APIKey)
|
||||||
|
}
|
||||||
|
if req.Weight > 0 {
|
||||||
|
update["weight"] = normalizeWeight(req.Weight)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Status) != "" {
|
||||||
|
update["status"] = strings.TrimSpace(req.Status)
|
||||||
|
}
|
||||||
|
if req.AutoBan != nil {
|
||||||
|
update["auto_ban"] = *req.AutoBan
|
||||||
|
}
|
||||||
|
if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" {
|
||||||
|
update["ban_reason"] = strings.TrimSpace(req.BanReason)
|
||||||
|
}
|
||||||
|
if !req.BanUntil.IsZero() {
|
||||||
|
tu := req.BanUntil.UTC()
|
||||||
|
update["ban_until"] = &tu
|
||||||
|
}
|
||||||
|
if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" {
|
||||||
|
update["ban_until"] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.Model(&key).Updates(update).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update api key", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload api key", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAPIKey godoc
|
||||||
|
// @Summary Delete API key
|
||||||
|
// @Description Delete an API key
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "APIKey ID"
|
||||||
|
// @Success 200 {object} gin.H
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/api-keys/{id} [delete]
|
||||||
|
func (h *Handler) DeleteAPIKey(c *gin.Context) {
|
||||||
|
id, ok := parseUintParam(c, "id")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var key model.APIKey
|
||||||
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "api key not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.Delete(&key).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete api key", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||||
|
}
|
||||||
@@ -104,9 +104,9 @@ func (h *AdminHandler) BatchMasters(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchProviders godoc
|
// BatchAPIKeys godoc
|
||||||
// @Summary Batch providers
|
// @Summary Batch api keys
|
||||||
// @Description Batch delete or status update for providers
|
// @Description Batch delete or status update for api keys
|
||||||
// @Tags admin
|
// @Tags admin
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
@@ -115,8 +115,8 @@ func (h *AdminHandler) BatchMasters(c *gin.Context) {
|
|||||||
// @Success 200 {object} BatchResponse
|
// @Success 200 {object} BatchResponse
|
||||||
// @Failure 400 {object} gin.H
|
// @Failure 400 {object} gin.H
|
||||||
// @Failure 500 {object} gin.H
|
// @Failure 500 {object} gin.H
|
||||||
// @Router /admin/providers/batch [post]
|
// @Router /admin/api-keys/batch [post]
|
||||||
func (h *Handler) BatchProviders(c *gin.Context) {
|
func (h *Handler) BatchAPIKeys(c *gin.Context) {
|
||||||
var req BatchActionRequest
|
var req BatchActionRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||||
@@ -140,8 +140,8 @@ func (h *Handler) BatchProviders(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
needsBindingSync := false
|
needsBindingSync := false
|
||||||
for _, id := range req.IDs {
|
for _, id := range req.IDs {
|
||||||
var p model.Provider
|
var key model.APIKey
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: "not found"})
|
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: "not found"})
|
||||||
continue
|
continue
|
||||||
@@ -151,11 +151,7 @@ func (h *Handler) BatchProviders(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
switch action {
|
switch action {
|
||||||
case "delete":
|
case "delete":
|
||||||
if err := h.db.Delete(&p).Error; err != nil {
|
if err := h.db.Delete(&key).Error; err != nil {
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncProviderDelete(&p); err != nil {
|
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -167,15 +163,11 @@ func (h *Handler) BatchProviders(c *gin.Context) {
|
|||||||
update["ban_reason"] = ""
|
update["ban_reason"] = ""
|
||||||
update["ban_until"] = nil
|
update["ban_until"] = nil
|
||||||
}
|
}
|
||||||
if err := h.db.Model(&p).Updates(update).Error; err != nil {
|
if err := h.db.Model(&key).Updates(update).Error; err != nil {
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
if err := h.db.First(&key, id).Error; err != nil {
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncProvider(&p); err != nil {
|
|
||||||
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
resp.Failed = append(resp.Failed, BatchResult{ID: id, Error: err.Error()})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -184,6 +176,10 @@ func (h *Handler) BatchProviders(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if needsBindingSync {
|
if needsBindingSync {
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/dto"
|
"github.com/ez-api/ez-api/internal/dto"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
groupx "github.com/ez-api/foundation/group"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateBinding godoc
|
// CreateBinding godoc
|
||||||
// @Summary Create a new binding
|
// @Summary Create a new binding
|
||||||
// @Description Create a new (namespace, public_model) binding to a route group and selector
|
// @Description Create a new (namespace, public_model) binding to a provider group and selector
|
||||||
// @Tags admin
|
// @Tags admin
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
@@ -32,14 +32,14 @@ func (h *Handler) CreateBinding(c *gin.Context) {
|
|||||||
|
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
ns := strings.TrimSpace(req.Namespace)
|
||||||
pm := strings.TrimSpace(req.PublicModel)
|
pm := strings.TrimSpace(req.PublicModel)
|
||||||
if ns == "" || pm == "" {
|
if ns == "" || pm == "" || req.GroupID == 0 {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "namespace and public_model required"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "namespace, public_model, and group_id required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rg := groupx.Normalize(req.RouteGroup)
|
if err := h.ensureActiveGroup(req.GroupID); err != nil {
|
||||||
if strings.TrimSpace(rg) == "" {
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
rg = "default"
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
st := strings.TrimSpace(req.Status)
|
st := strings.TrimSpace(req.Status)
|
||||||
@@ -55,7 +55,8 @@ func (h *Handler) CreateBinding(c *gin.Context) {
|
|||||||
b := model.Binding{
|
b := model.Binding{
|
||||||
Namespace: ns,
|
Namespace: ns,
|
||||||
PublicModel: pm,
|
PublicModel: pm,
|
||||||
RouteGroup: rg,
|
GroupID: req.GroupID,
|
||||||
|
Weight: normalizeWeight(req.Weight),
|
||||||
SelectorType: selectorType,
|
SelectorType: selectorType,
|
||||||
SelectorValue: strings.TrimSpace(req.SelectorValue),
|
SelectorValue: strings.TrimSpace(req.SelectorValue),
|
||||||
Status: st,
|
Status: st,
|
||||||
@@ -82,7 +83,7 @@ func (h *Handler) CreateBinding(c *gin.Context) {
|
|||||||
// @Security AdminAuth
|
// @Security AdminAuth
|
||||||
// @Param page query int false "page (1-based)"
|
// @Param page query int false "page (1-based)"
|
||||||
// @Param limit query int false "limit (default 50, max 200)"
|
// @Param limit query int false "limit (default 50, max 200)"
|
||||||
// @Param search query string false "search by namespace/public_model/route_group"
|
// @Param search query string false "search by namespace/public_model"
|
||||||
// @Success 200 {array} model.Binding
|
// @Success 200 {array} model.Binding
|
||||||
// @Failure 500 {object} gin.H
|
// @Failure 500 {object} gin.H
|
||||||
// @Router /admin/bindings [get]
|
// @Router /admin/bindings [get]
|
||||||
@@ -90,7 +91,7 @@ func (h *Handler) ListBindings(c *gin.Context) {
|
|||||||
var out []model.Binding
|
var out []model.Binding
|
||||||
q := h.db.Model(&model.Binding{}).Order("id desc")
|
q := h.db.Model(&model.Binding{}).Order("id desc")
|
||||||
query := parseListQuery(c)
|
query := parseListQuery(c)
|
||||||
q = applyListSearch(q, query.Search, "namespace", "public_model", "route_group")
|
q = applyListSearch(q, query.Search, "namespace", "public_model")
|
||||||
q = applyListPagination(q, query)
|
q = applyListPagination(q, query)
|
||||||
if err := q.Find(&out).Error; err != nil {
|
if err := q.Find(&out).Error; err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list bindings", "details": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list bindings", "details": err.Error()})
|
||||||
@@ -139,8 +140,15 @@ func (h *Handler) UpdateBinding(c *gin.Context) {
|
|||||||
if pm := strings.TrimSpace(req.PublicModel); pm != "" {
|
if pm := strings.TrimSpace(req.PublicModel); pm != "" {
|
||||||
existing.PublicModel = pm
|
existing.PublicModel = pm
|
||||||
}
|
}
|
||||||
if rg := strings.TrimSpace(req.RouteGroup); rg != "" {
|
if req.GroupID != 0 {
|
||||||
existing.RouteGroup = groupx.Normalize(rg)
|
if err := h.ensureActiveGroup(req.GroupID); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existing.GroupID = req.GroupID
|
||||||
|
}
|
||||||
|
if req.Weight > 0 {
|
||||||
|
existing.Weight = normalizeWeight(req.Weight)
|
||||||
}
|
}
|
||||||
if st := strings.TrimSpace(req.Status); st != "" {
|
if st := strings.TrimSpace(req.Status); st != "" {
|
||||||
existing.Status = st
|
existing.Status = st
|
||||||
@@ -229,3 +237,30 @@ func (h *Handler) DeleteBinding(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeWeight(weight int) int {
|
||||||
|
if weight <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return weight
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ensureActiveGroup(groupID uint) error {
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, groupID).Error; err != nil {
|
||||||
|
return fmt.Errorf("provider group not found")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(group.Status) != "" && strings.TrimSpace(group.Status) != "active" {
|
||||||
|
return fmt.Errorf("provider group not active")
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := h.db.Model(&model.APIKey{}).
|
||||||
|
Where("group_id = ? AND status = ?", groupID, "active").
|
||||||
|
Count(&count).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to check api keys")
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
return fmt.Errorf("provider group has no active api keys")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"github.com/ez-api/ez-api/internal/dto"
|
"github.com/ez-api/ez-api/internal/dto"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/ez-api/internal/service"
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
groupx "github.com/ez-api/foundation/group"
|
|
||||||
"github.com/ez-api/foundation/provider"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -54,255 +52,6 @@ func (h *Handler) logBaseQuery() *gorm.DB {
|
|||||||
|
|
||||||
// CreateKey is now handled by MasterHandler
|
// CreateKey is now handled by MasterHandler
|
||||||
|
|
||||||
// CreateProvider godoc
|
|
||||||
// @Summary Create a new provider
|
|
||||||
// @Description Register a new upstream AI provider
|
|
||||||
// @Tags admin
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param provider body dto.ProviderDTO true "Provider Info"
|
|
||||||
// @Success 201 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers [post]
|
|
||||||
func (h *Handler) CreateProvider(c *gin.Context) {
|
|
||||||
var req dto.ProviderDTO
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
providerType := provider.NormalizeType(req.Type)
|
|
||||||
baseURL := strings.TrimSpace(req.BaseURL)
|
|
||||||
googleLocation := provider.DefaultGoogleLocation(providerType, req.GoogleLocation)
|
|
||||||
|
|
||||||
group := strings.TrimSpace(req.Group)
|
|
||||||
if group == "" {
|
|
||||||
group = "default"
|
|
||||||
}
|
|
||||||
|
|
||||||
status := strings.TrimSpace(req.Status)
|
|
||||||
if status == "" {
|
|
||||||
status = "active"
|
|
||||||
}
|
|
||||||
autoBan := true
|
|
||||||
if req.AutoBan != nil {
|
|
||||||
autoBan = *req.AutoBan
|
|
||||||
}
|
|
||||||
|
|
||||||
// CP-side defaults + validation to prevent DP runtime errors.
|
|
||||||
switch providerType {
|
|
||||||
case provider.TypeOpenAI:
|
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = "https://api.openai.com/v1"
|
|
||||||
}
|
|
||||||
case provider.TypeAnthropic, provider.TypeClaude:
|
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = "https://api.anthropic.com"
|
|
||||||
}
|
|
||||||
case provider.TypeCompatible:
|
|
||||||
if baseURL == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Google SDK providers: base_url is not required.
|
|
||||||
if provider.IsVertexFamily(providerType) && strings.TrimSpace(googleLocation) == "" {
|
|
||||||
googleLocation = provider.DefaultGoogleLocation(providerType, "")
|
|
||||||
}
|
|
||||||
// For Gemini API providers, api_key is required.
|
|
||||||
if provider.IsGoogleFamily(providerType) && !provider.IsVertexFamily(providerType) && strings.TrimSpace(req.APIKey) == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
provider := model.Provider{
|
|
||||||
Name: req.Name,
|
|
||||||
Type: strings.TrimSpace(req.Type),
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: req.APIKey,
|
|
||||||
GoogleProject: strings.TrimSpace(req.GoogleProject),
|
|
||||||
GoogleLocation: googleLocation,
|
|
||||||
Group: group,
|
|
||||||
Models: strings.Join(req.Models, ","),
|
|
||||||
Status: status,
|
|
||||||
AutoBan: autoBan,
|
|
||||||
BanReason: req.BanReason,
|
|
||||||
Weight: req.Weight,
|
|
||||||
}
|
|
||||||
if !req.BanUntil.IsZero() {
|
|
||||||
tu := req.BanUntil.UTC()
|
|
||||||
provider.BanUntil = &tu
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Create(&provider).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sync.SyncProvider(&provider); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Provider model list changes can affect binding upstream mappings; rebuild bindings snapshot.
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateProvider godoc
|
|
||||||
// @Summary Update a provider
|
|
||||||
// @Description Update provider attributes including status/auto-ban flags
|
|
||||||
// @Tags admin
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param id path int true "Provider ID"
|
|
||||||
// @Param provider body dto.ProviderDTO true "Provider Info"
|
|
||||||
// @Success 200 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 404 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/{id} [put]
|
|
||||||
func (h *Handler) UpdateProvider(c *gin.Context) {
|
|
||||||
idParam := c.Param("id")
|
|
||||||
id, err := strconv.Atoi(idParam)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var existing model.Provider
|
|
||||||
if err := h.db.First(&existing, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req dto.ProviderDTO
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
nextType := strings.TrimSpace(existing.Type)
|
|
||||||
if t := strings.TrimSpace(req.Type); t != "" {
|
|
||||||
nextType = t
|
|
||||||
}
|
|
||||||
nextTypeLower := provider.NormalizeType(nextType)
|
|
||||||
nextBaseURL := strings.TrimSpace(existing.BaseURL)
|
|
||||||
if strings.TrimSpace(req.BaseURL) != "" {
|
|
||||||
nextBaseURL = strings.TrimSpace(req.BaseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := map[string]any{}
|
|
||||||
if strings.TrimSpace(req.Name) != "" {
|
|
||||||
update["name"] = req.Name
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Type) != "" {
|
|
||||||
update["type"] = strings.TrimSpace(req.Type)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.BaseURL) != "" {
|
|
||||||
update["base_url"] = req.BaseURL
|
|
||||||
}
|
|
||||||
if req.APIKey != "" {
|
|
||||||
update["api_key"] = req.APIKey
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.GoogleProject) != "" {
|
|
||||||
update["google_project"] = strings.TrimSpace(req.GoogleProject)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.GoogleLocation) != "" {
|
|
||||||
update["google_location"] = strings.TrimSpace(req.GoogleLocation)
|
|
||||||
} else if provider.IsVertexFamily(nextTypeLower) && strings.TrimSpace(existing.GoogleLocation) == "" {
|
|
||||||
update["google_location"] = provider.DefaultGoogleLocation(nextTypeLower, "")
|
|
||||||
}
|
|
||||||
if req.Models != nil {
|
|
||||||
update["models"] = strings.Join(req.Models, ",")
|
|
||||||
}
|
|
||||||
if req.Weight > 0 {
|
|
||||||
update["weight"] = req.Weight
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Group) != "" {
|
|
||||||
update["group"] = groupx.Normalize(req.Group)
|
|
||||||
}
|
|
||||||
if req.AutoBan != nil {
|
|
||||||
update["auto_ban"] = *req.AutoBan
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Status) != "" {
|
|
||||||
update["status"] = req.Status
|
|
||||||
}
|
|
||||||
if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" {
|
|
||||||
update["ban_reason"] = req.BanReason
|
|
||||||
}
|
|
||||||
if !req.BanUntil.IsZero() {
|
|
||||||
tu := req.BanUntil.UTC()
|
|
||||||
update["ban_until"] = &tu
|
|
||||||
}
|
|
||||||
if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" {
|
|
||||||
update["ban_until"] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defaults/validation after considering intended type/base_url.
|
|
||||||
switch nextTypeLower {
|
|
||||||
case provider.TypeOpenAI:
|
|
||||||
if nextBaseURL == "" {
|
|
||||||
update["base_url"] = "https://api.openai.com/v1"
|
|
||||||
}
|
|
||||||
case provider.TypeAnthropic, provider.TypeClaude:
|
|
||||||
if nextBaseURL == "" {
|
|
||||||
update["base_url"] = "https://api.anthropic.com"
|
|
||||||
}
|
|
||||||
case provider.TypeCompatible:
|
|
||||||
if nextBaseURL == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if provider.IsGoogleFamily(nextTypeLower) && !provider.IsVertexFamily(nextTypeLower) {
|
|
||||||
// Ensure Gemini API providers have api_key.
|
|
||||||
// If update does not include api_key, keep existing; otherwise require new one not empty.
|
|
||||||
apiKey := existing.APIKey
|
|
||||||
if req.APIKey != "" {
|
|
||||||
apiKey = req.APIKey
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api providers"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(update) == 0 {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Model(&existing).Updates(update).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.First(&existing, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sync.SyncProvider(&existing); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, existing)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateModel godoc
|
// CreateModel godoc
|
||||||
// @Summary Register a new model
|
// @Summary Register a new model
|
||||||
// @Description Register a supported model with its capabilities
|
// @Description Register a supported model with its capabilities
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func newTestHandlerWithRedis(t *testing.T) (*Handler, *gorm.DB, *miniredis.Minir
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,10 +204,18 @@ func TestBatchModels_Delete(t *testing.T) {
|
|||||||
func TestBatchBindings_Status(t *testing.T) {
|
func TestBatchBindings_Status(t *testing.T) {
|
||||||
h, db := newTestHandler(t)
|
h, db := newTestHandler(t)
|
||||||
|
|
||||||
|
group := model.ProviderGroup{Name: "default", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m1", Status: "active"}
|
||||||
|
if err := db.Create(&group).Error; err != nil {
|
||||||
|
t.Fatalf("create group: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.APIKey{GroupID: group.ID, APIKey: "k", Status: "active"}).Error; err != nil {
|
||||||
|
t.Fatalf("create api key: %v", err)
|
||||||
|
}
|
||||||
b := &model.Binding{
|
b := &model.Binding{
|
||||||
Namespace: "ns",
|
Namespace: "ns",
|
||||||
PublicModel: "m1",
|
PublicModel: "m1",
|
||||||
RouteGroup: "default",
|
GroupID: group.ID,
|
||||||
|
Weight: 1,
|
||||||
SelectorType: "exact",
|
SelectorType: "exact",
|
||||||
SelectorValue: "m1",
|
SelectorValue: "m1",
|
||||||
Status: "active",
|
Status: "active",
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func newTestHandlerWithNamespace(t *testing.T) (*Handler, *gorm.DB, *miniredis.M
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Namespace{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Namespace{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,10 +39,23 @@ func newTestHandlerWithNamespace(t *testing.T) (*Handler, *gorm.DB, *miniredis.M
|
|||||||
func TestNamespaceCRUD_DeleteCleansBindings(t *testing.T) {
|
func TestNamespaceCRUD_DeleteCleansBindings(t *testing.T) {
|
||||||
h, db, _ := newTestHandlerWithNamespace(t)
|
h, db, _ := newTestHandlerWithNamespace(t)
|
||||||
|
|
||||||
|
group := model.ProviderGroup{Name: "g1", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m1", Status: "active"}
|
||||||
|
if err := db.Create(&group).Error; err != nil {
|
||||||
|
t.Fatalf("create group: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.APIKey{
|
||||||
|
GroupID: group.ID,
|
||||||
|
APIKey: "k1",
|
||||||
|
Status: "active",
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("create api key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := db.Create(&model.Binding{
|
if err := db.Create(&model.Binding{
|
||||||
Namespace: "ns1",
|
Namespace: "ns1",
|
||||||
PublicModel: "m1",
|
PublicModel: "m1",
|
||||||
RouteGroup: "default",
|
GroupID: group.ID,
|
||||||
|
Weight: 1,
|
||||||
SelectorType: "exact",
|
SelectorType: "exact",
|
||||||
SelectorValue: "m1",
|
SelectorValue: "m1",
|
||||||
Status: "active",
|
Status: "active",
|
||||||
|
|||||||
@@ -1,350 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
|
||||||
"github.com/ez-api/foundation/provider"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ListProviders godoc
|
|
||||||
// @Summary List providers
|
|
||||||
// @Description List all configured upstream providers
|
|
||||||
// @Tags admin
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param page query int false "page (1-based)"
|
|
||||||
// @Param limit query int false "limit (default 50, max 200)"
|
|
||||||
// @Param search query string false "search by name/type/base_url/group"
|
|
||||||
// @Success 200 {array} model.Provider
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers [get]
|
|
||||||
func (h *Handler) ListProviders(c *gin.Context) {
|
|
||||||
var providers []model.Provider
|
|
||||||
q := h.db.Model(&model.Provider{}).Order("id desc")
|
|
||||||
query := parseListQuery(c)
|
|
||||||
q = applyListSearch(q, query.Search, "name", `"type"`, "base_url", `"group"`)
|
|
||||||
q = applyListPagination(q, query)
|
|
||||||
if err := q.Find(&providers).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list providers", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, providers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProvider godoc
|
|
||||||
// @Summary Get provider
|
|
||||||
// @Description Get a provider by id
|
|
||||||
// @Tags admin
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param id path int true "Provider ID"
|
|
||||||
// @Success 200 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 404 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/{id} [get]
|
|
||||||
func (h *Handler) GetProvider(c *gin.Context) {
|
|
||||||
id, ok := parseUintParam(c, "id")
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var p model.Provider
|
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteProvider godoc
|
|
||||||
// @Summary Delete provider
|
|
||||||
// @Description Deletes a provider and triggers a full snapshot sync to avoid stale routing
|
|
||||||
// @Tags admin
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param id path int true "Provider ID"
|
|
||||||
// @Success 200 {object} gin.H
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 404 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/{id} [delete]
|
|
||||||
func (h *Handler) DeleteProvider(c *gin.Context) {
|
|
||||||
id, ok := parseUintParam(c, "id")
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var p model.Provider
|
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Delete(&p).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sync.SyncProviderDelete(&p); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider delete", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
|
||||||
}
|
|
||||||
|
|
||||||
type testProviderResponse struct {
|
|
||||||
StatusCode int `json:"status_code"`
|
|
||||||
OK bool `json:"ok"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Body string `json:"body,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestProvider godoc
|
|
||||||
// @Summary Test provider connectivity
|
|
||||||
// @Description Performs a lightweight upstream request to verify the provider configuration
|
|
||||||
// @Tags admin
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param id path int true "Provider ID"
|
|
||||||
// @Success 200 {object} testProviderResponse
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 404 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/{id}/test [post]
|
|
||||||
func (h *Handler) TestProvider(c *gin.Context) {
|
|
||||||
id, ok := parseUintParam(c, "id")
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var p model.Provider
|
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := buildProviderModelsRequest(&p)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, testProviderResponse{StatusCode: 0, OK: false, URL: req.URL.String(), Body: err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
|
||||||
ok = resp.StatusCode >= 200 && resp.StatusCode < 300
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, testProviderResponse{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
OK: ok,
|
|
||||||
URL: req.URL.String(),
|
|
||||||
Body: string(body),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchProviderModels godoc
|
|
||||||
// @Summary Fetch models from provider
|
|
||||||
// @Description Calls upstream /models (or /v1/models) and updates provider model list
|
|
||||||
// @Tags admin
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param id path int true "Provider ID"
|
|
||||||
// @Success 200 {object} gin.H
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 404 {object} gin.H
|
|
||||||
// @Failure 502 {object} gin.H
|
|
||||||
// @Router /admin/providers/{id}/fetch-models [post]
|
|
||||||
func (h *Handler) FetchProviderModels(c *gin.Context) {
|
|
||||||
id, ok := parseUintParam(c, "id")
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var p model.Provider
|
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := buildProviderModelsRequest(&p)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 15 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to fetch models", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": "upstream returned non-2xx",
|
|
||||||
"status_code": resp.StatusCode,
|
|
||||||
"body": string(body),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
models, err := parseProviderModelIDs(body)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to parse models", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Model(&p).Update("models", strings.Join(models, ",")).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update provider models", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.db.First(&p, id).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reload provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncProvider(&p); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"status": "updated",
|
|
||||||
"count": len(models),
|
|
||||||
"models": models,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildProviderModelsRequest(p *model.Provider) (*http.Request, error) {
|
|
||||||
if p == nil {
|
|
||||||
return nil, fmt.Errorf("provider required")
|
|
||||||
}
|
|
||||||
pt := provider.NormalizeType(p.Type)
|
|
||||||
baseURL := strings.TrimRight(strings.TrimSpace(p.BaseURL), "/")
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, fmt.Errorf("base_url required for provider models fetch")
|
|
||||||
}
|
|
||||||
|
|
||||||
url := ""
|
|
||||||
switch pt {
|
|
||||||
case provider.TypeOpenAI, provider.TypeCompatible:
|
|
||||||
if strings.HasSuffix(baseURL, "/v1") {
|
|
||||||
url = baseURL + "/models"
|
|
||||||
} else {
|
|
||||||
url = baseURL + "/v1/models"
|
|
||||||
}
|
|
||||||
case provider.TypeAnthropic, provider.TypeClaude:
|
|
||||||
url = baseURL + "/v1/models"
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("provider type not supported for model fetch")
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("build request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := strings.TrimSpace(p.APIKey)
|
|
||||||
switch pt {
|
|
||||||
case provider.TypeOpenAI, provider.TypeCompatible:
|
|
||||||
if apiKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
}
|
|
||||||
case provider.TypeAnthropic, provider.TypeClaude:
|
|
||||||
if apiKey != "" {
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
|
||||||
}
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type providerModelsResponse struct {
|
|
||||||
Data []json.RawMessage `json:"data"`
|
|
||||||
Models []string `json:"models"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseProviderModelIDs(payload []byte) ([]string, error) {
|
|
||||||
var resp providerModelsResponse
|
|
||||||
if err := json.Unmarshal(payload, &resp); err != nil {
|
|
||||||
return nil, fmt.Errorf("decode response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
models := make([]string, 0, len(resp.Data)+len(resp.Models))
|
|
||||||
for _, name := range resp.Models {
|
|
||||||
name = strings.TrimSpace(name)
|
|
||||||
if name != "" {
|
|
||||||
models = append(models, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, raw := range resp.Data {
|
|
||||||
var item struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &item); err == nil {
|
|
||||||
if item.ID != "" {
|
|
||||||
models = append(models, strings.TrimSpace(item.ID))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if item.Model != "" {
|
|
||||||
models = append(models, strings.TrimSpace(item.Model))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if item.Name != "" {
|
|
||||||
models = append(models, strings.TrimSpace(item.Name))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var name string
|
|
||||||
if err := json.Unmarshal(raw, &name); err == nil {
|
|
||||||
name = strings.TrimSpace(name)
|
|
||||||
if name != "" {
|
|
||||||
models = append(models, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unique := make(map[string]struct{}, len(models))
|
|
||||||
out := make([]string, 0, len(models))
|
|
||||||
for _, name := range models {
|
|
||||||
name = strings.TrimSpace(name)
|
|
||||||
if name == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := unique[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
unique[name] = struct{}{}
|
|
||||||
out = append(out, name)
|
|
||||||
}
|
|
||||||
if len(out) == 0 {
|
|
||||||
return nil, fmt.Errorf("no models found in response")
|
|
||||||
}
|
|
||||||
sort.Strings(out)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAdmin_TestProvider_OpenAICompatible(t *testing.T) {
|
|
||||||
h, db := newTestHandler(t)
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path != "/v1/models" {
|
|
||||||
http.NotFound(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got := r.Header.Get("Authorization"); got != "Bearer k" {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"object":"list","data":[]}`))
|
|
||||||
}))
|
|
||||||
defer upstream.Close()
|
|
||||||
|
|
||||||
p := &model.Provider{
|
|
||||||
Name: "p1",
|
|
||||||
Type: "openai",
|
|
||||||
BaseURL: upstream.URL + "/v1",
|
|
||||||
APIKey: "k",
|
|
||||||
Group: "default",
|
|
||||||
Models: "gpt-4o-mini",
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
if err := db.Create(p).Error; err != nil {
|
|
||||||
t.Fatalf("create provider: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/:id/test", h.TestProvider)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/1/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if ok, _ := payload["ok"].(bool); !ok {
|
|
||||||
t.Fatalf("expected ok=true, got %v body=%s", payload["ok"], rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdmin_FetchProviderModels_OpenAICompatible(t *testing.T) {
|
|
||||||
h, db := newTestHandler(t)
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path != "/v1/models" {
|
|
||||||
http.NotFound(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got := r.Header.Get("Authorization"); got != "Bearer k" {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"gpt-4o-mini"},{"id":"gpt-4o"}]}`))
|
|
||||||
}))
|
|
||||||
defer upstream.Close()
|
|
||||||
|
|
||||||
p := &model.Provider{
|
|
||||||
Name: "p1",
|
|
||||||
Type: "openai",
|
|
||||||
BaseURL: upstream.URL + "/v1",
|
|
||||||
APIKey: "k",
|
|
||||||
Group: "default",
|
|
||||||
Models: "old-model",
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
if err := db.Create(p).Error; err != nil {
|
|
||||||
t.Fatalf("create provider: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/:id/fetch-models", h.FetchProviderModels)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/1/fetch-models", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
var updated model.Provider
|
|
||||||
if err := db.First(&updated, p.ID).Error; err != nil {
|
|
||||||
t.Fatalf("reload provider: %v", err)
|
|
||||||
}
|
|
||||||
if updated.Models != "gpt-4o,gpt-4o-mini" {
|
|
||||||
t.Fatalf("expected models to update, got %q", updated.Models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdmin_BatchProviders_Status(t *testing.T) {
|
|
||||||
h, db := newTestHandler(t)
|
|
||||||
|
|
||||||
banUntil := time.Now().Add(2 * time.Hour).UTC()
|
|
||||||
p := &model.Provider{
|
|
||||||
Name: "p1",
|
|
||||||
Type: "openai",
|
|
||||||
BaseURL: "https://api.openai.com/v1",
|
|
||||||
Group: "default",
|
|
||||||
Models: "gpt-4o-mini",
|
|
||||||
Status: "manual_disabled",
|
|
||||||
BanReason: "bad",
|
|
||||||
BanUntil: &banUntil,
|
|
||||||
}
|
|
||||||
if err := db.Create(p).Error; err != nil {
|
|
||||||
t.Fatalf("create provider: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/batch", h.BatchProviders)
|
|
||||||
|
|
||||||
payload := map[string]any{
|
|
||||||
"action": "status",
|
|
||||||
"status": "active",
|
|
||||||
"ids": []uint{p.ID},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(payload)
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/batch", bytes.NewReader(b))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
var updated model.Provider
|
|
||||||
if err := db.First(&updated, p.ID).Error; err != nil {
|
|
||||||
t.Fatalf("reload provider: %v", err)
|
|
||||||
}
|
|
||||||
if updated.Status != "active" {
|
|
||||||
t.Fatalf("expected status active, got %q", updated.Status)
|
|
||||||
}
|
|
||||||
if updated.BanReason != "" {
|
|
||||||
t.Fatalf("expected ban_reason cleared, got %q", updated.BanReason)
|
|
||||||
}
|
|
||||||
if updated.BanUntil != nil {
|
|
||||||
t.Fatalf("expected ban_until cleared, got %v", updated.BanUntil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/dto"
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
|
||||||
groupx "github.com/ez-api/foundation/group"
|
|
||||||
providerx "github.com/ez-api/foundation/provider"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CreateProviderPreset godoc
|
|
||||||
// @Summary Create a preset provider
|
|
||||||
// @Description Create an official OpenAI/Anthropic provider (only api_key is typically required)
|
|
||||||
// @Tags admin
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param provider body dto.ProviderPresetCreateDTO true "Provider preset payload"
|
|
||||||
// @Success 201 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/preset [post]
|
|
||||||
func (h *Handler) CreateProviderPreset(c *gin.Context) {
|
|
||||||
var req dto.ProviderPresetCreateDTO
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
preset := providerx.NormalizeType(req.Preset)
|
|
||||||
if preset == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "preset required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var providerType string
|
|
||||||
var baseURL string
|
|
||||||
switch preset {
|
|
||||||
case providerx.TypeOpenAI:
|
|
||||||
providerType = providerx.TypeOpenAI
|
|
||||||
baseURL = "https://api.openai.com/v1"
|
|
||||||
case providerx.TypeAnthropic, providerx.TypeClaude:
|
|
||||||
providerType = providerx.TypeAnthropic
|
|
||||||
baseURL = "https://api.anthropic.com"
|
|
||||||
default:
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported preset: " + preset + " (use /admin/providers/google for Google SDK providers)"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := strings.TrimSpace(req.Name)
|
|
||||||
if name == "" {
|
|
||||||
name = providerType + "-" + randomSuffix(4)
|
|
||||||
}
|
|
||||||
group := strings.TrimSpace(req.Group)
|
|
||||||
if group == "" {
|
|
||||||
group = "default"
|
|
||||||
}
|
|
||||||
status := strings.TrimSpace(req.Status)
|
|
||||||
if status == "" {
|
|
||||||
status = "active"
|
|
||||||
}
|
|
||||||
autoBan := true
|
|
||||||
if req.AutoBan != nil {
|
|
||||||
autoBan = *req.AutoBan
|
|
||||||
}
|
|
||||||
|
|
||||||
googleLocation := providerx.DefaultGoogleLocation(providerType, req.GoogleLocation)
|
|
||||||
|
|
||||||
p := model.Provider{
|
|
||||||
Name: name,
|
|
||||||
Type: providerType,
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: strings.TrimSpace(req.APIKey),
|
|
||||||
GoogleProject: strings.TrimSpace(req.GoogleProject),
|
|
||||||
GoogleLocation: googleLocation,
|
|
||||||
Group: groupx.Normalize(group),
|
|
||||||
Models: strings.Join(req.Models, ","),
|
|
||||||
Status: status,
|
|
||||||
AutoBan: autoBan,
|
|
||||||
}
|
|
||||||
if req.Weight > 0 {
|
|
||||||
p.Weight = req.Weight
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Create(&p).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sync.SyncProvider(&p); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateProviderGoogle godoc
|
|
||||||
// @Summary Create a Google SDK provider
|
|
||||||
// @Description Create a Google SDK provider (Gemini API key or Vertex project/location); base_url is not used
|
|
||||||
// @Tags admin
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param provider body dto.ProviderGoogleCreateDTO true "Google provider payload"
|
|
||||||
// @Success 201 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/google [post]
|
|
||||||
func (h *Handler) CreateProviderGoogle(c *gin.Context) {
|
|
||||||
var req dto.ProviderGoogleCreateDTO
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pt := providerx.NormalizeType(req.Type)
|
|
||||||
if pt == "" {
|
|
||||||
pt = providerx.TypeGemini
|
|
||||||
}
|
|
||||||
if !providerx.IsGoogleFamily(pt) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "type must be google family"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := strings.TrimSpace(req.Name)
|
|
||||||
if name == "" {
|
|
||||||
name = pt + "-" + randomSuffix(4)
|
|
||||||
}
|
|
||||||
group := strings.TrimSpace(req.Group)
|
|
||||||
if group == "" {
|
|
||||||
group = "default"
|
|
||||||
}
|
|
||||||
status := strings.TrimSpace(req.Status)
|
|
||||||
if status == "" {
|
|
||||||
status = "active"
|
|
||||||
}
|
|
||||||
autoBan := true
|
|
||||||
if req.AutoBan != nil {
|
|
||||||
autoBan = *req.AutoBan
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate fields by type.
|
|
||||||
apiKey := strings.TrimSpace(req.APIKey)
|
|
||||||
googleProject := strings.TrimSpace(req.GoogleProject)
|
|
||||||
googleLocation := providerx.DefaultGoogleLocation(pt, req.GoogleLocation)
|
|
||||||
|
|
||||||
if providerx.IsVertexFamily(pt) {
|
|
||||||
// Vertex uses ADC and project/location; api_key is not required.
|
|
||||||
if strings.TrimSpace(googleLocation) == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "google_location required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
apiKey = ""
|
|
||||||
} else {
|
|
||||||
// Gemini API requires api_key.
|
|
||||||
if apiKey == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "api_key required for gemini api"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
googleProject = ""
|
|
||||||
googleLocation = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
p := model.Provider{
|
|
||||||
Name: name,
|
|
||||||
Type: pt,
|
|
||||||
BaseURL: "", // intentionally unused for Google SDK
|
|
||||||
APIKey: apiKey,
|
|
||||||
GoogleProject: googleProject,
|
|
||||||
GoogleLocation: googleLocation,
|
|
||||||
Group: groupx.Normalize(group),
|
|
||||||
Models: strings.Join(req.Models, ","),
|
|
||||||
Status: status,
|
|
||||||
AutoBan: autoBan,
|
|
||||||
}
|
|
||||||
if req.Weight > 0 {
|
|
||||||
p.Weight = req.Weight
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Create(&p).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncProvider(&p); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusCreated, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateProviderCustom godoc
|
|
||||||
// @Summary Create a custom provider
|
|
||||||
// @Description Create an OpenAI-compatible provider (base_url + api_key required)
|
|
||||||
// @Tags admin
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Security AdminAuth
|
|
||||||
// @Param provider body dto.ProviderCustomCreateDTO true "Provider custom payload"
|
|
||||||
// @Success 201 {object} model.Provider
|
|
||||||
// @Failure 400 {object} gin.H
|
|
||||||
// @Failure 500 {object} gin.H
|
|
||||||
// @Router /admin/providers/custom [post]
|
|
||||||
func (h *Handler) CreateProviderCustom(c *gin.Context) {
|
|
||||||
var req dto.ProviderCustomCreateDTO
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := strings.TrimSpace(req.Name)
|
|
||||||
if name == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
baseURL := strings.TrimSpace(req.BaseURL)
|
|
||||||
if baseURL == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
group := strings.TrimSpace(req.Group)
|
|
||||||
if group == "" {
|
|
||||||
group = "default"
|
|
||||||
}
|
|
||||||
status := strings.TrimSpace(req.Status)
|
|
||||||
if status == "" {
|
|
||||||
status = "active"
|
|
||||||
}
|
|
||||||
autoBan := true
|
|
||||||
if req.AutoBan != nil {
|
|
||||||
autoBan = *req.AutoBan
|
|
||||||
}
|
|
||||||
|
|
||||||
p := model.Provider{
|
|
||||||
Name: name,
|
|
||||||
Type: providerx.TypeCompatible,
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: strings.TrimSpace(req.APIKey),
|
|
||||||
Group: groupx.Normalize(group),
|
|
||||||
Models: strings.Join(req.Models, ","),
|
|
||||||
Status: status,
|
|
||||||
AutoBan: autoBan,
|
|
||||||
}
|
|
||||||
if req.Weight > 0 {
|
|
||||||
p.Weight = req.Weight
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.db.Create(&p).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sync.SyncProvider(&p); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.sync.SyncBindings(h.db); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomSuffix(bytesLen int) string {
|
|
||||||
if bytesLen <= 0 {
|
|
||||||
bytesLen = 4
|
|
||||||
}
|
|
||||||
b := make([]byte, bytesLen)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
return "rand"
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreateProviderPreset_OpenAI_SetsBaseURL(t *testing.T) {
|
|
||||||
h, _ := newTestHandler(t)
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/preset", h.CreateProviderPreset)
|
|
||||||
|
|
||||||
reqBody := map[string]any{
|
|
||||||
"preset": "openai",
|
|
||||||
"api_key": "k",
|
|
||||||
"models": []string{"gpt-4o-mini"},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(reqBody)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/preset", bytes.NewReader(b))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusCreated {
|
|
||||||
t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
var got model.Provider
|
|
||||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if got.Type != "openai" {
|
|
||||||
t.Fatalf("expected type openai, got %q", got.Type)
|
|
||||||
}
|
|
||||||
if got.BaseURL != "https://api.openai.com/v1" {
|
|
||||||
t.Fatalf("expected base_url=https://api.openai.com/v1, got %q", got.BaseURL)
|
|
||||||
}
|
|
||||||
if got.Name == "" {
|
|
||||||
t.Fatalf("expected generated name")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateProviderCustom_RequiresBaseURL(t *testing.T) {
|
|
||||||
h, _ := newTestHandler(t)
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/custom", h.CreateProviderCustom)
|
|
||||||
|
|
||||||
reqBody := map[string]any{
|
|
||||||
"name": "c1",
|
|
||||||
"api_key": "k",
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(reqBody)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/custom", bytes.NewReader(b))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateProviderGoogle_GeminiRequiresAPIKey(t *testing.T) {
|
|
||||||
h, _ := newTestHandler(t)
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers/google", h.CreateProviderGoogle)
|
|
||||||
|
|
||||||
reqBody := map[string]any{
|
|
||||||
"type": "gemini",
|
|
||||||
"models": []string{"gemini-2.0-flash"},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(reqBody)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers/google", bytes.NewReader(b))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
302
internal/api/provider_group_handler.go
Normal file
302
internal/api/provider_group_handler.go
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ez-api/ez-api/internal/dto"
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/ez-api/foundation/provider"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateProviderGroup godoc
|
||||||
|
// @Summary Create a provider group
|
||||||
|
// @Description Create a provider group definition
|
||||||
|
// @Tags admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param group body dto.ProviderGroupDTO true "Provider group payload"
|
||||||
|
// @Success 201 {object} model.ProviderGroup
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/provider-groups [post]
|
||||||
|
func (h *Handler) CreateProviderGroup(c *gin.Context) {
|
||||||
|
var req dto.ProviderGroupDTO
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(req.Name)
|
||||||
|
if name == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "name required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ptype := provider.NormalizeType(req.Type)
|
||||||
|
if ptype == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "type required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSpace(req.BaseURL)
|
||||||
|
googleLocation := provider.DefaultGoogleLocation(ptype, req.GoogleLocation)
|
||||||
|
|
||||||
|
switch ptype {
|
||||||
|
case provider.TypeOpenAI:
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
case provider.TypeAnthropic, provider.TypeClaude:
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.anthropic.com"
|
||||||
|
}
|
||||||
|
case provider.TypeCompatible:
|
||||||
|
if baseURL == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if provider.IsVertexFamily(ptype) && strings.TrimSpace(googleLocation) == "" {
|
||||||
|
googleLocation = provider.DefaultGoogleLocation(ptype, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status := strings.TrimSpace(req.Status)
|
||||||
|
if status == "" {
|
||||||
|
status = "active"
|
||||||
|
}
|
||||||
|
|
||||||
|
group := model.ProviderGroup{
|
||||||
|
Name: name,
|
||||||
|
Type: strings.TrimSpace(req.Type),
|
||||||
|
BaseURL: baseURL,
|
||||||
|
GoogleProject: strings.TrimSpace(req.GoogleProject),
|
||||||
|
GoogleLocation: googleLocation,
|
||||||
|
Models: strings.Join(req.Models, ","),
|
||||||
|
Status: status,
|
||||||
|
}
|
||||||
|
if err := h.db.Create(&group).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider group", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProviderGroups godoc
|
||||||
|
// @Summary List provider groups
|
||||||
|
// @Description List all provider groups
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param page query int false "page (1-based)"
|
||||||
|
// @Param limit query int false "limit (default 50, max 200)"
|
||||||
|
// @Param search query string false "search by name/type"
|
||||||
|
// @Success 200 {array} model.ProviderGroup
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/provider-groups [get]
|
||||||
|
func (h *Handler) ListProviderGroups(c *gin.Context) {
|
||||||
|
var groups []model.ProviderGroup
|
||||||
|
q := h.db.Model(&model.ProviderGroup{}).Order("id desc")
|
||||||
|
query := parseListQuery(c)
|
||||||
|
q = applyListSearch(q, query.Search, "name", "type")
|
||||||
|
q = applyListPagination(q, query)
|
||||||
|
if err := q.Find(&groups).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list provider groups", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderGroup godoc
|
||||||
|
// @Summary Get provider group
|
||||||
|
// @Description Get a provider group by id
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "ProviderGroup ID"
|
||||||
|
// @Success 200 {object} model.ProviderGroup
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/provider-groups/{id} [get]
|
||||||
|
func (h *Handler) GetProviderGroup(c *gin.Context) {
|
||||||
|
id, ok := parseUintParam(c, "id")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateProviderGroup godoc
|
||||||
|
// @Summary Update provider group
|
||||||
|
// @Description Update a provider group
|
||||||
|
// @Tags admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "ProviderGroup ID"
|
||||||
|
// @Param group body dto.ProviderGroupDTO true "Provider group payload"
|
||||||
|
// @Success 200 {object} model.ProviderGroup
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/provider-groups/{id} [put]
|
||||||
|
func (h *Handler) UpdateProviderGroup(c *gin.Context) {
|
||||||
|
idParam := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idParam)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.ProviderGroupDTO
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nextType := strings.TrimSpace(group.Type)
|
||||||
|
if t := strings.TrimSpace(req.Type); t != "" {
|
||||||
|
nextType = t
|
||||||
|
}
|
||||||
|
nextTypeLower := provider.NormalizeType(nextType)
|
||||||
|
nextBaseURL := strings.TrimSpace(group.BaseURL)
|
||||||
|
if strings.TrimSpace(req.BaseURL) != "" {
|
||||||
|
nextBaseURL = strings.TrimSpace(req.BaseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := map[string]any{}
|
||||||
|
if strings.TrimSpace(req.Name) != "" {
|
||||||
|
update["name"] = strings.TrimSpace(req.Name)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Type) != "" {
|
||||||
|
update["type"] = strings.TrimSpace(req.Type)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.BaseURL) != "" {
|
||||||
|
update["base_url"] = strings.TrimSpace(req.BaseURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.GoogleProject) != "" {
|
||||||
|
update["google_project"] = strings.TrimSpace(req.GoogleProject)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.GoogleLocation) != "" {
|
||||||
|
update["google_location"] = strings.TrimSpace(req.GoogleLocation)
|
||||||
|
} else if provider.IsVertexFamily(nextTypeLower) && strings.TrimSpace(group.GoogleLocation) == "" {
|
||||||
|
update["google_location"] = provider.DefaultGoogleLocation(nextTypeLower, "")
|
||||||
|
}
|
||||||
|
if req.Models != nil {
|
||||||
|
update["models"] = strings.Join(req.Models, ",")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Status) != "" {
|
||||||
|
update["status"] = strings.TrimSpace(req.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch nextTypeLower {
|
||||||
|
case provider.TypeOpenAI:
|
||||||
|
if nextBaseURL == "" {
|
||||||
|
update["base_url"] = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
case provider.TypeAnthropic, provider.TypeClaude:
|
||||||
|
if nextBaseURL == "" {
|
||||||
|
update["base_url"] = "https://api.anthropic.com"
|
||||||
|
}
|
||||||
|
case provider.TypeCompatible:
|
||||||
|
if nextBaseURL == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required for compatible providers"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.Model(&group).Updates(update).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider group", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.First(&group, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider group", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProviderGroup godoc
|
||||||
|
// @Summary Delete provider group
|
||||||
|
// @Description Delete a provider group and its api keys/bindings
|
||||||
|
// @Tags admin
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param id path int true "ProviderGroup ID"
|
||||||
|
// @Success 200 {object} gin.H
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 404 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/provider-groups/{id} [delete]
|
||||||
|
func (h *Handler) DeleteProviderGroup(c *gin.Context) {
|
||||||
|
id, ok := parseUintParam(c, "id")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := h.db.First(&group, id).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "provider group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Where("group_id = ?", group.ID).Delete(&model.APIKey{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("group_id = ?", group.ID).Delete(&model.Binding{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Delete(&group).Error
|
||||||
|
}); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider group", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.sync.SyncProviders(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync providers", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.sync.SyncBindings(h.db); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||||
|
}
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/alicebob/miniredis/v2"
|
|
||||||
"github.com/ez-api/ez-api/internal/dto"
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
|
||||||
"github.com/ez-api/ez-api/internal/service"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestHandler(t *testing.T) (*Handler, *gorm.DB) {
|
|
||||||
t.Helper()
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
// Use a unique in-memory DB per test to avoid cross-test interference.
|
|
||||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
|
||||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("open sqlite: %v", err)
|
|
||||||
}
|
|
||||||
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.LogRecord{}); err != nil {
|
|
||||||
t.Fatalf("migrate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mr := miniredis.RunT(t)
|
|
||||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
|
||||||
sync := service.NewSyncService(rdb)
|
|
||||||
|
|
||||||
return NewHandler(db, db, sync, nil, rdb, nil), db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateProvider_DefaultsVertexLocationGlobal(t *testing.T) {
|
|
||||||
h, _ := newTestHandler(t)
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.POST("/admin/providers", h.CreateProvider)
|
|
||||||
|
|
||||||
reqBody := dto.ProviderDTO{
|
|
||||||
Name: "g1",
|
|
||||||
Type: "vertex-express",
|
|
||||||
Group: "default",
|
|
||||||
Models: []string{"gemini-3-pro-preview"},
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(reqBody)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/admin/providers", bytes.NewReader(b))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusCreated {
|
|
||||||
t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
var got model.Provider
|
|
||||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if got.GoogleLocation != "global" {
|
|
||||||
t.Fatalf("expected google_location=global, got %q", got.GoogleLocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateProvider_DefaultsVertexLocationGlobalWhenMissing(t *testing.T) {
|
|
||||||
h, db := newTestHandler(t)
|
|
||||||
|
|
||||||
existing := &model.Provider{
|
|
||||||
Name: "g2",
|
|
||||||
Type: "vertex",
|
|
||||||
Group: "default",
|
|
||||||
Models: "gemini-3-pro-preview",
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
if err := db.Create(existing).Error; err != nil {
|
|
||||||
t.Fatalf("create provider: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
|
||||||
r.PUT("/admin/providers/:id", h.UpdateProvider)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/providers/%d", existing.ID), bytes.NewReader([]byte(`{}`)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
|
|
||||||
}
|
|
||||||
var got model.Provider
|
|
||||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
|
||||||
t.Fatalf("unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if got.GoogleLocation != "global" {
|
|
||||||
t.Fatalf("expected google_location=global, got %q", got.GoogleLocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
14
internal/dto/api_key.go
Normal file
14
internal/dto/api_key.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// APIKeyDTO defines inbound payload for API key creation/update.
|
||||||
|
type APIKeyDTO struct {
|
||||||
|
GroupID uint `json:"group_id"`
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
Weight int `json:"weight,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
AutoBan *bool `json:"auto_ban,omitempty"`
|
||||||
|
BanReason string `json:"ban_reason,omitempty"`
|
||||||
|
BanUntil time.Time `json:"ban_until,omitempty"`
|
||||||
|
}
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
// BindingDTO defines inbound payload for binding creation/update.
|
// BindingDTO defines inbound payload for binding creation/update.
|
||||||
// It maps "(namespace, public_model)" to a RouteGroup and an upstream selector.
|
// It maps "(namespace, public_model)" to a ProviderGroup and an upstream selector.
|
||||||
type BindingDTO struct {
|
type BindingDTO struct {
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
PublicModel string `json:"public_model"`
|
PublicModel string `json:"public_model"`
|
||||||
RouteGroup string `json:"route_group"`
|
GroupID uint `json:"group_id"`
|
||||||
|
Weight int `json:"weight"`
|
||||||
SelectorType string `json:"selector_type"`
|
SelectorType string `json:"selector_type"`
|
||||||
SelectorValue string `json:"selector_value"`
|
SelectorValue string `json:"selector_value"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
|
|||||||
12
internal/dto/provider_group.go
Normal file
12
internal/dto/provider_group.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
// ProviderGroupDTO defines inbound payload for provider group creation/update.
|
||||||
|
type ProviderGroupDTO struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
GoogleProject string `json:"google_project,omitempty"`
|
||||||
|
GoogleLocation string `json:"google_location,omitempty"`
|
||||||
|
Models []string `json:"models"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
@@ -258,14 +258,45 @@ func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error {
|
func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error {
|
||||||
|
groupCache := make(map[string]model.ProviderGroup)
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
name := strings.TrimSpace(item.Name)
|
groupName := normalizeGroup(item.PrimaryGroup)
|
||||||
if name == "" {
|
if strings.TrimSpace(groupName) == "" {
|
||||||
summary.Warnings = append(summary.Warnings, "skip provider with empty name")
|
groupName = "default"
|
||||||
|
}
|
||||||
|
group, ok := groupCache[groupName]
|
||||||
|
if !ok {
|
||||||
|
var existing model.ProviderGroup
|
||||||
|
err := i.db.Where("name = ?", groupName).First(&existing).Error
|
||||||
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
group = existing
|
||||||
|
} else {
|
||||||
|
group = model.ProviderGroup{
|
||||||
|
Name: groupName,
|
||||||
|
Type: strings.TrimSpace(item.Type),
|
||||||
|
BaseURL: strings.TrimSpace(item.BaseURL),
|
||||||
|
Models: strings.Join(item.Models, ","),
|
||||||
|
Status: normalizeStatus(item.Status, "active"),
|
||||||
|
}
|
||||||
|
if !i.opts.DryRun {
|
||||||
|
if err := i.db.Create(&group).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groupCache[groupName] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := strings.TrimSpace(item.APIKey)
|
||||||
|
if apiKey == "" {
|
||||||
|
summary.Warnings = append(summary.Warnings, "skip api key with empty api_key")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var existing model.Provider
|
var existingKey model.APIKey
|
||||||
err := i.db.Where("name = ?", name).First(&existing).Error
|
err := i.db.Where("group_id = ? AND api_key = ?", group.ID, apiKey).First(&existingKey).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -277,16 +308,11 @@ func (i *Importer) importProviders(items []Provider, summary *ImportSummary) err
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
update := map[string]any{
|
update := map[string]any{
|
||||||
"type": strings.TrimSpace(item.Type),
|
|
||||||
"base_url": strings.TrimSpace(item.BaseURL),
|
|
||||||
"api_key": strings.TrimSpace(item.APIKey),
|
|
||||||
"group": normalizeGroup(item.PrimaryGroup),
|
|
||||||
"models": strings.Join(item.Models, ","),
|
|
||||||
"weight": resolveWeight(item.Weight, item.Priority),
|
"weight": resolveWeight(item.Weight, item.Priority),
|
||||||
"status": normalizeProviderStatus(item.Status),
|
"status": normalizeProviderStatus(item.Status),
|
||||||
"auto_ban": item.AutoBan,
|
"auto_ban": item.AutoBan,
|
||||||
}
|
}
|
||||||
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
|
if err := i.db.Model(&existingKey).Updates(update).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
summary.ProvidersUpdated++
|
summary.ProvidersUpdated++
|
||||||
@@ -301,18 +327,14 @@ func (i *Importer) importProviders(items []Provider, summary *ImportSummary) err
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
provider := model.Provider{
|
key := model.APIKey{
|
||||||
Name: name,
|
GroupID: group.ID,
|
||||||
Type: strings.TrimSpace(item.Type),
|
APIKey: apiKey,
|
||||||
BaseURL: strings.TrimSpace(item.BaseURL),
|
|
||||||
APIKey: strings.TrimSpace(item.APIKey),
|
|
||||||
Group: normalizeGroup(item.PrimaryGroup),
|
|
||||||
Models: strings.Join(item.Models, ","),
|
|
||||||
Weight: resolveWeight(item.Weight, item.Priority),
|
Weight: resolveWeight(item.Weight, item.Priority),
|
||||||
Status: normalizeProviderStatus(item.Status),
|
Status: normalizeProviderStatus(item.Status),
|
||||||
AutoBan: item.AutoBan,
|
AutoBan: item.AutoBan,
|
||||||
}
|
}
|
||||||
if err := i.db.Create(&provider).Error; err != nil {
|
if err := i.db.Create(&key).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
summary.ProvidersCreated++
|
summary.ProvidersCreated++
|
||||||
@@ -420,8 +442,14 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error
|
|||||||
summary.Warnings = append(summary.Warnings, "skip binding with empty model")
|
summary.Warnings = append(summary.Warnings, "skip binding with empty model")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
groupName := normalizeGroup(item.RouteGroup)
|
||||||
|
var group model.ProviderGroup
|
||||||
|
if err := i.db.Where("name = ?", groupName).First(&group).Error; err != nil {
|
||||||
|
summary.Warnings = append(summary.Warnings, "skip binding with missing provider group: "+groupName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
var existing model.Binding
|
var existing model.Binding
|
||||||
err := i.db.Where("namespace = ? AND public_model = ?", ns, publicModel).First(&existing).Error
|
err := i.db.Where("namespace = ? AND public_model = ? AND group_id = ?", ns, publicModel, group.ID).First(&existing).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -433,7 +461,8 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
update := map[string]any{
|
update := map[string]any{
|
||||||
"route_group": normalizeGroup(item.RouteGroup),
|
"group_id": group.ID,
|
||||||
|
"weight": 1,
|
||||||
"selector_type": "exact",
|
"selector_type": "exact",
|
||||||
"selector_value": publicModel,
|
"selector_value": publicModel,
|
||||||
"status": normalizeStatus(item.Status, "active"),
|
"status": normalizeStatus(item.Status, "active"),
|
||||||
@@ -456,7 +485,8 @@ func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error
|
|||||||
binding := model.Binding{
|
binding := model.Binding{
|
||||||
Namespace: ns,
|
Namespace: ns,
|
||||||
PublicModel: publicModel,
|
PublicModel: publicModel,
|
||||||
RouteGroup: normalizeGroup(item.RouteGroup),
|
GroupID: group.ID,
|
||||||
|
Weight: 1,
|
||||||
SelectorType: "exact",
|
SelectorType: "exact",
|
||||||
SelectorValue: publicModel,
|
SelectorValue: publicModel,
|
||||||
Status: normalizeStatus(item.Status, "active"),
|
Status: normalizeStatus(item.Status, "active"),
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ type Key struct {
|
|||||||
// Binding represents an EZ-API binding (optional, from abilities).
|
// Binding represents an EZ-API binding (optional, from abilities).
|
||||||
type Binding struct {
|
type Binding struct {
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
RouteGroup string `json:"route_group"`
|
RouteGroup string `json:"route_group"` // provider group name
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import "gorm.io/gorm"
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Admin is not a database model. It's configured via environment variables.
|
// Admin is not a database model. It's configured via environment variables.
|
||||||
|
|
||||||
@@ -50,24 +46,6 @@ type Key struct {
|
|||||||
QuotaResetType string `gorm:"size:20" json:"quota_reset_type"`
|
QuotaResetType string `gorm:"size:20" json:"quota_reset_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provider remains the same.
|
|
||||||
type Provider struct {
|
|
||||||
gorm.Model
|
|
||||||
Name string `gorm:"not null" json:"name"`
|
|
||||||
Type string `gorm:"not null" json:"type"` // openai, anthropic, etc.
|
|
||||||
BaseURL string `json:"base_url"`
|
|
||||||
APIKey string `json:"api_key"`
|
|
||||||
GoogleProject string `gorm:"size:128" json:"google_project,omitempty"`
|
|
||||||
GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"`
|
|
||||||
Group string `gorm:"default:'default'" json:"group"` // routing group/tier
|
|
||||||
Models string `json:"models"` // comma-separated list of supported models (e.g. "gpt-4,gpt-3.5-turbo")
|
|
||||||
Weight int `gorm:"default:1" json:"weight"` // routing weight inside route_group
|
|
||||||
Status string `gorm:"size:50;default:'active'" json:"status"` // active, auto_disabled, manual_disabled
|
|
||||||
AutoBan bool `gorm:"default:true" json:"auto_ban"` // whether DP-triggered disable is allowed
|
|
||||||
BanReason string `gorm:"size:255" json:"ban_reason"` // reason for current disable
|
|
||||||
BanUntil *time.Time `json:"ban_until"` // optional TTL for disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model remains the same.
|
// Model remains the same.
|
||||||
type Model struct {
|
type Model struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
@@ -82,13 +60,13 @@ type Model struct {
|
|||||||
MaxOutputTokens int `json:"max_output_tokens"`
|
MaxOutputTokens int `json:"max_output_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binding defines a stable "namespace.public_model" routing key and its target RouteGroup + selector.
|
// Binding defines a stable "namespace.public_model" routing key and its target ProviderGroup + selector.
|
||||||
// RouteGroup currently reuses Provider.Group.
|
|
||||||
type Binding struct {
|
type Binding struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Namespace string `gorm:"size:100;not null;index:idx_binding_key,unique" json:"namespace"`
|
Namespace string `gorm:"size:100;not null;index:idx_binding_key,unique" json:"namespace"`
|
||||||
PublicModel string `gorm:"size:255;not null;index:idx_binding_key,unique" json:"public_model"`
|
PublicModel string `gorm:"size:255;not null;index:idx_binding_key,unique" json:"public_model"`
|
||||||
RouteGroup string `gorm:"size:100;not null" json:"route_group"`
|
GroupID uint `gorm:"not null;index:idx_binding_key,unique" json:"group_id"`
|
||||||
|
Weight int `gorm:"default:1" json:"weight"`
|
||||||
SelectorType string `gorm:"size:50;default:'exact'" json:"selector_type"`
|
SelectorType string `gorm:"size:50;default:'exact'" json:"selector_type"`
|
||||||
SelectorValue string `gorm:"size:255" json:"selector_value"`
|
SelectorValue string `gorm:"size:255" json:"selector_value"`
|
||||||
Status string `gorm:"size:50;default:'active'" json:"status"`
|
Status string `gorm:"size:50;default:'active'" json:"status"`
|
||||||
|
|||||||
31
internal/model/provider_group.go
Normal file
31
internal/model/provider_group.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderGroup represents a shared upstream definition.
|
||||||
|
type ProviderGroup struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"size:255;uniqueIndex;not null" json:"name"`
|
||||||
|
Type string `gorm:"size:50;not null" json:"type"` // openai, anthropic, gemini
|
||||||
|
BaseURL string `gorm:"size:512;not null" json:"base_url"`
|
||||||
|
GoogleProject string `gorm:"size:128" json:"google_project,omitempty"`
|
||||||
|
GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"`
|
||||||
|
Models string `json:"models"` // comma-separated list of supported models
|
||||||
|
Status string `gorm:"size:50;default:'active'" json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKey represents a credential within a provider group.
|
||||||
|
type APIKey struct {
|
||||||
|
gorm.Model
|
||||||
|
GroupID uint `gorm:"not null;index" json:"group_id"`
|
||||||
|
APIKey string `gorm:"not null" json:"api_key"`
|
||||||
|
Weight int `gorm:"default:1" json:"weight"`
|
||||||
|
Status string `gorm:"size:50;default:'active'" json:"status"`
|
||||||
|
AutoBan bool `gorm:"default:true" json:"auto_ban"`
|
||||||
|
BanReason string `gorm:"size:255" json:"ban_reason"`
|
||||||
|
BanUntil *time.Time `json:"ban_until"`
|
||||||
|
}
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
groupx "github.com/ez-api/foundation/group"
|
|
||||||
"github.com/ez-api/foundation/jsoncodec"
|
"github.com/ez-api/foundation/jsoncodec"
|
||||||
"github.com/ez-api/foundation/modelcap"
|
"github.com/ez-api/foundation/modelcap"
|
||||||
"github.com/ez-api/foundation/routing"
|
"github.com/ez-api/foundation/routing"
|
||||||
@@ -373,6 +372,33 @@ type upstreamCap struct {
|
|||||||
SupportsTools boolVal
|
SupportsTools boolVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func boolValEqual(a, b boolVal) bool {
|
||||||
|
if a.Known != b.Known {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !a.Known {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return a.Val == b.Val
|
||||||
|
}
|
||||||
|
|
||||||
|
func intValEqual(a, b intVal) bool {
|
||||||
|
if a.Known != b.Known {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !a.Known {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return a.Val == b.Val
|
||||||
|
}
|
||||||
|
|
||||||
|
func capsEqual(a, b upstreamCap) bool {
|
||||||
|
return boolValEqual(a.SupportsVision, b.SupportsVision) &&
|
||||||
|
boolValEqual(a.SupportsTools, b.SupportsTools) &&
|
||||||
|
intValEqual(a.ContextWindow, b.ContextWindow) &&
|
||||||
|
intValEqual(a.MaxOutputTokens, b.MaxOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
type modelsDevRegistry struct {
|
type modelsDevRegistry struct {
|
||||||
ByProviderModel map[string]upstreamCap // key: providerID|modelID
|
ByProviderModel map[string]upstreamCap // key: providerID|modelID
|
||||||
ByModel map[string]upstreamCap // fallback: modelID
|
ByModel map[string]upstreamCap // fallback: modelID
|
||||||
@@ -707,9 +733,13 @@ func (a *capAgg) finalize(name string) modelcap.Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
|
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
|
||||||
var providers []model.Provider
|
var groups []model.ProviderGroup
|
||||||
if err := s.db.Find(&providers).Error; err != nil {
|
if err := s.db.Find(&groups).Error; err != nil {
|
||||||
return nil, nil, fmt.Errorf("load providers: %w", err)
|
return nil, nil, fmt.Errorf("load provider groups: %w", err)
|
||||||
|
}
|
||||||
|
var apiKeys []model.APIKey
|
||||||
|
if err := s.db.Find(&apiKeys).Error; err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("load api keys: %w", err)
|
||||||
}
|
}
|
||||||
var bindings []model.Binding
|
var bindings []model.Binding
|
||||||
if err := s.db.Find(&bindings).Error; err != nil {
|
if err := s.db.Find(&bindings).Error; err != nil {
|
||||||
@@ -718,21 +748,29 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
|||||||
|
|
||||||
type providerLite struct {
|
type providerLite struct {
|
||||||
id uint
|
id uint
|
||||||
group string
|
|
||||||
ptype string
|
ptype string
|
||||||
models []string
|
models []string
|
||||||
}
|
}
|
||||||
providersByGroup := make(map[string][]providerLite)
|
providersByGroupID := make(map[uint]providerLite)
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
for _, p := range providers {
|
activeKeys := make(map[uint]bool)
|
||||||
if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" {
|
for _, k := range apiKeys {
|
||||||
|
if strings.TrimSpace(k.Status) != "" && strings.TrimSpace(k.Status) != "active" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now {
|
if k.BanUntil != nil && k.BanUntil.UTC().Unix() > now {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
group := groupx.Normalize(p.Group)
|
activeKeys[k.GroupID] = true
|
||||||
rawModels := strings.Split(p.Models, ",")
|
}
|
||||||
|
for _, g := range groups {
|
||||||
|
if strings.TrimSpace(g.Status) != "" && strings.TrimSpace(g.Status) != "active" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !activeKeys[g.ID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rawModels := strings.Split(g.Models, ",")
|
||||||
var outModels []string
|
var outModels []string
|
||||||
for _, m := range rawModels {
|
for _, m := range rawModels {
|
||||||
m = strings.TrimSpace(m)
|
m = strings.TrimSpace(m)
|
||||||
@@ -740,19 +778,20 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
|||||||
outModels = append(outModels, m)
|
outModels = append(outModels, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if group == "" || len(outModels) == 0 {
|
if len(outModels) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
providersByGroupID[g.ID] = providerLite{
|
||||||
id: p.ID,
|
id: g.ID,
|
||||||
group: group,
|
ptype: strings.TrimSpace(g.Type),
|
||||||
ptype: strings.TrimSpace(p.Type),
|
|
||||||
models: outModels,
|
models: outModels,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelsOut := make(map[string]modelcap.Model)
|
modelsOut := make(map[string]modelcap.Model)
|
||||||
payloads := make(map[string]string)
|
payloads := make(map[string]string)
|
||||||
|
capBaseline := make(map[string]upstreamCap)
|
||||||
|
capBaselineOK := make(map[string]bool)
|
||||||
|
|
||||||
for _, b := range bindings {
|
for _, b := range bindings {
|
||||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||||
@@ -764,12 +803,8 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
key := ns + "." + pm
|
key := ns + "." + pm
|
||||||
rg := groupx.Normalize(b.RouteGroup)
|
group := providersByGroupID[b.GroupID]
|
||||||
if rg == "" {
|
if group.id == 0 {
|
||||||
continue
|
|
||||||
}
|
|
||||||
pgroup := providersByGroup[rg]
|
|
||||||
if len(pgroup) == 0 {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -782,13 +817,25 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
|||||||
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
|
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
|
||||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||||
|
|
||||||
for _, p := range pgroup {
|
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, group.models)
|
||||||
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models)
|
if err == nil {
|
||||||
if err != nil {
|
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(group.ptype), up)
|
||||||
continue
|
if baseOK, seen := capBaselineOK[key]; seen {
|
||||||
|
if !ok || !baseOK || !capsEqual(capBaseline[key], cap) {
|
||||||
|
return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
capBaselineOK[key] = ok
|
||||||
|
if ok {
|
||||||
|
capBaseline[key] = cap
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up)
|
|
||||||
agg.merge(cap, ok)
|
agg.merge(cap, ok)
|
||||||
|
} else {
|
||||||
|
if _, seen := capBaselineOK[key]; seen {
|
||||||
|
return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key)
|
||||||
|
}
|
||||||
|
capBaselineOK[key] = false
|
||||||
}
|
}
|
||||||
|
|
||||||
out := agg.finalize(key)
|
out := agg.finalize(key)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestModelRegistry_Check(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,22 +58,31 @@ func TestModelRegistry_RefreshAndRollback(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.Create(&model.Provider{
|
group := model.ProviderGroup{
|
||||||
Name: "p1",
|
Name: "rg",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Group: "rg",
|
BaseURL: "https://api.openai.com/v1",
|
||||||
Models: "gpt-4o-mini",
|
Models: "gpt-4o-mini",
|
||||||
Status: "active",
|
Status: "active",
|
||||||
|
}
|
||||||
|
if err := db.Create(&group).Error; err != nil {
|
||||||
|
t.Fatalf("create provider group: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.APIKey{
|
||||||
|
GroupID: group.ID,
|
||||||
|
APIKey: "k",
|
||||||
|
Status: "active",
|
||||||
}).Error; err != nil {
|
}).Error; err != nil {
|
||||||
t.Fatalf("create provider: %v", err)
|
t.Fatalf("create api key: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.Create(&model.Binding{
|
if err := db.Create(&model.Binding{
|
||||||
Namespace: "ns",
|
Namespace: "ns",
|
||||||
PublicModel: "m",
|
PublicModel: "m",
|
||||||
RouteGroup: "rg",
|
GroupID: group.ID,
|
||||||
|
Weight: 1,
|
||||||
SelectorType: "exact",
|
SelectorType: "exact",
|
||||||
SelectorValue: "gpt-4o-mini",
|
SelectorValue: "gpt-4o-mini",
|
||||||
Status: "active",
|
Status: "active",
|
||||||
|
|||||||
@@ -77,78 +77,29 @@ func (s *SyncService) SyncMaster(master *model.Master) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
// SyncProviders rebuilds provider snapshots from ProviderGroup + APIKey tables.
|
||||||
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
func (s *SyncService) SyncProviders(db *gorm.DB) error {
|
||||||
ctx := context.Background()
|
if db == nil {
|
||||||
group := groupx.Normalize(provider.Group)
|
return fmt.Errorf("db required")
|
||||||
models := strings.Split(provider.Models, ",")
|
|
||||||
|
|
||||||
snap := providerSnapshot{
|
|
||||||
ID: provider.ID,
|
|
||||||
Name: provider.Name,
|
|
||||||
Type: provider.Type,
|
|
||||||
BaseURL: provider.BaseURL,
|
|
||||||
APIKey: provider.APIKey,
|
|
||||||
GoogleProject: provider.GoogleProject,
|
|
||||||
GoogleLocation: provider.GoogleLocation,
|
|
||||||
Group: group,
|
|
||||||
Models: models,
|
|
||||||
Weight: provider.Weight,
|
|
||||||
Status: normalizeStatus(provider.Status),
|
|
||||||
AutoBan: provider.AutoBan,
|
|
||||||
BanReason: provider.BanReason,
|
|
||||||
}
|
|
||||||
if provider.BanUntil != nil {
|
|
||||||
snap.BanUntil = provider.BanUntil.UTC().Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Update Provider Config
|
|
||||||
if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id)
|
|
||||||
// Note: This is an additive operation. Removing models requires full sync or smarter logic.
|
|
||||||
pipe := s.rdb.Pipeline()
|
|
||||||
for _, m := range models {
|
|
||||||
m = strings.TrimSpace(m)
|
|
||||||
if m == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if snap.Status != "active" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
|
||||||
pipe.SAdd(ctx, routeKey, provider.ID)
|
|
||||||
}
|
|
||||||
_, err := pipe.Exec(ctx)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncProviderDelete removes provider snapshot and routing entries from Redis.
|
|
||||||
func (s *SyncService) SyncProviderDelete(provider *model.Provider) error {
|
|
||||||
if provider == nil {
|
|
||||||
return fmt.Errorf("provider required")
|
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
group := groupx.Normalize(provider.Group)
|
|
||||||
models := strings.Split(provider.Models, ",")
|
var groups []model.ProviderGroup
|
||||||
|
if err := db.Find(&groups).Error; err != nil {
|
||||||
|
return fmt.Errorf("load provider groups: %w", err)
|
||||||
|
}
|
||||||
|
var apiKeys []model.APIKey
|
||||||
|
if err := db.Find(&apiKeys).Error; err != nil {
|
||||||
|
return fmt.Errorf("load api keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
pipe := s.rdb.TxPipeline()
|
pipe := s.rdb.TxPipeline()
|
||||||
pipe.HDel(ctx, "config:providers", fmt.Sprintf("%d", provider.ID))
|
pipe.Del(ctx, "config:providers")
|
||||||
for _, m := range models {
|
if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil {
|
||||||
m = strings.TrimSpace(m)
|
return err
|
||||||
if m == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
|
||||||
pipe.SRem(ctx, routeKey, provider.ID)
|
|
||||||
}
|
}
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
return fmt.Errorf("delete provider snapshot: %w", err)
|
return fmt.Errorf("write provider snapshot: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -203,6 +154,7 @@ type providerSnapshot struct {
|
|||||||
APIKey string `json:"api_key"`
|
APIKey string `json:"api_key"`
|
||||||
GoogleProject string `json:"google_project,omitempty"`
|
GoogleProject string `json:"google_project,omitempty"`
|
||||||
GoogleLocation string `json:"google_location,omitempty"`
|
GoogleLocation string `json:"google_location,omitempty"`
|
||||||
|
GroupID uint `json:"group_id,omitempty"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
Models []string `json:"models"`
|
Models []string `json:"models"`
|
||||||
Weight int `json:"weight,omitempty"`
|
Weight int `json:"weight,omitempty"`
|
||||||
@@ -212,15 +164,100 @@ type providerSnapshot struct {
|
|||||||
BanUntil int64 `json:"ban_until,omitempty"` // unix seconds
|
BanUntil int64 `json:"ban_until,omitempty"` // unix seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pipeliner, groups []model.ProviderGroup, apiKeys []model.APIKey) error {
|
||||||
|
groupMap := make(map[uint]model.ProviderGroup, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
groupMap[g.ID] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range apiKeys {
|
||||||
|
g, ok := groupMap[k.GroupID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
groupName := groupx.Normalize(g.Name)
|
||||||
|
if strings.TrimSpace(groupName) == "" {
|
||||||
|
groupName = "default"
|
||||||
|
}
|
||||||
|
groupStatus := normalizeStatus(g.Status)
|
||||||
|
keyStatus := normalizeStatus(k.Status)
|
||||||
|
status := keyStatus
|
||||||
|
if groupStatus != "" && groupStatus != "active" {
|
||||||
|
status = groupStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
rawModels := strings.Split(g.Models, ",")
|
||||||
|
var models []string
|
||||||
|
for _, m := range rawModels {
|
||||||
|
m = strings.TrimSpace(m)
|
||||||
|
if m != "" {
|
||||||
|
models = append(models, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(g.Name)
|
||||||
|
if name == "" {
|
||||||
|
name = groupName
|
||||||
|
}
|
||||||
|
name = fmt.Sprintf("%s#%d", name, k.ID)
|
||||||
|
|
||||||
|
snap := providerSnapshot{
|
||||||
|
ID: k.ID,
|
||||||
|
Name: name,
|
||||||
|
Type: strings.TrimSpace(g.Type),
|
||||||
|
BaseURL: strings.TrimSpace(g.BaseURL),
|
||||||
|
APIKey: strings.TrimSpace(k.APIKey),
|
||||||
|
GoogleProject: strings.TrimSpace(g.GoogleProject),
|
||||||
|
GoogleLocation: strings.TrimSpace(g.GoogleLocation),
|
||||||
|
GroupID: g.ID,
|
||||||
|
Group: groupName,
|
||||||
|
Models: models,
|
||||||
|
Weight: k.Weight,
|
||||||
|
Status: status,
|
||||||
|
AutoBan: k.AutoBan,
|
||||||
|
BanReason: strings.TrimSpace(k.BanReason),
|
||||||
|
}
|
||||||
|
if k.BanUntil != nil {
|
||||||
|
snap.BanUntil = k.BanUntil.UTC().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := jsoncodec.Marshal(snap)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal provider %d: %w", k.ID, err)
|
||||||
|
}
|
||||||
|
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", k.ID), payload)
|
||||||
|
|
||||||
|
// Legacy route table maintenance for compatibility.
|
||||||
|
for _, m := range models {
|
||||||
|
if m == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if snap.Status != "active" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routeKey := fmt.Sprintf("route:group:%s:%s", groupName, m)
|
||||||
|
pipe.SAdd(ctx, routeKey, k.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// keySnapshot is no longer needed as we write directly to auth:token:*
|
// keySnapshot is no longer needed as we write directly to auth:token:*
|
||||||
|
|
||||||
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
|
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
|
||||||
func (s *SyncService) SyncAll(db *gorm.DB) error {
|
func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
var providers []model.Provider
|
var groups []model.ProviderGroup
|
||||||
if err := db.Find(&providers).Error; err != nil {
|
if err := db.Find(&groups).Error; err != nil {
|
||||||
return fmt.Errorf("load providers: %w", err)
|
return fmt.Errorf("load provider groups: %w", err)
|
||||||
|
}
|
||||||
|
var apiKeys []model.APIKey
|
||||||
|
if err := db.Find(&apiKeys).Error; err != nil {
|
||||||
|
return fmt.Errorf("load api keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var keys []model.Key
|
var keys []model.Key
|
||||||
@@ -259,53 +296,8 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
pipe.Del(ctx, masterKeys...)
|
pipe.Del(ctx, masterKeys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them)
|
if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil {
|
||||||
// For MVP, we rely on the fact that we are rebuilding.
|
return err
|
||||||
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
|
||||||
|
|
||||||
for _, p := range providers {
|
|
||||||
group := groupx.Normalize(p.Group)
|
|
||||||
models := strings.Split(p.Models, ",")
|
|
||||||
|
|
||||||
snap := providerSnapshot{
|
|
||||||
ID: p.ID,
|
|
||||||
Name: p.Name,
|
|
||||||
Type: p.Type,
|
|
||||||
BaseURL: p.BaseURL,
|
|
||||||
APIKey: p.APIKey,
|
|
||||||
GoogleProject: p.GoogleProject,
|
|
||||||
GoogleLocation: p.GoogleLocation,
|
|
||||||
Group: group,
|
|
||||||
Models: models,
|
|
||||||
Weight: p.Weight,
|
|
||||||
Status: normalizeStatus(p.Status),
|
|
||||||
AutoBan: p.AutoBan,
|
|
||||||
BanReason: p.BanReason,
|
|
||||||
}
|
|
||||||
if p.BanUntil != nil {
|
|
||||||
snap.BanUntil = p.BanUntil.UTC().Unix()
|
|
||||||
}
|
|
||||||
payload, err := jsoncodec.Marshal(snap)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
|
|
||||||
}
|
|
||||||
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
|
|
||||||
|
|
||||||
// Rebuild Routing Table
|
|
||||||
for _, m := range models {
|
|
||||||
m = strings.TrimSpace(m)
|
|
||||||
if m == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if snap.Status != "active" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
|
||||||
pipe.SAdd(ctx, routeKey, p.ID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
@@ -382,7 +374,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil {
|
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -398,9 +390,13 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
var providers []model.Provider
|
var groups []model.ProviderGroup
|
||||||
if err := db.Find(&providers).Error; err != nil {
|
if err := db.Find(&groups).Error; err != nil {
|
||||||
return fmt.Errorf("load providers: %w", err)
|
return fmt.Errorf("load provider groups: %w", err)
|
||||||
|
}
|
||||||
|
var apiKeys []model.APIKey
|
||||||
|
if err := db.Find(&apiKeys).Error; err != nil {
|
||||||
|
return fmt.Errorf("load api keys: %w", err)
|
||||||
}
|
}
|
||||||
var bindings []model.Binding
|
var bindings []model.Binding
|
||||||
if err := db.Find(&bindings).Error; err != nil {
|
if err := db.Find(&bindings).Error; err != nil {
|
||||||
@@ -409,7 +405,7 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
|||||||
|
|
||||||
pipe := s.rdb.TxPipeline()
|
pipe := s.rdb.TxPipeline()
|
||||||
pipe.Del(ctx, "config:bindings", "meta:bindings_meta")
|
pipe.Del(ctx, "config:bindings", "meta:bindings_meta")
|
||||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil {
|
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
@@ -418,32 +414,65 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, providers []model.Provider) error {
|
func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, groups []model.ProviderGroup, apiKeys []model.APIKey) error {
|
||||||
// Group providers by route group for selector resolution.
|
type groupLite struct {
|
||||||
type providerLite struct {
|
|
||||||
id uint
|
id uint
|
||||||
group string
|
name string
|
||||||
|
ptype string
|
||||||
|
baseURL string
|
||||||
|
googleProject string
|
||||||
|
googleLoc string
|
||||||
models []string
|
models []string
|
||||||
|
status string
|
||||||
}
|
}
|
||||||
providersByGroup := make(map[string][]providerLite)
|
groupsByID := make(map[uint]groupLite, len(groups))
|
||||||
for _, p := range providers {
|
for _, g := range groups {
|
||||||
group := groupx.Normalize(p.Group)
|
rawModels := strings.Split(g.Models, ",")
|
||||||
models := strings.Split(p.Models, ",")
|
|
||||||
var outModels []string
|
var outModels []string
|
||||||
for _, m := range models {
|
for _, m := range rawModels {
|
||||||
m = strings.TrimSpace(m)
|
m = strings.TrimSpace(m)
|
||||||
if m != "" {
|
if m != "" {
|
||||||
outModels = append(outModels, m)
|
outModels = append(outModels, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
groupsByID[g.ID] = groupLite{
|
||||||
id: p.ID,
|
id: g.ID,
|
||||||
group: group,
|
name: groupx.Normalize(g.Name),
|
||||||
|
ptype: strings.TrimSpace(g.Type),
|
||||||
|
baseURL: strings.TrimSpace(g.BaseURL),
|
||||||
|
googleProject: strings.TrimSpace(g.GoogleProject),
|
||||||
|
googleLoc: strings.TrimSpace(g.GoogleLocation),
|
||||||
models: outModels,
|
models: outModels,
|
||||||
|
status: normalizeStatus(g.Status),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiKeyLite struct {
|
||||||
|
id uint
|
||||||
|
groupID uint
|
||||||
|
status string
|
||||||
|
weight int
|
||||||
|
autoBan bool
|
||||||
|
banUntil *time.Time
|
||||||
|
}
|
||||||
|
keysByGroup := make(map[uint][]apiKeyLite)
|
||||||
|
for _, k := range apiKeys {
|
||||||
|
keysByGroup[k.GroupID] = append(keysByGroup[k.GroupID], apiKeyLite{
|
||||||
|
id: k.ID,
|
||||||
|
groupID: k.GroupID,
|
||||||
|
status: normalizeStatus(k.Status),
|
||||||
|
weight: k.Weight,
|
||||||
|
autoBan: k.AutoBan,
|
||||||
|
banUntil: k.BanUntil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type bindingAgg struct {
|
||||||
|
snap routing.BindingSnapshot
|
||||||
|
}
|
||||||
|
snaps := make(map[string]*routing.BindingSnapshot)
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
for _, b := range bindings {
|
for _, b := range bindings {
|
||||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||||
continue
|
continue
|
||||||
@@ -453,43 +482,65 @@ func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipe
|
|||||||
if ns == "" || pm == "" {
|
if ns == "" || pm == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rg := groupx.Normalize(b.RouteGroup)
|
group, ok := groupsByID[b.GroupID]
|
||||||
if rg == "" {
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if group.status != "" && group.status != "active" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
snap := struct {
|
key := ns + "." + pm
|
||||||
Namespace string `json:"namespace"`
|
snap := snaps[key]
|
||||||
PublicModel string `json:"public_model"`
|
if snap == nil {
|
||||||
RouteGroup string `json:"route_group"`
|
snap = &routing.BindingSnapshot{
|
||||||
SelectorType string `json:"selector_type,omitempty"`
|
|
||||||
SelectorValue string `json:"selector_value,omitempty"`
|
|
||||||
Status string `json:"status,omitempty"`
|
|
||||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
|
||||||
Upstreams map[string]string `json:"upstreams"`
|
|
||||||
}{
|
|
||||||
Namespace: ns,
|
Namespace: ns,
|
||||||
PublicModel: pm,
|
PublicModel: pm,
|
||||||
RouteGroup: rg,
|
Status: "active",
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
snaps[key] = snap
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate := routing.BindingCandidate{
|
||||||
|
GroupID: group.id,
|
||||||
|
RouteGroup: group.name,
|
||||||
|
Weight: normalizeWeight(b.Weight),
|
||||||
SelectorType: strings.TrimSpace(b.SelectorType),
|
SelectorType: strings.TrimSpace(b.SelectorType),
|
||||||
SelectorValue: strings.TrimSpace(b.SelectorValue),
|
SelectorValue: strings.TrimSpace(b.SelectorValue),
|
||||||
Status: "active",
|
Status: "active",
|
||||||
UpdatedAt: now,
|
|
||||||
Upstreams: make(map[string]string),
|
Upstreams: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
selectorType := strings.TrimSpace(b.SelectorType)
|
selectorType := strings.TrimSpace(b.SelectorType)
|
||||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||||
|
keys := keysByGroup[b.GroupID]
|
||||||
|
if len(keys) == 0 {
|
||||||
|
candidate.Error = "no_provider"
|
||||||
|
}
|
||||||
|
|
||||||
for _, p := range providersByGroup[rg] {
|
nowUnix := time.Now().Unix()
|
||||||
up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, p.models)
|
for _, k := range keys {
|
||||||
|
if k.status != "" && k.status != "active" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k.banUntil != nil && k.banUntil.UTC().Unix() > nowUnix {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, group.models)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
snap.Upstreams[fmt.Sprintf("%d", p.id)] = up
|
candidate.Upstreams[fmt.Sprintf("%d", k.id)] = up
|
||||||
|
}
|
||||||
|
if len(candidate.Upstreams) == 0 && candidate.Error == "" {
|
||||||
|
candidate.Error = "config_error"
|
||||||
}
|
}
|
||||||
|
|
||||||
key := ns + "." + pm
|
snap.Candidates = append(snap.Candidates, candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, snap := range snaps {
|
||||||
payload, err := jsoncodec.Marshal(snap)
|
payload, err := jsoncodec.Marshal(snap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal config:bindings:%s: %w", key, err)
|
return fmt.Errorf("marshal config:bindings:%s: %w", key, err)
|
||||||
@@ -519,6 +570,13 @@ func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val inter
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeWeight(weight int) int {
|
||||||
|
if weight <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return weight
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeStatus(status string) string {
|
func normalizeStatus(status string) string {
|
||||||
st := strings.ToLower(strings.TrimSpace(status))
|
st := strings.ToLower(strings.TrimSpace(status))
|
||||||
if st == "" {
|
if st == "" {
|
||||||
|
|||||||
@@ -15,8 +15,11 @@ import (
|
|||||||
type bindingSnapshot struct {
|
type bindingSnapshot struct {
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
PublicModel string `json:"public_model"`
|
PublicModel string `json:"public_model"`
|
||||||
|
Candidates []struct {
|
||||||
RouteGroup string `json:"route_group"`
|
RouteGroup string `json:"route_group"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
Upstreams map[string]string `json:"upstreams"`
|
Upstreams map[string]string `json:"upstreams"`
|
||||||
|
} `json:"candidates"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyncBindings_SelectorExact(t *testing.T) {
|
func TestSyncBindings_SelectorExact(t *testing.T) {
|
||||||
@@ -26,15 +29,19 @@ func TestSyncBindings_SelectorExact(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "m"}
|
group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m", Status: "active"}
|
||||||
if err := db.Create(&p).Error; err != nil {
|
if err := db.Create(&group).Error; err != nil {
|
||||||
t.Fatalf("create provider: %v", err)
|
t.Fatalf("create group: %v", err)
|
||||||
}
|
}
|
||||||
b := model.Binding{Namespace: "ns", PublicModel: "m", RouteGroup: "rg", SelectorType: "exact", Status: "active"}
|
key := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"}
|
||||||
|
if err := db.Create(&key).Error; err != nil {
|
||||||
|
t.Fatalf("create api key: %v", err)
|
||||||
|
}
|
||||||
|
b := model.Binding{Namespace: "ns", PublicModel: "m", GroupID: group.ID, Weight: 1, SelectorType: "exact", Status: "active"}
|
||||||
if err := db.Create(&b).Error; err != nil {
|
if err := db.Create(&b).Error; err != nil {
|
||||||
t.Fatalf("create binding: %v", err)
|
t.Fatalf("create binding: %v", err)
|
||||||
}
|
}
|
||||||
@@ -54,8 +61,11 @@ func TestSyncBindings_SelectorExact(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
||||||
t.Fatalf("unmarshal: %v", err)
|
t.Fatalf("unmarshal: %v", err)
|
||||||
}
|
}
|
||||||
if snap.Upstreams == nil || snap.Upstreams[jsonID(p.ID)] != "m" {
|
if len(snap.Candidates) != 1 {
|
||||||
t.Fatalf("unexpected upstreams: %+v", snap.Upstreams)
|
t.Fatalf("expected 1 candidate, got %+v", snap.Candidates)
|
||||||
|
}
|
||||||
|
if snap.Candidates[0].Upstreams == nil || snap.Candidates[0].Upstreams[jsonID(key.ID)] != "m" {
|
||||||
|
t.Fatalf("unexpected upstreams: %+v", snap.Candidates[0].Upstreams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,27 +76,31 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open sqlite: %v", err)
|
t.Fatalf("open sqlite: %v", err)
|
||||||
}
|
}
|
||||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil {
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p1 := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "moonshot/kimi2,kimi2"}
|
group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "moonshot/kimi2,kimi2", Status: "active"}
|
||||||
p2 := model.Provider{Name: "p2", Type: "openai", Group: "rg", Models: "moonshot/kimi2"}
|
if err := db.Create(&group).Error; err != nil {
|
||||||
if err := db.Create(&p1).Error; err != nil {
|
t.Fatalf("create group: %v", err)
|
||||||
t.Fatalf("create provider1: %v", err)
|
|
||||||
}
|
}
|
||||||
if err := db.Create(&p2).Error; err != nil {
|
k1 := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"}
|
||||||
t.Fatalf("create provider2: %v", err)
|
k2 := model.APIKey{GroupID: group.ID, APIKey: "k2", Status: "active"}
|
||||||
|
if err := db.Create(&k1).Error; err != nil {
|
||||||
|
t.Fatalf("create api key1: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&k2).Error; err != nil {
|
||||||
|
t.Fatalf("create api key2: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regex should match uniquely (moonshot/kimi2 only).
|
// Regex should match uniquely (moonshot/kimi2 only).
|
||||||
bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", RouteGroup: "rg", SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"}
|
bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", GroupID: group.ID, Weight: 1, SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"}
|
||||||
if err := db.Create(&bRegex).Error; err != nil {
|
if err := db.Create(&bRegex).Error; err != nil {
|
||||||
t.Fatalf("create binding regex: %v", err)
|
t.Fatalf("create binding regex: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize_exact should match p2 (moonshot/kimi2) for "kimi2".
|
// Normalize_exact should match p2 (moonshot/kimi2) for "kimi2".
|
||||||
bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", RouteGroup: "rg", SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"}
|
bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", GroupID: group.ID, Weight: 1, SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"}
|
||||||
if err := db.Create(&bNorm).Error; err != nil {
|
if err := db.Create(&bNorm).Error; err != nil {
|
||||||
t.Fatalf("create binding normalize: %v", err)
|
t.Fatalf("create binding normalize: %v", err)
|
||||||
}
|
}
|
||||||
@@ -104,8 +118,12 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte(raw), &snapRegex); err != nil {
|
if err := json.Unmarshal([]byte(raw), &snapRegex); err != nil {
|
||||||
t.Fatalf("unmarshal regex: %v", err)
|
t.Fatalf("unmarshal regex: %v", err)
|
||||||
}
|
}
|
||||||
if snapRegex.Upstreams[jsonID(p1.ID)] != "moonshot/kimi2" || snapRegex.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" {
|
if len(snapRegex.Candidates) != 1 {
|
||||||
t.Fatalf("unexpected regex upstreams: %+v", snapRegex.Upstreams)
|
t.Fatalf("expected 1 candidate, got %+v", snapRegex.Candidates)
|
||||||
|
}
|
||||||
|
upstreams := snapRegex.Candidates[0].Upstreams
|
||||||
|
if upstreams[jsonID(k1.ID)] != "moonshot/kimi2" || upstreams[jsonID(k2.ID)] != "moonshot/kimi2" {
|
||||||
|
t.Fatalf("unexpected regex upstreams: %+v", upstreams)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize_exact binding should include p2 but exclude p1 due to multi-match (moonshot/kimi2 + kimi2).
|
// Normalize_exact binding should include p2 but exclude p1 due to multi-match (moonshot/kimi2 + kimi2).
|
||||||
@@ -114,11 +132,11 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
|||||||
if err := json.Unmarshal([]byte(raw), &snapNorm); err != nil {
|
if err := json.Unmarshal([]byte(raw), &snapNorm); err != nil {
|
||||||
t.Fatalf("unmarshal normalize: %v", err)
|
t.Fatalf("unmarshal normalize: %v", err)
|
||||||
}
|
}
|
||||||
if snapNorm.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" {
|
if len(snapNorm.Candidates) != 1 {
|
||||||
t.Fatalf("expected p2 upstream, got %+v", snapNorm.Upstreams)
|
t.Fatalf("expected 1 candidate, got %+v", snapNorm.Candidates)
|
||||||
}
|
}
|
||||||
if _, ok := snapNorm.Upstreams[jsonID(p1.ID)]; ok {
|
if len(snapNorm.Candidates[0].Upstreams) != 0 || snapNorm.Candidates[0].Error != "config_error" {
|
||||||
t.Fatalf("did not expect p1 upstream due to normalize multi-match, got %+v", snapNorm.Upstreams)
|
t.Fatalf("expected config_error with no upstreams, got %+v", snapNorm.Candidates[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,68 +2,76 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"reflect"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/alicebob/miniredis/v2"
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/foundation/contract"
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) {
|
func TestSyncProviders_WritesSnapshotAndRouting(t *testing.T) {
|
||||||
goldenRaw := contract.ProviderSnapshotJSON()
|
|
||||||
var golden map[string]any
|
|
||||||
if err := json.Unmarshal(goldenRaw, &golden); err != nil {
|
|
||||||
t.Fatalf("parse golden json: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mr := miniredis.RunT(t)
|
mr := miniredis.RunT(t)
|
||||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
|
||||||
svc := NewSyncService(rdb)
|
svc := NewSyncService(rdb)
|
||||||
|
|
||||||
p := &model.Provider{
|
db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{})
|
||||||
Name: "p1",
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
group := model.ProviderGroup{
|
||||||
|
Name: "default",
|
||||||
Type: "vertex-express",
|
Type: "vertex-express",
|
||||||
Group: "default",
|
BaseURL: "https://vertex.example",
|
||||||
|
GoogleLocation: "global",
|
||||||
Models: "gemini-3-pro-preview",
|
Models: "gemini-3-pro-preview",
|
||||||
Status: "active",
|
Status: "active",
|
||||||
|
}
|
||||||
|
if err := db.Create(&group).Error; err != nil {
|
||||||
|
t.Fatalf("create group: %v", err)
|
||||||
|
}
|
||||||
|
key := model.APIKey{
|
||||||
|
GroupID: group.ID,
|
||||||
|
APIKey: "k",
|
||||||
|
Status: "active",
|
||||||
AutoBan: true,
|
AutoBan: true,
|
||||||
GoogleProject: "",
|
|
||||||
GoogleLocation: "global",
|
|
||||||
}
|
}
|
||||||
p.ID = 42
|
if err := db.Create(&key).Error; err != nil {
|
||||||
|
t.Fatalf("create key: %v", err)
|
||||||
if err := svc.SyncProvider(p); err != nil {
|
|
||||||
t.Fatalf("SyncProvider: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
raw := mr.HGet("config:providers", "42")
|
if err := svc.SyncProviders(db); err != nil {
|
||||||
|
t.Fatalf("SyncProviders: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := mr.HGet("config:providers", jsonID(key.ID))
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
t.Fatalf("expected config:providers hash entry")
|
t.Fatalf("expected config:providers hash entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
var snap map[string]any
|
var snap map[string]any
|
||||||
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
||||||
t.Fatalf("invalid snapshot json: %v", err)
|
t.Fatalf("invalid snapshot json: %v", err)
|
||||||
}
|
}
|
||||||
for k, v := range golden {
|
if snap["group"] != "default" {
|
||||||
if !reflect.DeepEqual(snap[k], v) {
|
t.Fatalf("expected group default, got %#v", snap["group"])
|
||||||
t.Fatalf("snapshot mismatch for %q: got=%#v want=%#v", k, snap[k], v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
routeKey := "route:group:default:gemini-3-pro-preview"
|
routeKey := "route:group:default:gemini-3-pro-preview"
|
||||||
if !mr.Exists(routeKey) {
|
if !mr.Exists(routeKey) {
|
||||||
t.Fatalf("expected routing key %q to exist", routeKey)
|
t.Fatalf("expected routing key %q to exist", routeKey)
|
||||||
}
|
}
|
||||||
ok, err := mr.SIsMember(routeKey, "42")
|
ok, err := mr.SIsMember(routeKey, jsonID(key.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("SIsMember: %v", err)
|
t.Fatalf("SIsMember: %v", err)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected provider id 42 in routing set %q", routeKey)
|
t.Fatalf("expected provider id in routing set %q", routeKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,34 +121,6 @@ func TestSyncModelDelete_RemovesMeta(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSyncProviderDelete_RemovesSnapshotAndRouting(t *testing.T) {
|
func jsonID(id uint) string {
|
||||||
mr := miniredis.RunT(t)
|
return strconv.FormatUint(uint64(id), 10)
|
||||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
|
||||||
svc := NewSyncService(rdb)
|
|
||||||
|
|
||||||
p := &model.Provider{
|
|
||||||
Name: "p1",
|
|
||||||
Type: "openai",
|
|
||||||
Group: "default",
|
|
||||||
Models: "gpt-4o-mini,gpt-4o",
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
p.ID = 7
|
|
||||||
|
|
||||||
if err := svc.SyncProvider(p); err != nil {
|
|
||||||
t.Fatalf("SyncProvider: %v", err)
|
|
||||||
}
|
|
||||||
if err := svc.SyncProviderDelete(p); err != nil {
|
|
||||||
t.Fatalf("SyncProviderDelete: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got := mr.HGet("config:providers", "7"); got != "" {
|
|
||||||
t.Fatalf("expected provider snapshot removed, got %q", got)
|
|
||||||
}
|
|
||||||
if ok, _ := mr.SIsMember("route:group:default:gpt-4o-mini", "7"); ok {
|
|
||||||
t.Fatalf("expected provider removed from route set")
|
|
||||||
}
|
|
||||||
if ok, _ := mr.SIsMember("route:group:default:gpt-4o", "7"); ok {
|
|
||||||
t.Fatalf("expected provider removed from route set")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user