From 305f2ebf18cc07899fbd0b6d973718bf8e0e8c79 Mon Sep 17 00:00:00 2001 From: zenfun Date: Fri, 12 Dec 2025 23:44:52 +0800 Subject: [PATCH] feat(provider): add update endpoint and enforce status checks Add `PUT /admin/providers/{id}` endpoint to allow updating provider configurations, including status and ban details. Update synchronization logic to exclude inactive or banned providers from routing tables to ensure traffic is not routed to them. --- internal/api/handler.go | 100 +++++++++++++++++++++++++++++++++++++++ internal/dto/provider.go | 3 ++ internal/service/sync.go | 13 +++++ 3 files changed, 116 insertions(+) diff --git a/internal/api/handler.go b/internal/api/handler.go index e7fbff2..339d4bc 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -86,6 +86,99 @@ func (h *Handler) CreateProvider(c *gin.Context) { c.JSON(http.StatusCreated, provider) } +// UpdateProvider godoc +// @Summary Update a provider +// @Description Update provider attributes including status/auto-ban flags +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param id path int true "Provider ID" +// @Param provider body dto.ProviderDTO true "Provider Info" +// @Success 200 {object} model.Provider +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/providers/{id} [put] +func (h *Handler) UpdateProvider(c *gin.Context) { + idParam := c.Param("id") + id, err := strconv.Atoi(idParam) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + var existing model.Provider + if err := h.db.First(&existing, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "provider not found"}) + return + } + + var req dto.ProviderDTO + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + update := map[string]any{} + if strings.TrimSpace(req.Name) != "" { + update["name"] = req.Name + } + if strings.TrimSpace(req.Type) != "" { + update["type"] = req.Type + } + if strings.TrimSpace(req.BaseURL) != "" { + update["base_url"] = req.BaseURL + } + if req.APIKey != "" { + update["api_key"] = req.APIKey + } + if req.Models != nil { + update["models"] = strings.Join(req.Models, ",") + } + if strings.TrimSpace(req.Group) != "" { + update["group"] = normalizeGroup(req.Group) + } + if req.AutoBan != nil { + update["auto_ban"] = *req.AutoBan + } + if strings.TrimSpace(req.Status) != "" { + update["status"] = req.Status + } + if req.BanReason != "" || strings.TrimSpace(req.Status) == "active" { + update["ban_reason"] = req.BanReason + } + if !req.BanUntil.IsZero() { + tu := req.BanUntil.UTC() + update["ban_until"] = &tu + } + if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { + update["ban_until"] = nil + } + + if len(update) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"}) + return + } + + if err := h.db.Model(&existing).Updates(update).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider", "details": err.Error()}) + return + } + + if err := h.db.First(&existing, id).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to reload provider", "details": err.Error()}) + return + } + + if err := h.sync.SyncProvider(&existing); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, existing) +} + // CreateModel godoc // @Summary Register a new model // @Description Register a supported model with its capabilities @@ -241,3 +334,10 @@ func (h *Handler) IngestLog(c *gin.Context) { h.logger.Write(rec) c.JSON(http.StatusAccepted, gin.H{"status": "queued"}) } + +func normalizeGroup(group string) string { + if strings.TrimSpace(group) == "" { + return "default" + } + return group +} diff --git a/internal/dto/provider.go b/internal/dto/provider.go index b48a4de..45714ad 100644 --- a/internal/dto/provider.go +++ b/internal/dto/provider.go @@ -14,4 +14,7 @@ type ProviderDTO struct { AutoBan *bool `json:"auto_ban,omitempty"` BanReason string `json:"ban_reason,omitempty"` BanUntil time.Time `json:"ban_until,omitempty"` + + // Optional control params + SkipRouting bool `json:"skip_routing,omitempty"` // if true, do not add to routing tables (e.g., disabled) } diff --git a/internal/service/sync.go b/internal/service/sync.go index 58871e9..ca5d7c5 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/bytedance/sonic" "github.com/ez-api/ez-api/internal/model" @@ -93,6 +94,12 @@ func (s *SyncService) SyncProvider(provider *model.Provider) error { if m == "" { continue } + if snap.Status != "active" { + continue + } + if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { + continue + } routeKey := fmt.Sprintf("route:group:%s:%s", group, m) pipe.SAdd(ctx, routeKey, provider.ID) } @@ -217,6 +224,12 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { if m == "" { continue } + if snap.Status != "active" { + continue + } + if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil { + continue + } routeKey := fmt.Sprintf("route:group:%s:%s", group, m) pipe.SAdd(ctx, routeKey, p.ID) }