mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Add IPBanManager to handle periodic background jobs including: - Expiring outdated bans - Syncing hit counts from Redis to DB - Performing full Redis state synchronization Additionally, update the service expiration logic to use system time and add unit tests for CIDR normalization and overlap checking.
368 lines
9.4 KiB
Go
368 lines
9.4 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidCIDR = errors.New("invalid CIDR format")
|
|
ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule")
|
|
ErrIPBanNotFound = errors.New("IP ban not found")
|
|
ErrDuplicateCIDR = errors.New("CIDR already exists")
|
|
)
|
|
|
|
// IPBanService handles global IP ban operations.
|
|
type IPBanService struct {
|
|
db *gorm.DB
|
|
rdb *redis.Client
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewIPBanService creates a new IPBanService.
|
|
func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService {
|
|
return &IPBanService{
|
|
db: db,
|
|
rdb: rdb,
|
|
logger: slog.Default(),
|
|
}
|
|
}
|
|
|
|
// NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format.
|
|
// - IPv4 addresses are converted to /32
|
|
// - IPv6 addresses are converted to /128
|
|
// - CIDR strings are validated and normalized
|
|
func NormalizeCIDR(input string) (string, error) {
|
|
input = strings.TrimSpace(input)
|
|
if input == "" {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
|
|
// Check if it's a CIDR
|
|
if strings.Contains(input, "/") {
|
|
_, ipnet, err := net.ParseCIDR(input)
|
|
if err != nil {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
// Return the network base IP with mask
|
|
ones, _ := ipnet.Mask.Size()
|
|
return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil
|
|
}
|
|
|
|
// It's a plain IP address
|
|
ip := net.ParseIP(input)
|
|
if ip == nil {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
|
|
// Determine if IPv4 or IPv6
|
|
if ip.To4() != nil {
|
|
return ip.String() + "/32", nil
|
|
}
|
|
return ip.String() + "/128", nil
|
|
}
|
|
|
|
// CIDROverlaps checks if two CIDR ranges overlap.
|
|
func CIDROverlaps(a, b string) bool {
|
|
_, netA, errA := net.ParseCIDR(a)
|
|
_, netB, errB := net.ParseCIDR(b)
|
|
if errA != nil || errB != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if either network contains the other's base IP
|
|
return netA.Contains(netB.IP) || netB.Contains(netA.IP)
|
|
}
|
|
|
|
// CreateIPBanRequest represents a request to create an IP ban.
|
|
type CreateIPBanRequest struct {
|
|
CIDR string `json:"cidr" binding:"required"`
|
|
Reason string `json:"reason,omitempty"`
|
|
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
|
CreatedBy string `json:"created_by,omitempty"`
|
|
}
|
|
|
|
// UpdateIPBanRequest represents a request to update an IP ban.
|
|
type UpdateIPBanRequest struct {
|
|
Reason *string `json:"reason,omitempty"`
|
|
ExpiresAt *int64 `json:"expires_at,omitempty"` // Use pointer to distinguish between "not set" and "set to null"
|
|
Status *string `json:"status,omitempty"`
|
|
}
|
|
|
|
// Create creates a new IP ban with validation.
|
|
func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) {
|
|
// Normalize CIDR
|
|
normalized, err := NormalizeCIDR(req.CIDR)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for existing rule with same CIDR
|
|
var existing model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil {
|
|
return nil, ErrDuplicateCIDR
|
|
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for overlapping active rules
|
|
var activeRules []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, rule := range activeRules {
|
|
if CIDROverlaps(normalized, rule.CIDR) {
|
|
return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR)
|
|
}
|
|
}
|
|
|
|
// Create the ban
|
|
ban := &model.IPBan{
|
|
CIDR: normalized,
|
|
Status: model.IPBanStatusActive,
|
|
Reason: req.Reason,
|
|
ExpiresAt: req.ExpiresAt,
|
|
CreatedBy: req.CreatedBy,
|
|
}
|
|
|
|
if err := s.db.WithContext(ctx).Create(ban).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Sync to Redis
|
|
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
|
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
|
// Don't fail the create operation, just log the error
|
|
}
|
|
|
|
return ban, nil
|
|
}
|
|
|
|
// Get retrieves an IP ban by ID.
|
|
func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) {
|
|
var ban model.IPBan
|
|
if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrIPBanNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &ban, nil
|
|
}
|
|
|
|
// List retrieves IP bans with optional status filter.
|
|
func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) {
|
|
var bans []model.IPBan
|
|
query := s.db.WithContext(ctx)
|
|
if status != "" {
|
|
query = query.Where("status = ?", status)
|
|
}
|
|
if err := query.Order("created_at DESC").Find(&bans).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return bans, nil
|
|
}
|
|
|
|
// Update updates an existing IP ban.
|
|
func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) {
|
|
ban, err := s.Get(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
updates := make(map[string]interface{})
|
|
if req.Reason != nil {
|
|
updates["reason"] = *req.Reason
|
|
}
|
|
if req.ExpiresAt != nil {
|
|
updates["expires_at"] = req.ExpiresAt
|
|
}
|
|
if req.Status != nil {
|
|
if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired {
|
|
return nil, fmt.Errorf("invalid status: %s", *req.Status)
|
|
}
|
|
updates["status"] = *req.Status
|
|
}
|
|
|
|
if len(updates) > 0 {
|
|
if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
// Reload to get updated values
|
|
if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Sync to Redis
|
|
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
|
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
|
|
return ban, nil
|
|
}
|
|
|
|
// Delete removes an IP ban.
|
|
func (s *IPBanService) Delete(ctx context.Context, id uint) error {
|
|
ban, err := s.Get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Remove from Redis first
|
|
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
|
s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
|
|
// Delete from database (hard delete, not soft delete)
|
|
if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Redis key for global IP bans SET
|
|
const redisIPBansKey = "global:ip-bans"
|
|
|
|
// syncBanToRedis syncs a ban to Redis.
|
|
func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Only sync active, non-expired bans
|
|
if ban.Status != model.IPBanStatusActive || ban.IsExpired() {
|
|
return s.removeBanFromRedis(ctx, ban.CIDR)
|
|
}
|
|
|
|
return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err()
|
|
}
|
|
|
|
// removeBanFromRedis removes a CIDR from Redis.
|
|
func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err()
|
|
}
|
|
|
|
// SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans.
|
|
func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get all active bans
|
|
var bans []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Filter out expired ones
|
|
var activeCIDRs []interface{}
|
|
for _, ban := range bans {
|
|
if !ban.IsExpired() {
|
|
activeCIDRs = append(activeCIDRs, ban.CIDR)
|
|
}
|
|
}
|
|
|
|
// Use a pipeline to atomically replace the set
|
|
pipe := s.rdb.Pipeline()
|
|
pipe.Del(ctx, redisIPBansKey)
|
|
if len(activeCIDRs) > 0 {
|
|
pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...)
|
|
}
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
|
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
|
now := time.Now().Unix()
|
|
|
|
// Find active bans that have expired
|
|
var expiredBans []model.IPBan
|
|
if err := s.db.WithContext(ctx).
|
|
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
|
model.IPBanStatusActive,
|
|
now).
|
|
Find(&expiredBans).Error; err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if len(expiredBans) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Mark as expired
|
|
var ids []uint
|
|
for _, ban := range expiredBans {
|
|
ids = append(ids, ban.ID)
|
|
}
|
|
|
|
result := s.db.WithContext(ctx).
|
|
Model(&model.IPBan{}).
|
|
Where("id IN ?", ids).
|
|
Update("status", model.IPBanStatusExpired)
|
|
|
|
if result.Error != nil {
|
|
return 0, result.Error
|
|
}
|
|
|
|
// Remove from Redis
|
|
for _, ban := range expiredBans {
|
|
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
|
s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// SyncHitCounts syncs hit counts from Redis to database.
|
|
func (s *IPBanService) SyncHitCounts(ctx context.Context) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get all active bans
|
|
var bans []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, ban := range bans {
|
|
hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR)
|
|
|
|
// Get and reset the counter atomically
|
|
count, err := s.rdb.GetDel(ctx, hitKey).Int64()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err)
|
|
continue
|
|
}
|
|
|
|
if count > 0 {
|
|
// Add to database hit_count
|
|
if err := s.db.WithContext(ctx).
|
|
Model(&model.IPBan{}).
|
|
Where("id = ?", ban.ID).
|
|
Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil {
|
|
s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|