mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
- Initialize and schedule IP ban maintenance tasks in server entry point - Perform initial IP ban sync to Redis on startup - Implement optional JSON unmarshalling to handle null `expires_at` in API - Add CIDR overlap validation when updating rule status to active
383 lines
9.9 KiB
Go
383 lines
9.9 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidCIDR = errors.New("invalid CIDR format")
|
|
ErrCIDROverlap = errors.New("CIDR overlaps with existing active rule")
|
|
ErrIPBanNotFound = errors.New("IP ban not found")
|
|
ErrDuplicateCIDR = errors.New("CIDR already exists")
|
|
)
|
|
|
|
// IPBanService handles global IP ban operations.
|
|
type IPBanService struct {
|
|
db *gorm.DB
|
|
rdb *redis.Client
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewIPBanService creates a new IPBanService.
|
|
func NewIPBanService(db *gorm.DB, rdb *redis.Client) *IPBanService {
|
|
return &IPBanService{
|
|
db: db,
|
|
rdb: rdb,
|
|
logger: slog.Default(),
|
|
}
|
|
}
|
|
|
|
// NormalizeCIDR normalizes an IP or CIDR string to canonical CIDR format.
|
|
// - IPv4 addresses are converted to /32
|
|
// - IPv6 addresses are converted to /128
|
|
// - CIDR strings are validated and normalized
|
|
func NormalizeCIDR(input string) (string, error) {
|
|
input = strings.TrimSpace(input)
|
|
if input == "" {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
|
|
// Check if it's a CIDR
|
|
if strings.Contains(input, "/") {
|
|
_, ipnet, err := net.ParseCIDR(input)
|
|
if err != nil {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
// Return the network base IP with mask
|
|
ones, _ := ipnet.Mask.Size()
|
|
return fmt.Sprintf("%s/%d", ipnet.IP.String(), ones), nil
|
|
}
|
|
|
|
// It's a plain IP address
|
|
ip := net.ParseIP(input)
|
|
if ip == nil {
|
|
return "", ErrInvalidCIDR
|
|
}
|
|
|
|
// Determine if IPv4 or IPv6
|
|
if ip.To4() != nil {
|
|
return ip.String() + "/32", nil
|
|
}
|
|
return ip.String() + "/128", nil
|
|
}
|
|
|
|
// CIDROverlaps checks if two CIDR ranges overlap.
|
|
func CIDROverlaps(a, b string) bool {
|
|
_, netA, errA := net.ParseCIDR(a)
|
|
_, netB, errB := net.ParseCIDR(b)
|
|
if errA != nil || errB != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if either network contains the other's base IP
|
|
return netA.Contains(netB.IP) || netB.Contains(netA.IP)
|
|
}
|
|
|
|
// CreateIPBanRequest represents a request to create an IP ban.
|
|
type CreateIPBanRequest struct {
|
|
CIDR string `json:"cidr" binding:"required"`
|
|
Reason string `json:"reason,omitempty"`
|
|
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
|
CreatedBy string `json:"created_by,omitempty"`
|
|
}
|
|
|
|
// UpdateIPBanRequest represents a request to update an IP ban.
|
|
type UpdateIPBanRequest struct {
|
|
Reason *string `json:"reason,omitempty"`
|
|
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
|
ExpiresAtSet bool `json:"-"`
|
|
Status *string `json:"status,omitempty"`
|
|
}
|
|
|
|
// Create creates a new IP ban with validation.
|
|
func (s *IPBanService) Create(ctx context.Context, req CreateIPBanRequest) (*model.IPBan, error) {
|
|
// Normalize CIDR
|
|
normalized, err := NormalizeCIDR(req.CIDR)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for existing rule with same CIDR
|
|
var existing model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("cidr = ?", normalized).First(&existing).Error; err == nil {
|
|
return nil, ErrDuplicateCIDR
|
|
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for overlapping active rules
|
|
var activeRules []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&activeRules).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, rule := range activeRules {
|
|
if CIDROverlaps(normalized, rule.CIDR) {
|
|
return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR)
|
|
}
|
|
}
|
|
|
|
// Create the ban
|
|
ban := &model.IPBan{
|
|
CIDR: normalized,
|
|
Status: model.IPBanStatusActive,
|
|
Reason: req.Reason,
|
|
ExpiresAt: req.ExpiresAt,
|
|
CreatedBy: req.CreatedBy,
|
|
}
|
|
|
|
if err := s.db.WithContext(ctx).Create(ban).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Sync to Redis
|
|
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
|
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
|
// Don't fail the create operation, just log the error
|
|
}
|
|
|
|
return ban, nil
|
|
}
|
|
|
|
// Get retrieves an IP ban by ID.
|
|
func (s *IPBanService) Get(ctx context.Context, id uint) (*model.IPBan, error) {
|
|
var ban model.IPBan
|
|
if err := s.db.WithContext(ctx).First(&ban, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrIPBanNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &ban, nil
|
|
}
|
|
|
|
// List retrieves IP bans with optional status filter.
|
|
func (s *IPBanService) List(ctx context.Context, status string) ([]model.IPBan, error) {
|
|
var bans []model.IPBan
|
|
query := s.db.WithContext(ctx)
|
|
if status != "" {
|
|
query = query.Where("status = ?", status)
|
|
}
|
|
if err := query.Order("created_at DESC").Find(&bans).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return bans, nil
|
|
}
|
|
|
|
// Update updates an existing IP ban.
|
|
func (s *IPBanService) Update(ctx context.Context, id uint, req UpdateIPBanRequest) (*model.IPBan, error) {
|
|
ban, err := s.Get(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if req.Status != nil && *req.Status == model.IPBanStatusActive && ban.Status != model.IPBanStatusActive {
|
|
var activeRules []model.IPBan
|
|
if err := s.db.WithContext(ctx).
|
|
Where("status = ? AND id <> ?", model.IPBanStatusActive, ban.ID).
|
|
Find(&activeRules).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
for _, rule := range activeRules {
|
|
if CIDROverlaps(ban.CIDR, rule.CIDR) {
|
|
return nil, fmt.Errorf("%w: overlaps with %s", ErrCIDROverlap, rule.CIDR)
|
|
}
|
|
}
|
|
}
|
|
|
|
updates := make(map[string]interface{})
|
|
if req.Reason != nil {
|
|
updates["reason"] = *req.Reason
|
|
}
|
|
if req.ExpiresAtSet {
|
|
updates["expires_at"] = req.ExpiresAt
|
|
}
|
|
if req.Status != nil {
|
|
if *req.Status != model.IPBanStatusActive && *req.Status != model.IPBanStatusExpired {
|
|
return nil, fmt.Errorf("invalid status: %s", *req.Status)
|
|
}
|
|
updates["status"] = *req.Status
|
|
}
|
|
|
|
if len(updates) > 0 {
|
|
if err := s.db.WithContext(ctx).Model(ban).Updates(updates).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
// Reload to get updated values
|
|
if err := s.db.WithContext(ctx).First(ban, id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Sync to Redis
|
|
if err := s.syncBanToRedis(ctx, ban); err != nil {
|
|
s.logger.Error("failed to sync IP ban to Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
|
|
return ban, nil
|
|
}
|
|
|
|
// Delete removes an IP ban.
|
|
func (s *IPBanService) Delete(ctx context.Context, id uint) error {
|
|
ban, err := s.Get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Remove from Redis first
|
|
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
|
s.logger.Error("failed to remove IP ban from Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
|
|
// Delete from database (hard delete, not soft delete)
|
|
if err := s.db.WithContext(ctx).Unscoped().Delete(ban).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Redis key for global IP bans SET
|
|
const redisIPBansKey = "global:ip-bans"
|
|
|
|
// syncBanToRedis syncs a ban to Redis.
|
|
func (s *IPBanService) syncBanToRedis(ctx context.Context, ban *model.IPBan) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Only sync active, non-expired bans
|
|
if ban.Status != model.IPBanStatusActive || ban.IsExpired() {
|
|
return s.removeBanFromRedis(ctx, ban.CIDR)
|
|
}
|
|
|
|
return s.rdb.SAdd(ctx, redisIPBansKey, ban.CIDR).Err()
|
|
}
|
|
|
|
// removeBanFromRedis removes a CIDR from Redis.
|
|
func (s *IPBanService) removeBanFromRedis(ctx context.Context, cidr string) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
return s.rdb.SRem(ctx, redisIPBansKey, cidr).Err()
|
|
}
|
|
|
|
// SyncAllToRedis rebuilds the Redis SET with all active, non-expired bans.
|
|
func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get all active bans
|
|
var bans []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Filter out expired ones
|
|
var activeCIDRs []interface{}
|
|
for _, ban := range bans {
|
|
if !ban.IsExpired() {
|
|
activeCIDRs = append(activeCIDRs, ban.CIDR)
|
|
}
|
|
}
|
|
|
|
// Use a pipeline to atomically replace the set
|
|
pipe := s.rdb.Pipeline()
|
|
pipe.Del(ctx, redisIPBansKey)
|
|
if len(activeCIDRs) > 0 {
|
|
pipe.SAdd(ctx, redisIPBansKey, activeCIDRs...)
|
|
}
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
|
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
|
now := time.Now().Unix()
|
|
|
|
// Find active bans that have expired
|
|
var expiredBans []model.IPBan
|
|
if err := s.db.WithContext(ctx).
|
|
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
|
model.IPBanStatusActive,
|
|
now).
|
|
Find(&expiredBans).Error; err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if len(expiredBans) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Mark as expired
|
|
var ids []uint
|
|
for _, ban := range expiredBans {
|
|
ids = append(ids, ban.ID)
|
|
}
|
|
|
|
result := s.db.WithContext(ctx).
|
|
Model(&model.IPBan{}).
|
|
Where("id IN ?", ids).
|
|
Update("status", model.IPBanStatusExpired)
|
|
|
|
if result.Error != nil {
|
|
return 0, result.Error
|
|
}
|
|
|
|
// Remove from Redis
|
|
for _, ban := range expiredBans {
|
|
if err := s.removeBanFromRedis(ctx, ban.CIDR); err != nil {
|
|
s.logger.Error("failed to remove expired ban from Redis", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// SyncHitCounts syncs hit counts from Redis to database.
|
|
func (s *IPBanService) SyncHitCounts(ctx context.Context) error {
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get all active bans
|
|
var bans []model.IPBan
|
|
if err := s.db.WithContext(ctx).Where("status = ?", model.IPBanStatusActive).Find(&bans).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, ban := range bans {
|
|
hitKey := fmt.Sprintf("global:ip-bans:hits:%s", ban.CIDR)
|
|
|
|
// Get and reset the counter atomically
|
|
count, err := s.rdb.GetDel(ctx, hitKey).Int64()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
s.logger.Error("failed to get hit count from Redis", "cidr", ban.CIDR, "err", err)
|
|
continue
|
|
}
|
|
|
|
if count > 0 {
|
|
// Add to database hit_count
|
|
if err := s.db.WithContext(ctx).
|
|
Model(&model.IPBan{}).
|
|
Where("id = ?", ban.ID).
|
|
Update("hit_count", gorm.Expr("hit_count + ?", count)).Error; err != nil {
|
|
s.logger.Error("failed to update hit count in database", "cidr", ban.CIDR, "err", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|