diff --git a/internal/api/handler.go b/internal/api/handler.go index e7fbff2..339d4bc 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -86,6 +86,99 @@ func (h *Handler) CreateProvider(c *gin.Context) { c.JSON(http.StatusCreated, provider) } +// UpdateProvider godoc +// @Summary Update a provider +// @Description Update provider attributes including status/auto-ban flags +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param id path int true "Provider ID" +// @Param provider body dto.ProviderDTO true "Provider Info" +// @Success 200 {object} model.Provider +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/providers/{id} [put] +func (h *Handler) UpdateProvider(c *gin.Context) { + idParam := c.Param("id") + id, err := strconv.Atoi(idParam) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + var existing model.Provider + if err := h.db.First(&existing, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) + return + } + + var req dto.ProviderDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + update := map[string]any{} + if strings.TrimSpace(req.Name) != "" { + update["name"] = req.Name + } + if strings.TrimSpace(req.Type) != "" { + update["type"] = req.Type + } + if strings.TrimSpace(req.BaseURL) != "" { + update["base_url"] = req.BaseURL + } + if req.APIKey != "" { + update["api_key"] = req.APIKey + } + if req.Models != nil { + update["models"] = strings.Join(req.Models, ",") + } + if strings.TrimSpace(req.Group) != "" { + update["group"] = normalizeGroup(req.Group) + } + if req.AutoBan != nil { + update["auto_ban"] = *req.AutoBan + } + if strings.TrimSpace(req.Status) != "" { + update["status"] = req.Status + } + if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" { + update["ban_reason"] = req.BanReason + } + if !req.BanUntil.IsZero() { + tu := req.BanUntil.UTC() + update["ban_until"] = &tu + } + if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { + update["ban_until"] = nil + } + + if len(update) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"}) + return + } + + if err := h.db.Model(&existing).Updates(update).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider", "details": err.Error()}) + return + } + + if err := h.db.First(&existing, id).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider", "details": err.Error()}) + return + } + + if err := h.sync.SyncProvider(&existing); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, existing) +} + // CreateModel godoc // @Summary Register a new model // @Description Register a supported model with its capabilities @@ -241,3 +334,10 @@ func (h *Handler) IngestLog(c *gin.Context) { h.logger.Write(rec) c.JSON(http.StatusAccepted, gin.H{"status": "queued"}) } + +func normalizeGroup(group string) string { + if strings.TrimSpace(group) == "" { + return "default" + } + return group +} diff --git a/internal/dto/provider.go b/internal/dto/provider.go index b48a4de..45714ad 100644 --- a/internal/dto/provider.go +++ b/internal/dto/provider.go @@ -14,4 +14,7 @@ type ProviderDTO struct { AutoBan *bool `json:"auto_ban,omitempty"` BanReason string `json:"ban_reason,omitempty"` BanUntil time.Time `json:"ban_until,omitempty"` + + // Optional control params + SkipRouting bool `json:"skip_routing,omitempty"` // if true, do not add to routing tables (e.g., disabled) } diff --git a/internal/service/sync.go b/internal/service/sync.go index 58871e9..ca5d7c5 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/bytedance/sonic" "github.com/ez-api/ez-api/internal/model" @@ -93,6 +94,12 @@ func (s *SyncService) SyncProvider(provider *model.Provider) error { if m == "" { continue } + if snap.Status != "active" { + continue + } + if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { + continue + } routeKey := fmt.Sprintf("route:group:%s:%s", group, m) pipe.SAdd(ctx, routeKey, provider.ID) } @@ -217,6 +224,12 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { if m == "" { continue } + if snap.Status != "active" { + continue + } + if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { + continue + } routeKey := fmt.Sprintf("route:group:%s:%s", group, m) pipe.SAdd(ctx, routeKey, p.ID) }