mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(service): implement IP ban service logic
Add IPBanService to manage global IP bans with Redis synchronization for high-performance filtering. Includes logic for CIDR normalization, overlap detection, hit count tracking, and rule expiration.
This commit is contained in:
364
internal/service/ip_ban.go
Normal file
364
internal/service/ip_ban.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCIDR = errors.New("invalid CIDR format")
|
||||
ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule")
|
||||
ErrIPBanNotFound = errors.New("IP ban not found")
|
||||
ErrDuplicateCIDR = errors.New("CIDR already exists")
|
||||
)
|
||||
|
||||
// IPBanService handles global IP ban operations.
|
||||
type IPBanService struct {
|
||||
db *gorm.DB
|
||||
rdb *redis.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewIPBanService creates a new IPBanService.
|
||||
func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService {
|
||||
return &IPBanService{
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format.
|
||||
// - IPv4 addresses are converted to /32
|
||||
// - IPv6 addresses are converted to /128
|
||||
// - CIDR strings are validated and normalized
|
||||
func NormalizeCIDR(input string) (string, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", ErrInvalidCIDR
|
||||
}
|
||||
|
||||
// Check if it's a CIDR
|
||||
if strings.Contains(input, "/") {
|
||||
_, ipnet, err := net.ParseCIDR(input)
|
||||
if err != nil {
|
||||
return "", ErrInvalidCIDR
|
||||
}
|
||||
// Return the network base IP with mask
|
||||
ones, _ := ipnet.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil
|
||||
}
|
||||
|
||||
// It's a plain IP address
|
||||
ip := net.ParseIP(input)
|
||||
if ip == nil {
|
||||
return "", ErrInvalidCIDR
|
||||
}
|
||||
|
||||
// Determine if IPv4 or IPv6
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32", nil
|
||||
}
|
||||
return ip.String() + "/128", nil
|
||||
}
|
||||
|
||||
// CIDROverlaps checks if two CIDR ranges overlap.
|
||||
func CIDROverlaps(a, b string) bool {
|
||||
_, netA, errA := net.ParseCIDR(a)
|
||||
_, netB, errB := net.ParseCIDR(b)
|
||||
if errA != nil || errB != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if either network contains the other's base IP
|
||||
return netA.Contains(netB.IP) || netB.Contains(netA.IP)
|
||||
}
|
||||
|
||||
// CreateIPBanRequest represents a request to create an IP ban.
|
||||
type CreateIPBanRequest struct {
|
||||
CIDR string `json:"cidr" binding:"required"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
CreatedBy string `json:"created_by,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateIPBanRequest represents a request to update an IP ban.
|
||||
type UpdateIPBanRequest struct {
|
||||
Reason *string `json:"reason,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null"
|
||||
Status *string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// Create creates a new IP ban with validation.
|
||||
func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) {
|
||||
// Normalize CIDR
|
||||
normalized, err := NormalizeCIDR(req.CIDR)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for existing rule with same CIDR
|
||||
var existing model.IPBan
|
||||
if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil {
|
||||
return nil, ErrDuplicateCIDR
|
||||
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for overlapping active rules
|
||||
var activeRules []model.IPBan
|
||||
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, rule := range activeRules {
|
||||
if CIDROverlaps(normalized, rule.CIDR) {
|
||||
return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the ban
|
||||
ban := &model.IPBan{
|
||||
CIDR: normalized,
|
||||
Status: model.IPBanStatusActive,
|
||||
Reason: req.Reason,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
CreatedBy: req.CreatedBy,
|
||||
}
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(ban).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sync to Redis
|
||||
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
||||
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
||||
// Don't fail the create operation, just log the error
|
||||
}
|
||||
|
||||
return ban, nil
|
||||
}
|
||||
|
||||
// Get retrieves an IP ban by ID.
|
||||
func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) {
|
||||
var ban model.IPBan
|
||||
if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrIPBanNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &ban, nil
|
||||
}
|
||||
|
||||
// List retrieves IP bans with optional status filter.
|
||||
func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) {
|
||||
var bans []model.IPBan
|
||||
query := s.db.WithContext(ctx)
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if err := query.Order("created_at DESC").Find(&bans).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bans, nil
|
||||
}
|
||||
|
||||
// Update updates an existing IP ban.
|
||||
func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) {
|
||||
ban, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if req.Reason != nil {
|
||||
updates["reason"] = *req.Reason
|
||||
}
|
||||
if req.ExpiresAt != nil {
|
||||
updates["expires_at"] = req.ExpiresAt
|
||||
}
|
||||
if req.Status != nil {
|
||||
if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired {
|
||||
return nil, fmt.Errorf("invalid status: %s", *req.Status)
|
||||
}
|
||||
updates["status"] = *req.Status
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reload to get updated values
|
||||
if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Sync to Redis
|
||||
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
||||
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
||||
}
|
||||
|
||||
return ban, nil
|
||||
}
|
||||
|
||||
// Delete removes an IP ban.
|
||||
func (s *IPBanService) Delete(ctx context.Context, id uint) error {
|
||||
ban, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove from Redis first
|
||||
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
||||
s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err)
|
||||
}
|
||||
|
||||
// Delete from database (hard delete, not soft delete)
|
||||
if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Redis key for global IP bans SET
|
||||
const redisIPBansKey = "global:ip-bans"
|
||||
|
||||
// syncBanToRedis syncs a ban to Redis.
|
||||
func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only sync active, non-expired bans
|
||||
if ban.Status != model.IPBanStatusActive || ban.IsExpired() {
|
||||
return s.removeBanFromRedis(ctx, ban.CIDR)
|
||||
}
|
||||
|
||||
return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err()
|
||||
}
|
||||
|
||||
// removeBanFromRedis removes a CIDR from Redis.
|
||||
func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err()
|
||||
}
|
||||
|
||||
// SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans.
|
||||
func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get all active bans
|
||||
var bans []model.IPBan
|
||||
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter out expired ones
|
||||
var activeCIDRs []interface{}
|
||||
for _, ban := range bans {
|
||||
if !ban.IsExpired() {
|
||||
activeCIDRs = append(activeCIDRs, ban.CIDR)
|
||||
}
|
||||
}
|
||||
|
||||
// Use a pipeline to atomically replace the set
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Del(ctx, redisIPBansKey)
|
||||
if len(activeCIDRs) > 0 {
|
||||
pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...)
|
||||
}
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
||||
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
||||
// Find active bans that have expired
|
||||
var expiredBans []model.IPBan
|
||||
if err := s.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
||||
model.IPBanStatusActive,
|
||||
ctx.Value("now")). // This should be time.Now().Unix() in caller
|
||||
Find(&expiredBans).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(expiredBans) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Mark as expired
|
||||
var ids []uint
|
||||
for _, ban := range expiredBans {
|
||||
ids = append(ids, ban.ID)
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&model.IPBan{}).
|
||||
Where("id IN ?", ids).
|
||||
Update("status", model.IPBanStatusExpired)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
// Remove from Redis
|
||||
for _, ban := range expiredBans {
|
||||
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
||||
s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// SyncHitCounts syncs hit counts from Redis to database.
|
||||
func (s *IPBanService) SyncHitCounts(ctx context.Context) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get all active bans
|
||||
var bans []model.IPBan
|
||||
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ban := range bans {
|
||||
hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR)
|
||||
|
||||
// Get and reset the counter atomically
|
||||
count, err := s.rdb.GetDel(ctx, hitKey).Int64()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// Add to database hit_count
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&model.IPBan{}).
|
||||
Where("id = ?", ban.ID).
|
||||
Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil {
|
||||
s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user