From 830c6fa6e74662387326099b4a0f864a89d3e9b6 Mon Sep 17 00:00:00 2001 From: zenfun Date: Sun, 4 Jan 2026 01:28:43 +0800 Subject: [PATCH] feat(cron): implement IP ban maintenance tasks Add IPBanManager to handle periodic background jobs including: - Expiring outdated bans - Syncing hit counts from Redis to DB - Performing full Redis state synchronization Additionally, update the service expiration logic to use system time and add unit tests for CIDR normalization and overlap checking. --- internal/cron/ip_ban_manager.go | 61 +++++++++++++ internal/service/ip_ban.go | 5 +- internal/service/ip_ban_test.go | 152 ++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 internal/cron/ip_ban_manager.go create mode 100644 internal/service/ip_ban_test.go diff --git a/internal/cron/ip_ban_manager.go b/internal/cron/ip_ban_manager.go new file mode 100644 index 0000000..cb849f6 --- /dev/null +++ b/internal/cron/ip_ban_manager.go @@ -0,0 +1,61 @@ +package cron + +import ( + "context" + "log/slog" + + "github.com/ez-api/ez-api/internal/service" +) + +// IPBanManager handles periodic IP ban maintenance tasks. +type IPBanManager struct { + service *service.IPBanService + logger *slog.Logger +} + +// NewIPBanManager creates a new IPBanManager. +func NewIPBanManager(service *service.IPBanService) *IPBanManager { + return &IPBanManager{ + service: service, + logger: slog.Default(), + } +} + +// ExpireRunOnce checks for expired bans and marks them. Called by scheduler. +func (m *IPBanManager) ExpireRunOnce(ctx context.Context) { + if m == nil || m.service == nil { + return + } + + count, err := m.service.ExpireOutdatedBans(ctx) + if err != nil { + m.logger.Error("failed to expire outdated IP bans", "err", err) + return + } + + if count > 0 { + m.logger.Info("expired outdated IP bans", "count", count) + } +} + +// HitSyncRunOnce syncs hit counts from Redis to database. Called by scheduler. +func (m *IPBanManager) HitSyncRunOnce(ctx context.Context) { + if m == nil || m.service == nil { + return + } + + if err := m.service.SyncHitCounts(ctx); err != nil { + m.logger.Error("failed to sync IP ban hit counts", "err", err) + } +} + +// FullSyncRunOnce performs a full sync of all active bans to Redis. Called by scheduler. +func (m *IPBanManager) FullSyncRunOnce(ctx context.Context) { + if m == nil || m.service == nil { + return + } + + if err := m.service.SyncAllToRedis(ctx); err != nil { + m.logger.Error("failed to sync IP bans to Redis", "err", err) + } +} diff --git a/internal/service/ip_ban.go b/internal/service/ip_ban.go index 1b1afff..9d0d167 100644 --- a/internal/service/ip_ban.go +++ b/internal/service/ip_ban.go @@ -7,6 +7,7 @@ import ( "log/slog" "net" "strings" + "time" "github.com/ez-api/ez-api/internal/model" "github.com/redis/go-redis/v9" @@ -288,12 +289,14 @@ func (s *IPBanService) SyncAllToRedis(ctx context.Context) error { // ExpireOutdatedBans marks expired bans and removes them from Redis. func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) { + now := time.Now().Unix() + // Find active bans that have expired var expiredBans []model.IPBan if err := s.db.WithContext(ctx). Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?", model.IPBanStatusActive, - ctx.Value("now")). // This should be time.Now().Unix() in caller + now). Find(&expiredBans).Error; err != nil { return 0, err } diff --git a/internal/service/ip_ban_test.go b/internal/service/ip_ban_test.go new file mode 100644 index 0000000..264f13d --- /dev/null +++ b/internal/service/ip_ban_test.go @@ -0,0 +1,152 @@ +package service + +import ( + "testing" +) + +func TestNormalizeCIDR_IPv4Address(t *testing.T) { + tests := []struct { + input string + expected string + hasError bool + }{ + {"10.0.0.1", "10.0.0.1/32", false}, + {"192.168.1.100", "192.168.1.100/32", false}, + {"8.8.8.8", "8.8.8.8/32", false}, + } + + for _, tt := range tests { + result, err := NormalizeCIDR(tt.input) + if tt.hasError { + if err == nil { + t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) + continue + } + if result != tt.expected { + t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) + } + } +} + +func TestNormalizeCIDR_IPv4CIDR(t *testing.T) { + tests := []struct { + input string + expected string + hasError bool + }{ + {"10.0.0.0/24", "10.0.0.0/24", false}, + {"192.168.1.0/16", "192.168.0.0/16", false}, // Normalized to network address + {"172.16.0.1/32", "172.16.0.1/32", false}, + } + + for _, tt := range tests { + result, err := NormalizeCIDR(tt.input) + if tt.hasError { + if err == nil { + t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) + continue + } + if result != tt.expected { + t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) + } + } +} + +func TestNormalizeCIDR_IPv6Address(t *testing.T) { + tests := []struct { + input string + expected string + hasError bool + }{ + {"2001:db8::1", "2001:db8::1/128", false}, + {"::1", "::1/128", false}, + } + + for _, tt := range tests { + result, err := NormalizeCIDR(tt.input) + if tt.hasError { + if err == nil { + t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) + continue + } + if result != tt.expected { + t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) + } + } +} + +func TestNormalizeCIDR_Invalid(t *testing.T) { + tests := []string{ + "", + "not-an-ip", + "10.0.0.256", + "10.0.0.0/33", + "10.0.0.0/abc", + } + + for _, input := range tests { + _, err := NormalizeCIDR(input) + if err == nil { + t.Errorf("NormalizeCIDR(%q): expected error for invalid input", input) + } + } +} + +func TestCIDROverlaps_NoOverlap(t *testing.T) { + tests := []struct { + a, b string + }{ + {"10.0.0.0/24", "192.168.1.0/24"}, + {"10.0.0.0/24", "10.0.1.0/24"}, + {"10.0.0.0/32", "10.0.0.1/32"}, + } + + for _, tt := range tests { + if CIDROverlaps(tt.a, tt.b) { + t.Errorf("CIDROverlaps(%q, %q): expected no overlap", tt.a, tt.b) + } + } +} + +func TestCIDROverlaps_HasOverlap(t *testing.T) { + tests := []struct { + a, b string + }{ + {"10.0.0.0/24", "10.0.0.0/24"}, // Exact same + {"10.0.0.0/24", "10.0.0.0/25"}, // a contains b + {"10.0.0.0/25", "10.0.0.0/24"}, // b contains a + {"10.0.0.0/16", "10.0.1.0/24"}, // a contains b + {"10.0.0.1/32", "10.0.0.0/24"}, // b contains a + } + + for _, tt := range tests { + if !CIDROverlaps(tt.a, tt.b) { + t.Errorf("CIDROverlaps(%q, %q): expected overlap", tt.a, tt.b) + } + } +} + +func TestCIDROverlaps_InvalidCIDR(t *testing.T) { + // Invalid CIDR should return false (no overlap) + if CIDROverlaps("not-valid", "10.0.0.0/24") { + t.Error("expected no overlap with invalid CIDR") + } + if CIDROverlaps("10.0.0.0/24", "not-valid") { + t.Error("expected no overlap with invalid CIDR") + } +}