From a7571dd4ad7ed6fd22f7ab31bd21095fbc8554d9 Mon Sep 17 00:00:00 2001 From: zenfun Date: Sun, 4 Jan 2026 01:44:45 +0800 Subject: [PATCH] feat(server): integrate ip ban cron and refine updates - Initialize and schedule IP ban maintenance tasks in server entry point - Perform initial IP ban sync to Redis on startup - Implement optional JSON unmarshalling to handle null `expires_at` in API - Add CIDR overlap validation when updating rule status to active --- cmd/server/main.go | 7 ++++++ internal/api/ip_ban_handler.go | 41 +++++++++++++++++++++++++++------- internal/service/ip_ban.go | 31 ++++++++++++++++++------- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 58e9ed9..a3cfc05 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -235,6 +235,7 @@ func main() { authHandler := api.NewAuthHandler(db, rdb, adminService, masterService) ipBanService := service.NewIPBanService(db, rdb) ipBanHandler := api.NewIPBanHandler(ipBanService) + ipBanManager := cron.NewIPBanManager(ipBanService) modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{ Enabled: cfg.ModelRegistry.Enabled, RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second, @@ -250,11 +251,17 @@ func main() { if err := syncService.SyncAll(db); err != nil { logger.Warn("initial sync warning", "err", err) } + if err := ipBanService.SyncAllToRedis(context.Background()); err != nil { + logger.Warn("initial IP ban sync warning", "err", err) + } // Initial model registry refresh before scheduler starts if modelRegistryService.Enabled() { modelRegistryService.RunOnce(context.Background()) sched.Every("model-registry-refresh", modelRegistryService.RefreshEvery(), modelRegistryService.RunOnce) } + sched.Every("ip-ban-expire", time.Minute, ipBanManager.ExpireRunOnce) + sched.Every("ip-ban-hit-sync", 5*time.Minute, ipBanManager.HitSyncRunOnce) + sched.Every("ip-ban-full-sync", 5*time.Minute, ipBanManager.FullSyncRunOnce) sched.Start() // 5. Setup Gin Router diff --git a/internal/api/ip_ban_handler.go b/internal/api/ip_ban_handler.go index 1ff115e..733e2b6 100644 --- a/internal/api/ip_ban_handler.go +++ b/internal/api/ip_ban_handler.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "errors" "net/http" "strconv" @@ -29,9 +30,9 @@ type CreateIPBanRequest struct { // UpdateIPBanRequest represents a request to update an IP ban. type UpdateIPBanRequest struct { - Reason *string `json:"reason,omitempty"` - ExpiresAt *int64 `json:"expires_at,omitempty"` - Status *string `json:"status,omitempty"` + Reason *string `json:"reason,omitempty"` + ExpiresAt optionalInt64 `json:"expires_at,omitempty"` + Status *string `json:"status,omitempty"` } // IPBanView represents the API response for an IP ban. @@ -47,6 +48,25 @@ type IPBanView struct { UpdatedAt int64 `json:"updated_at"` } +type optionalInt64 struct { + Value *int64 + Set bool +} + +func (o *optionalInt64) UnmarshalJSON(data []byte) error { + o.Set = true + if string(data) == "null" { + o.Value = nil + return nil + } + var v int64 + if err := json.Unmarshal(data, &v); err != nil { + return err + } + o.Value = &v + return nil +} + func toIPBanView(ban *model.IPBan) IPBanView { return IPBanView{ ID: ban.ID, @@ -183,6 +203,7 @@ func (h *IPBanHandler) Get(c *gin.Context) { // @Success 200 {object} IPBanView // @Failure 400 {object} gin.H // @Failure 404 {object} gin.H +// @Failure 409 {object} gin.H // @Failure 500 {object} gin.H // @Router /admin/ip-bans/{id} [put] func (h *IPBanHandler) Update(c *gin.Context) { @@ -199,15 +220,19 @@ func (h *IPBanHandler) Update(c *gin.Context) { } ban, err := h.ipBanService.Update(c.Request.Context(), uint(id), service.UpdateIPBanRequest{ - Reason: req.Reason, - ExpiresAt: req.ExpiresAt, - Status: req.Status, + Reason: req.Reason, + ExpiresAt: req.ExpiresAt.Value, + ExpiresAtSet: req.ExpiresAt.Set, + Status: req.Status, }) if err != nil { - if errors.Is(err, service.ErrIPBanNotFound) { + switch { + case errors.Is(err, service.ErrIPBanNotFound): c.JSON(http.StatusNotFound, gin.H{"error": "IP ban not found"}) - } else { + case errors.Is(err, service.ErrCIDROverlap): + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + default: c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update IP ban", "details": err.Error()}) } return diff --git a/internal/service/ip_ban.go b/internal/service/ip_ban.go index 9d0d167..76bd366 100644 --- a/internal/service/ip_ban.go +++ b/internal/service/ip_ban.go @@ -15,10 +15,10 @@ import ( ) var ( - ErrInvalidCIDR = errors.New("invalid CIDR format") - ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule") - ErrIPBanNotFound = errors.New("IP ban not found") - ErrDuplicateCIDR = errors.New("CIDR already exists") + ErrInvalidCIDR = errors.New("invalid CIDR format") + ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule") + ErrIPBanNotFound = errors.New("IP ban not found") + ErrDuplicateCIDR = errors.New("CIDR already exists") ) // IPBanService handles global IP ban operations. @@ -93,9 +93,10 @@ type CreateIPBanRequest struct { // UpdateIPBanRequest represents a request to update an IP ban. type UpdateIPBanRequest struct { - Reason *string `json:"reason,omitempty"` - ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null" - Status *string `json:"status,omitempty"` + Reason *string `json:"reason,omitempty"` + ExpiresAt *int64 `json:"expires_at,omitempty"` + ExpiresAtSet bool `json:"-"` + Status *string `json:"status,omitempty"` } // Create creates a new IP ban with validation. @@ -180,11 +181,25 @@ func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanReque return nil, err } + if req.Status != nil && *req.Status == model.IPBanStatusActive && ban.Status != model.IPBanStatusActive { + var activeRules []model.IPBan + if err := s.db.WithContext(ctx). + Where("status = ? AND id <> ?", model.IPBanStatusActive, ban.ID). + Find(&activeRules).Error; err != nil { + return nil, err + } + for _, rule := range activeRules { + if CIDROverlaps(ban.CIDR, rule.CIDR) { + return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR) + } + } + } + updates := make(map[string]interface{}) if req.Reason != nil { updates["reason"] = *req.Reason } - if req.ExpiresAt != nil { + if req.ExpiresAtSet { updates["expires_at"] = req.ExpiresAt } if req.Status != nil {