package service import ( "context" "errors" "fmt" "log/slog" "net" "strings" "time" "github.com/ez-api/ez-api/internal/model" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) var ( ErrInvalidCIDR = errors.New("invalid CIDR format") ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule") ErrIPBanNotFound = errors.New("IP ban not found") ErrDuplicateCIDR = errors.New("CIDR already exists") ) // IPBanService handles global IP ban operations. type IPBanService struct { db *gorm.DB rdb *redis.Client logger *slog.Logger } // NewIPBanService creates a new IPBanService. func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService { return &IPBanService{ db: db, rdb: rdb, logger: slog.Default(), } } // NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format. // - IPv4 addresses are converted to /32 // - IPv6 addresses are converted to /128 // - CIDR strings are validated and normalized func NormalizeCIDR(input string) (string, error) { input = strings.TrimSpace(input) if input == "" { return "", ErrInvalidCIDR } // Check if it's a CIDR if strings.Contains(input, "/") { _, ipnet, err := net.ParseCIDR(input) if err != nil { return "", ErrInvalidCIDR } // Return the network base IP with mask ones, _ := ipnet.Mask.Size() return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil } // It's a plain IP address ip := net.ParseIP(input) if ip == nil { return "", ErrInvalidCIDR } // Determine if IPv4 or IPv6 if ip.To4() != nil { return ip.String() + "/32", nil } return ip.String() + "/128", nil } // CIDROverlaps checks if two CIDR ranges overlap. func CIDROverlaps(a, b string) bool { _, netA, errA := net.ParseCIDR(a) _, netB, errB := net.ParseCIDR(b) if errA != nil || errB != nil { return false } // Check if either network contains the other's base IP return netA.Contains(netB.IP) || netB.Contains(netA.IP) } // CreateIPBanRequest represents a request to create an IP ban. type CreateIPBanRequest struct { CIDR string `json:"cidr" binding:"required"` Reason string `json:"reason,omitempty"` ExpiresAt *int64 `json:"expires_at,omitempty"` CreatedBy string `json:"created_by,omitempty"` } // UpdateIPBanRequest represents a request to update an IP ban. type UpdateIPBanRequest struct { Reason *string `json:"reason,omitempty"` ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null" Status *string `json:"status,omitempty"` } // Create creates a new IP ban with validation. func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) { // Normalize CIDR normalized, err := NormalizeCIDR(req.CIDR) if err != nil { return nil, err } // Check for existing rule with same CIDR var existing model.IPBan if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil { return nil, ErrDuplicateCIDR } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } // Check for overlapping active rules var activeRules []model.IPBan if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil { return nil, err } for _, rule := range activeRules { if CIDROverlaps(normalized, rule.CIDR) { return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR) } } // Create the ban ban := &model.IPBan{ CIDR: normalized, Status: model.IPBanStatusActive, Reason: req.Reason, ExpiresAt: req.ExpiresAt, CreatedBy: req.CreatedBy, } if err := s.db.WithContext(ctx).Create(ban).Error; err != nil { return nil, err } // Sync to Redis if err := s.syncBanToRedis(ctx, ban); err != nil { s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err) // Don't fail the create operation, just log the error } return ban, nil } // Get retrieves an IP ban by ID. func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) { var ban model.IPBan if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrIPBanNotFound } return nil, err } return &ban, nil } // List retrieves IP bans with optional status filter. func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) { var bans []model.IPBan query := s.db.WithContext(ctx) if status != "" { query = query.Where("status = ?", status) } if err := query.Order("created_at DESC").Find(&bans).Error; err != nil { return nil, err } return bans, nil } // Update updates an existing IP ban. func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) { ban, err := s.Get(ctx, id) if err != nil { return nil, err } updates := make(map[string]interface{}) if req.Reason != nil { updates["reason"] = *req.Reason } if req.ExpiresAt != nil { updates["expires_at"] = req.ExpiresAt } if req.Status != nil { if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired { return nil, fmt.Errorf("invalid status: %s", *req.Status) } updates["status"] = *req.Status } if len(updates) > 0 { if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil { return nil, err } // Reload to get updated values if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil { return nil, err } } // Sync to Redis if err := s.syncBanToRedis(ctx, ban); err != nil { s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err) } return ban, nil } // Delete removes an IP ban. func (s *IPBanService) Delete(ctx context.Context, id uint) error { ban, err := s.Get(ctx, id) if err != nil { return err } // Remove from Redis first if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil { s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err) } // Delete from database (hard delete, not soft delete) if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil { return err } return nil } // Redis key for global IP bans SET const redisIPBansKey = "global:ip-bans" // syncBanToRedis syncs a ban to Redis. func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error { if s.rdb == nil { return nil } // Only sync active, non-expired bans if ban.Status != model.IPBanStatusActive || ban.IsExpired() { return s.removeBanFromRedis(ctx, ban.CIDR) } return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err() } // removeBanFromRedis removes a CIDR from Redis. func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error { if s.rdb == nil { return nil } return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err() } // SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans. func (s *IPBanService) SyncAllToRedis(ctx context.Context) error { if s.rdb == nil { return nil } // Get all active bans var bans []model.IPBan if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil { return err } // Filter out expired ones var activeCIDRs []interface{} for _, ban := range bans { if !ban.IsExpired() { activeCIDRs = append(activeCIDRs, ban.CIDR) } } // Use a pipeline to atomically replace the set pipe := s.rdb.Pipeline() pipe.Del(ctx, redisIPBansKey) if len(activeCIDRs) > 0 { pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...) } _, err := pipe.Exec(ctx) return err } // ExpireOutdatedBans marks expired bans and removes them from Redis. func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) { now := time.Now().Unix() // Find active bans that have expired var expiredBans []model.IPBan if err := s.db.WithContext(ctx). Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?", model.IPBanStatusActive, now). Find(&expiredBans).Error; err != nil { return 0, err } if len(expiredBans) == 0 { return 0, nil } // Mark as expired var ids []uint for _, ban := range expiredBans { ids = append(ids, ban.ID) } result := s.db.WithContext(ctx). Model(&model.IPBan{}). Where("id IN ?", ids). Update("status", model.IPBanStatusExpired) if result.Error != nil { return 0, result.Error } // Remove from Redis for _, ban := range expiredBans { if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil { s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err) } } return result.RowsAffected, nil } // SyncHitCounts syncs hit counts from Redis to database. func (s *IPBanService) SyncHitCounts(ctx context.Context) error { if s.rdb == nil { return nil } // Get all active bans var bans []model.IPBan if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil { return err } for _, ban := range bans { hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR) // Get and reset the counter atomically count, err := s.rdb.GetDel(ctx, hitKey).Int64() if err != nil && !errors.Is(err, redis.Nil) { s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err) continue } if count > 0 { // Add to database hit_count if err := s.db.WithContext(ctx). Model(&model.IPBan{}). Where("id = ?", ban.ID). Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil { s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err) } } } return nil }