Files
ez-api/internal/service/ip_ban.go
zenfun 2359603666 feat(service): implement IP ban service logic
Add IPBanService to manage global IP bans with Redis synchronization
for high-performance filtering. Includes logic for CIDR normalization,
overlap detection, hit count tracking, and rule expiration.
2026-01-04 00:59:03 +08:00

365 lines
9.5 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"strings"
"github.com/ez-api/ez-api/internal/model"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrInvalidCIDR = errors.New("invalid CIDR format")
ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule")
ErrIPBanNotFound = errors.New("IP ban not found")
ErrDuplicateCIDR = errors.New("CIDR already exists")
)
// IPBanService handles global IP ban operations.
type IPBanService struct {
db *gorm.DB
rdb *redis.Client
logger *slog.Logger
}
// NewIPBanService creates a new IPBanService.
func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService {
return &IPBanService{
db: db,
rdb: rdb,
logger: slog.Default(),
}
}
// NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format.
// - IPv4 addresses are converted to /32
// - IPv6 addresses are converted to /128
// - CIDR strings are validated and normalized
func NormalizeCIDR(input string) (string, error) {
input = strings.TrimSpace(input)
if input == "" {
return "", ErrInvalidCIDR
}
// Check if it's a CIDR
if strings.Contains(input, "/") {
_, ipnet, err := net.ParseCIDR(input)
if err != nil {
return "", ErrInvalidCIDR
}
// Return the network base IP with mask
ones, _ := ipnet.Mask.Size()
return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil
}
// It's a plain IP address
ip := net.ParseIP(input)
if ip == nil {
return "", ErrInvalidCIDR
}
// Determine if IPv4 or IPv6
if ip.To4() != nil {
return ip.String() + "/32", nil
}
return ip.String() + "/128", nil
}
// CIDROverlaps checks if two CIDR ranges overlap.
func CIDROverlaps(a, b string) bool {
_, netA, errA := net.ParseCIDR(a)
_, netB, errB := net.ParseCIDR(b)
if errA != nil || errB != nil {
return false
}
// Check if either network contains the other's base IP
return netA.Contains(netB.IP) || netB.Contains(netA.IP)
}
// CreateIPBanRequest represents a request to create an IP ban.
type CreateIPBanRequest struct {
CIDR string `json:"cidr" binding:"required"`
Reason string `json:"reason,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
}
// UpdateIPBanRequest represents a request to update an IP ban.
type UpdateIPBanRequest struct {
Reason *string `json:"reason,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null"
Status *string `json:"status,omitempty"`
}
// Create creates a new IP ban with validation.
func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) {
// Normalize CIDR
normalized, err := NormalizeCIDR(req.CIDR)
if err != nil {
return nil, err
}
// Check for existing rule with same CIDR
var existing model.IPBan
if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil {
return nil, ErrDuplicateCIDR
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
// Check for overlapping active rules
var activeRules []model.IPBan
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil {
return nil, err
}
for _, rule := range activeRules {
if CIDROverlaps(normalized, rule.CIDR) {
return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR)
}
}
// Create the ban
ban := &model.IPBan{
CIDR: normalized,
Status: model.IPBanStatusActive,
Reason: req.Reason,
ExpiresAt: req.ExpiresAt,
CreatedBy: req.CreatedBy,
}
if err := s.db.WithContext(ctx).Create(ban).Error; err != nil {
return nil, err
}
// Sync to Redis
if err := s.syncBanToRedis(ctx, ban); err != nil {
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
// Don't fail the create operation, just log the error
}
return ban, nil
}
// Get retrieves an IP ban by ID.
func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) {
var ban model.IPBan
if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrIPBanNotFound
}
return nil, err
}
return &ban, nil
}
// List retrieves IP bans with optional status filter.
func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) {
var bans []model.IPBan
query := s.db.WithContext(ctx)
if status != "" {
query = query.Where("status = ?", status)
}
if err := query.Order("created_at DESC").Find(&bans).Error; err != nil {
return nil, err
}
return bans, nil
}
// Update updates an existing IP ban.
func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) {
ban, err := s.Get(ctx, id)
if err != nil {
return nil, err
}
updates := make(map[string]interface{})
if req.Reason != nil {
updates["reason"] = *req.Reason
}
if req.ExpiresAt != nil {
updates["expires_at"] = req.ExpiresAt
}
if req.Status != nil {
if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired {
return nil, fmt.Errorf("invalid status: %s", *req.Status)
}
updates["status"] = *req.Status
}
if len(updates) > 0 {
if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil {
return nil, err
}
// Reload to get updated values
if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil {
return nil, err
}
}
// Sync to Redis
if err := s.syncBanToRedis(ctx, ban); err != nil {
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
}
return ban, nil
}
// Delete removes an IP ban.
func (s *IPBanService) Delete(ctx context.Context, id uint) error {
ban, err := s.Get(ctx, id)
if err != nil {
return err
}
// Remove from Redis first
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err)
}
// Delete from database (hard delete, not soft delete)
if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil {
return err
}
return nil
}
// Redis key for global IP bans SET
const redisIPBansKey = "global:ip-bans"
// syncBanToRedis syncs a ban to Redis.
func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error {
if s.rdb == nil {
return nil
}
// Only sync active, non-expired bans
if ban.Status != model.IPBanStatusActive || ban.IsExpired() {
return s.removeBanFromRedis(ctx, ban.CIDR)
}
return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err()
}
// removeBanFromRedis removes a CIDR from Redis.
func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error {
if s.rdb == nil {
return nil
}
return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err()
}
// SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans.
func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
if s.rdb == nil {
return nil
}
// Get all active bans
var bans []model.IPBan
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
return err
}
// Filter out expired ones
var activeCIDRs []interface{}
for _, ban := range bans {
if !ban.IsExpired() {
activeCIDRs = append(activeCIDRs, ban.CIDR)
}
}
// Use a pipeline to atomically replace the set
pipe := s.rdb.Pipeline()
pipe.Del(ctx, redisIPBansKey)
if len(activeCIDRs) > 0 {
pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...)
}
_, err := pipe.Exec(ctx)
return err
}
// ExpireOutdatedBans marks expired bans and removes them from Redis.
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
// Find active bans that have expired
var expiredBans []model.IPBan
if err := s.db.WithContext(ctx).
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
model.IPBanStatusActive,
ctx.Value("now")). // This should be time.Now().Unix() in caller
Find(&expiredBans).Error; err != nil {
return 0, err
}
if len(expiredBans) == 0 {
return 0, nil
}
// Mark as expired
var ids []uint
for _, ban := range expiredBans {
ids = append(ids, ban.ID)
}
result := s.db.WithContext(ctx).
Model(&model.IPBan{}).
Where("id IN ?", ids).
Update("status", model.IPBanStatusExpired)
if result.Error != nil {
return 0, result.Error
}
// Remove from Redis
for _, ban := range expiredBans {
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err)
}
}
return result.RowsAffected, nil
}
// SyncHitCounts syncs hit counts from Redis to database.
func (s *IPBanService) SyncHitCounts(ctx context.Context) error {
if s.rdb == nil {
return nil
}
// Get all active bans
var bans []model.IPBan
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
return err
}
for _, ban := range bans {
hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR)
// Get and reset the counter atomically
count, err := s.rdb.GetDel(ctx, hitKey).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err)
continue
}
if count > 0 {
// Add to database hit_count
if err := s.db.WithContext(ctx).
Model(&model.IPBan{}).
Where("id = ?", ban.ID).
Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil {
s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err)
}
}
}
return nil
}