From 2359603666bbec7e44288f032057b1d3af20b92d Mon Sep 17 00:00:00 2001 From: zenfun Date: Sun, 4 Jan 2026 00:59:03 +0800 Subject: [PATCH] feat(service): implement IP ban service logic Add IPBanService to manage global IP bans with Redis synchronization for high-performance filtering. Includes logic for CIDR normalization, overlap detection, hit count tracking, and rule expiration. --- internal/service/ip_ban.go | 364 +++++++++++++++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 internal/service/ip_ban.go diff --git a/internal/service/ip_ban.go b/internal/service/ip_ban.go new file mode 100644 index 0000000..1b1afff --- /dev/null +++ b/internal/service/ip_ban.go @@ -0,0 +1,364 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "strings" + + "github.com/ez-api/ez-api/internal/model" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +var ( + ErrInvalidCIDR = errors.New("invalid CIDR format") + ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule") + ErrIPBanNotFound = errors.New("IP ban not found") + ErrDuplicateCIDR = errors.New("CIDR already exists") +) + +// IPBanService handles global IP ban operations. +type IPBanService struct { + db *gorm.DB + rdb *redis.Client + logger *slog.Logger +} + +// NewIPBanService creates a new IPBanService. +func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService { + return &IPBanService{ + db: db, + rdb: rdb, + logger: slog.Default(), + } +} + +// NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format. +// - IPv4 addresses are converted to /32 +// - IPv6 addresses are converted to /128 +// - CIDR strings are validated and normalized +func NormalizeCIDR(input string) (string, error) { + input = strings.TrimSpace(input) + if input == "" { + return "", ErrInvalidCIDR + } + + // Check if it's a CIDR + if strings.Contains(input, "/") { + _, ipnet, err := net.ParseCIDR(input) + if err != nil { + return "", ErrInvalidCIDR + } + // Return the network base IP with mask + ones, _ := ipnet.Mask.Size() + return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil + } + + // It's a plain IP address + ip := net.ParseIP(input) + if ip == nil { + return "", ErrInvalidCIDR + } + + // Determine if IPv4 or IPv6 + if ip.To4() != nil { + return ip.String() + "/32", nil + } + return ip.String() + "/128", nil +} + +// CIDROverlaps checks if two CIDR ranges overlap. +func CIDROverlaps(a, b string) bool { + _, netA, errA := net.ParseCIDR(a) + _, netB, errB := net.ParseCIDR(b) + if errA != nil || errB != nil { + return false + } + + // Check if either network contains the other's base IP + return netA.Contains(netB.IP) || netB.Contains(netA.IP) +} + +// CreateIPBanRequest represents a request to create an IP ban. +type CreateIPBanRequest struct { + CIDR string `json:"cidr" binding:"required"` + Reason string `json:"reason,omitempty"` + ExpiresAt *int64 `json:"expires_at,omitempty"` + CreatedBy string `json:"created_by,omitempty"` +} + +// UpdateIPBanRequest represents a request to update an IP ban. +type UpdateIPBanRequest struct { + Reason *string `json:"reason,omitempty"` + ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null" + Status *string `json:"status,omitempty"` +} + +// Create creates a new IP ban with validation. +func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) { + // Normalize CIDR + normalized, err := NormalizeCIDR(req.CIDR) + if err != nil { + return nil, err + } + + // Check for existing rule with same CIDR + var existing model.IPBan + if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil { + return nil, ErrDuplicateCIDR + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + + // Check for overlapping active rules + var activeRules []model.IPBan + if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil { + return nil, err + } + + for _, rule := range activeRules { + if CIDROverlaps(normalized, rule.CIDR) { + return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR) + } + } + + // Create the ban + ban := &model.IPBan{ + CIDR: normalized, + Status: model.IPBanStatusActive, + Reason: req.Reason, + ExpiresAt: req.ExpiresAt, + CreatedBy: req.CreatedBy, + } + + if err := s.db.WithContext(ctx).Create(ban).Error; err != nil { + return nil, err + } + + // Sync to Redis + if err := s.syncBanToRedis(ctx, ban); err != nil { + s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err) + // Don't fail the create operation, just log the error + } + + return ban, nil +} + +// Get retrieves an IP ban by ID. +func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) { + var ban model.IPBan + if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrIPBanNotFound + } + return nil, err + } + return &ban, nil +} + +// List retrieves IP bans with optional status filter. +func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) { + var bans []model.IPBan + query := s.db.WithContext(ctx) + if status != "" { + query = query.Where("status = ?", status) + } + if err := query.Order("created_at DESC").Find(&bans).Error; err != nil { + return nil, err + } + return bans, nil +} + +// Update updates an existing IP ban. +func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) { + ban, err := s.Get(ctx, id) + if err != nil { + return nil, err + } + + updates := make(map[string]interface{}) + if req.Reason != nil { + updates["reason"] = *req.Reason + } + if req.ExpiresAt != nil { + updates["expires_at"] = req.ExpiresAt + } + if req.Status != nil { + if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired { + return nil, fmt.Errorf("invalid status: %s", *req.Status) + } + updates["status"] = *req.Status + } + + if len(updates) > 0 { + if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil { + return nil, err + } + // Reload to get updated values + if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil { + return nil, err + } + } + + // Sync to Redis + if err := s.syncBanToRedis(ctx, ban); err != nil { + s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err) + } + + return ban, nil +} + +// Delete removes an IP ban. +func (s *IPBanService) Delete(ctx context.Context, id uint) error { + ban, err := s.Get(ctx, id) + if err != nil { + return err + } + + // Remove from Redis first + if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil { + s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err) + } + + // Delete from database (hard delete, not soft delete) + if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil { + return err + } + + return nil +} + +// Redis key for global IP bans SET +const redisIPBansKey = "global:ip-bans" + +// syncBanToRedis syncs a ban to Redis. +func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error { + if s.rdb == nil { + return nil + } + + // Only sync active, non-expired bans + if ban.Status != model.IPBanStatusActive || ban.IsExpired() { + return s.removeBanFromRedis(ctx, ban.CIDR) + } + + return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err() +} + +// removeBanFromRedis removes a CIDR from Redis. +func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error { + if s.rdb == nil { + return nil + } + return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err() +} + +// SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans. +func (s *IPBanService) SyncAllToRedis(ctx context.Context) error { + if s.rdb == nil { + return nil + } + + // Get all active bans + var bans []model.IPBan + if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil { + return err + } + + // Filter out expired ones + var activeCIDRs []interface{} + for _, ban := range bans { + if !ban.IsExpired() { + activeCIDRs = append(activeCIDRs, ban.CIDR) + } + } + + // Use a pipeline to atomically replace the set + pipe := s.rdb.Pipeline() + pipe.Del(ctx, redisIPBansKey) + if len(activeCIDRs) > 0 { + pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...) + } + _, err := pipe.Exec(ctx) + return err +} + +// ExpireOutdatedBans marks expired bans and removes them from Redis. +func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) { + // Find active bans that have expired + var expiredBans []model.IPBan + if err := s.db.WithContext(ctx). + Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?", + model.IPBanStatusActive, + ctx.Value("now")). // This should be time.Now().Unix() in caller + Find(&expiredBans).Error; err != nil { + return 0, err + } + + if len(expiredBans) == 0 { + return 0, nil + } + + // Mark as expired + var ids []uint + for _, ban := range expiredBans { + ids = append(ids, ban.ID) + } + + result := s.db.WithContext(ctx). + Model(&model.IPBan{}). + Where("id IN ?", ids). + Update("status", model.IPBanStatusExpired) + + if result.Error != nil { + return 0, result.Error + } + + // Remove from Redis + for _, ban := range expiredBans { + if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil { + s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err) + } + } + + return result.RowsAffected, nil +} + +// SyncHitCounts syncs hit counts from Redis to database. +func (s *IPBanService) SyncHitCounts(ctx context.Context) error { + if s.rdb == nil { + return nil + } + + // Get all active bans + var bans []model.IPBan + if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil { + return err + } + + for _, ban := range bans { + hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR) + + // Get and reset the counter atomically + count, err := s.rdb.GetDel(ctx, hitKey).Int64() + if err != nil && !errors.Is(err, redis.Nil) { + s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err) + continue + } + + if count > 0 { + // Add to database hit_count + if err := s.db.WithContext(ctx). + Model(&model.IPBan{}). + Where("id = ?", ban.ID). + Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil { + s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err) + } + } + } + + return nil +}