diff --git a/cmd/server/main.go b/cmd/server/main.go index 43f03d5..0bc07cd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -167,6 +167,8 @@ func main() { adminGroup.PUT("/features", featureHandler.UpdateFeatures) // Other admin routes for managing providers, models, etc. adminGroup.POST("/providers", handler.CreateProvider) + adminGroup.POST("/providers/preset", handler.CreateProviderPreset) + adminGroup.POST("/providers/custom", handler.CreateProviderCustom) adminGroup.PUT("/providers/:id", handler.UpdateProvider) adminGroup.POST("/models", handler.CreateModel) adminGroup.GET("/models", handler.ListModels) diff --git a/internal/api/handler.go b/internal/api/handler.go index 3c05982..e0b7a83 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -74,6 +74,7 @@ func (h *Handler) CreateProvider(c *gin.Context) { Status: status, AutoBan: autoBan, BanReason: req.BanReason, + Weight: req.Weight, } if !req.BanUntil.IsZero() { tu := req.BanUntil.UTC() @@ -162,6 +163,9 @@ func (h *Handler) UpdateProvider(c *gin.Context) { if req.Models != nil { update["models"] = strings.Join(req.Models, ",") } + if req.Weight > 0 { + update["weight"] = req.Weight + } if strings.TrimSpace(req.Group) != "" { update["group"] = groupx.Normalize(req.Group) } diff --git a/internal/api/provider_create_handler.go b/internal/api/provider_create_handler.go new file mode 100644 index 0000000..c70b32a --- /dev/null +++ b/internal/api/provider_create_handler.go @@ -0,0 +1,194 @@ +package api + +import ( + "crypto/rand" + "encoding/hex" + "net/http" + "strings" + + "github.com/ez-api/ez-api/internal/dto" + "github.com/ez-api/ez-api/internal/model" + groupx "github.com/ez-api/foundation/group" + providerx "github.com/ez-api/foundation/provider" + "github.com/gin-gonic/gin" +) + +// CreateProviderPreset godoc +// @Summary Create a preset provider +// @Description Create an official OpenAI/Anthropic/Gemini provider (only api_key is typically required) +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param provider body dto.ProviderPresetCreateDTO true "Provider preset payload" +// @Success 201 {object} model.Provider +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/providers/preset [post] +func (h *Handler) CreateProviderPreset(c *gin.Context) { + var req dto.ProviderPresetCreateDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + preset := providerx.NormalizeType(req.Preset) + if preset == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "preset required"}) + return + } + + var providerType string + var baseURL string + switch preset { + case providerx.TypeOpenAI: + providerType = providerx.TypeOpenAI + baseURL = "https://api.openai.com" + case providerx.TypeAnthropic, providerx.TypeClaude: + providerType = providerx.TypeAnthropic + baseURL = "https://api.anthropic.com" + case providerx.TypeGemini, providerx.TypeAIStudio, providerx.TypeGoogle: + // Gemini API / AI Studio (SDK transport). BaseURL is optional but we provide the official endpoint. + providerType = providerx.TypeGemini + baseURL = "https://generativelanguage.googleapis.com" + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported preset: " + preset}) + return + } + + name := strings.TrimSpace(req.Name) + if name == "" { + name = providerType + "-" + randomSuffix(4) + } + group := strings.TrimSpace(req.Group) + if group == "" { + group = "default" + } + status := strings.TrimSpace(req.Status) + if status == "" { + status = "active" + } + autoBan := true + if req.AutoBan != nil { + autoBan = *req.AutoBan + } + + googleLocation := providerx.DefaultGoogleLocation(providerType, req.GoogleLocation) + + p := model.Provider{ + Name: name, + Type: providerType, + BaseURL: baseURL, + APIKey: strings.TrimSpace(req.APIKey), + GoogleProject: strings.TrimSpace(req.GoogleProject), + GoogleLocation: googleLocation, + Group: groupx.Normalize(group), + Models: strings.Join(req.Models, ","), + Status: status, + AutoBan: autoBan, + } + if req.Weight > 0 { + p.Weight = req.Weight + } + + if err := h.db.Create(&p).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) + return + } + + if err := h.sync.SyncProvider(&p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusCreated, p) +} + +// CreateProviderCustom godoc +// @Summary Create a custom provider +// @Description Create an OpenAI-compatible provider (base_url + api_key required) +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param provider body dto.ProviderCustomCreateDTO true "Provider custom payload" +// @Success 201 {object} model.Provider +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/providers/custom [post] +func (h *Handler) CreateProviderCustom(c *gin.Context) { + var req dto.ProviderCustomCreateDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + name := strings.TrimSpace(req.Name) + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name required"}) + return + } + baseURL := strings.TrimSpace(req.BaseURL) + if baseURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "base_url required"}) + return + } + + group := strings.TrimSpace(req.Group) + if group == "" { + group = "default" + } + status := strings.TrimSpace(req.Status) + if status == "" { + status = "active" + } + autoBan := true + if req.AutoBan != nil { + autoBan = *req.AutoBan + } + + p := model.Provider{ + Name: name, + Type: providerx.TypeCompatible, + BaseURL: baseURL, + APIKey: strings.TrimSpace(req.APIKey), + Group: groupx.Normalize(group), + Models: strings.Join(req.Models, ","), + Status: status, + AutoBan: autoBan, + } + if req.Weight > 0 { + p.Weight = req.Weight + } + + if err := h.db.Create(&p).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider", "details": err.Error()}) + return + } + + if err := h.sync.SyncProvider(&p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) + return + } + if err := h.sync.SyncBindings(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync bindings", "details": err.Error()}) + return + } + + c.JSON(http.StatusCreated, p) +} + +func randomSuffix(bytesLen int) string { + if bytesLen <= 0 { + bytesLen = 4 + } + b := make([]byte, bytesLen) + if _, err := rand.Read(b); err != nil { + return "rand" + } + return hex.EncodeToString(b) +} diff --git a/internal/api/provider_create_handler_test.go b/internal/api/provider_create_handler_test.go new file mode 100644 index 0000000..6218fbd --- /dev/null +++ b/internal/api/provider_create_handler_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ez-api/ez-api/internal/model" + "github.com/gin-gonic/gin" +) + +func TestCreateProviderPreset_OpenAI_SetsBaseURL(t *testing.T) { + h, _ := newTestHandler(t) + + r := gin.New() + r.POST("/admin/providers/preset", h.CreateProviderPreset) + + reqBody := map[string]any{ + "preset": "openai", + "api_key": "k", + "models": []string{"gpt-4o-mini"}, + } + b, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/admin/providers/preset", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String()) + } + var got model.Provider + if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Type != "openai" { + t.Fatalf("expected type openai, got %q", got.Type) + } + if got.BaseURL != "https://api.openai.com" { + t.Fatalf("expected base_url=https://api.openai.com, got %q", got.BaseURL) + } + if got.Name == "" { + t.Fatalf("expected generated name") + } +} + +func TestCreateProviderCustom_RequiresBaseURL(t *testing.T) { + h, _ := newTestHandler(t) + + r := gin.New() + r.POST("/admin/providers/custom", h.CreateProviderCustom) + + reqBody := map[string]any{ + "name": "c1", + "api_key": "k", + } + b, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/admin/providers/custom", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } +} diff --git a/internal/dto/provider.go b/internal/dto/provider.go index 304a18f..9e31be3 100644 --- a/internal/dto/provider.go +++ b/internal/dto/provider.go @@ -13,6 +13,7 @@ type ProviderDTO struct { Group string `json:"group"` Models []string `json:"models"` // List of supported model names Status string `json:"status"` + Weight int `json:"weight,omitempty"` AutoBan *bool `json:"auto_ban,omitempty"` BanReason string `json:"ban_reason,omitempty"` BanUntil time.Time `json:"ban_until,omitempty"` diff --git a/internal/dto/provider_create.go b/internal/dto/provider_create.go new file mode 100644 index 0000000..9357883 --- /dev/null +++ b/internal/dto/provider_create.go @@ -0,0 +1,35 @@ +package dto + +// ProviderPresetCreateDTO creates an official provider with sensible defaults. +// For preset providers, base_url is derived automatically; users typically only provide api_key. +type ProviderPresetCreateDTO struct { + Preset string `json:"preset"` // openai | anthropic | gemini + + // Optional fields. + Name string `json:"name"` + Group string `json:"group"` + Models []string `json:"models"` + + APIKey string `json:"api_key"` + GoogleProject string `json:"google_project,omitempty"` + GoogleLocation string `json:"google_location,omitempty"` + + Status string `json:"status"` + Weight int `json:"weight,omitempty"` + AutoBan *bool `json:"auto_ban,omitempty"` +} + +// ProviderCustomCreateDTO creates an OpenAI-compatible provider. +// For custom providers, base_url is required. +type ProviderCustomCreateDTO struct { + Name string `json:"name"` + Group string `json:"group"` + Models []string `json:"models"` + + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + + Status string `json:"status"` + Weight int `json:"weight,omitempty"` + AutoBan *bool `json:"auto_ban,omitempty"` +} diff --git a/internal/model/models.go b/internal/model/models.go index 25f31ae..838d905 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -49,6 +49,7 @@ type Provider struct { GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` Group string `gorm:"default:'default'" json:"group"` // routing group/tier Models string `json:"models"` // comma-separated list of supported models (e.g. "gpt-4,gpt-3.5-turbo") + Weight int `gorm:"default:1" json:"weight"` // routing weight inside route_group Status string `gorm:"size:50;default:'active'" json:"status"` // active, auto_disabled, manual_disabled AutoBan bool `gorm:"default:true" json:"auto_ban"` // whether DP-triggered disable is allowed BanReason string `gorm:"size:255" json:"ban_reason"` // reason for current disable diff --git a/internal/service/sync.go b/internal/service/sync.go index 2c1ad21..157785e 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -79,6 +79,7 @@ func (s *SyncService) SyncProvider(provider *model.Provider) error { GoogleLocation: provider.GoogleLocation, Group: group, Models: models, + Weight: provider.Weight, Status: normalizeStatus(provider.Status), AutoBan: provider.AutoBan, BanReason: provider.BanReason, @@ -139,6 +140,7 @@ type providerSnapshot struct { GoogleLocation string `json:"google_location,omitempty"` Group string `json:"group"` Models []string `json:"models"` + Weight int `json:"weight,omitempty"` Status string `json:"status"` AutoBan bool `json:"auto_ban"` BanReason string `json:"ban_reason,omitempty"` @@ -220,6 +222,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { GoogleLocation: p.GoogleLocation, Group: group, Models: models, + Weight: p.Weight, Status: normalizeStatus(p.Status), AutoBan: p.AutoBan, BanReason: p.BanReason,