mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(cron): implement IP ban maintenance tasks
Add IPBanManager to handle periodic background jobs including: - Expiring outdated bans - Syncing hit counts from Redis to DB - Performing full Redis state synchronization Additionally, update the service expiration logic to use system time and add unit tests for CIDR normalization and overlap checking.
This commit is contained in:
61
internal/cron/ip_ban_manager.go
Normal file
61
internal/cron/ip_ban_manager.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package cron
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPBanManager handles periodic IP ban maintenance tasks.
|
||||||
|
type IPBanManager struct {
|
||||||
|
service *service.IPBanService
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPBanManager creates a new IPBanManager.
|
||||||
|
func NewIPBanManager(service *service.IPBanService) *IPBanManager {
|
||||||
|
return &IPBanManager{
|
||||||
|
service: service,
|
||||||
|
logger: slog.Default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpireRunOnce checks for expired bans and marks them. Called by scheduler.
|
||||||
|
func (m *IPBanManager) ExpireRunOnce(ctx context.Context) {
|
||||||
|
if m == nil || m.service == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := m.service.ExpireOutdatedBans(ctx)
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Error("failed to expire outdated IP bans", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
m.logger.Info("expired outdated IP bans", "count", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitSyncRunOnce syncs hit counts from Redis to database. Called by scheduler.
|
||||||
|
func (m *IPBanManager) HitSyncRunOnce(ctx context.Context) {
|
||||||
|
if m == nil || m.service == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.service.SyncHitCounts(ctx); err != nil {
|
||||||
|
m.logger.Error("failed to sync IP ban hit counts", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FullSyncRunOnce performs a full sync of all active bans to Redis. Called by scheduler.
|
||||||
|
func (m *IPBanManager) FullSyncRunOnce(ctx context.Context) {
|
||||||
|
if m == nil || m.service == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.service.SyncAllToRedis(ctx); err != nil {
|
||||||
|
m.logger.Error("failed to sync IP bans to Redis", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -288,12 +289,14 @@ func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
|
|||||||
|
|
||||||
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
||||||
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
// Find active bans that have expired
|
// Find active bans that have expired
|
||||||
var expiredBans []model.IPBan
|
var expiredBans []model.IPBan
|
||||||
if err := s.db.WithContext(ctx).
|
if err := s.db.WithContext(ctx).
|
||||||
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
||||||
model.IPBanStatusActive,
|
model.IPBanStatusActive,
|
||||||
ctx.Value("now")). // This should be time.Now().Unix() in caller
|
now).
|
||||||
Find(&expiredBans).Error; err != nil {
|
Find(&expiredBans).Error; err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
152
internal/service/ip_ban_test.go
Normal file
152
internal/service/ip_ban_test.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeCIDR_IPv4Address(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{"10.0.0.1", "10.0.0.1/32", false},
|
||||||
|
{"192.168.1.100", "192.168.1.100/32", false},
|
||||||
|
{"8.8.8.8", "8.8.8.8/32", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result, err := NormalizeCIDR(tt.input)
|
||||||
|
if tt.hasError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCIDR_IPv4CIDR(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{"10.0.0.0/24", "10.0.0.0/24", false},
|
||||||
|
{"192.168.1.0/16", "192.168.0.0/16", false}, // Normalized to network address
|
||||||
|
{"172.16.0.1/32", "172.16.0.1/32", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result, err := NormalizeCIDR(tt.input)
|
||||||
|
if tt.hasError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCIDR_IPv6Address(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{"2001:db8::1", "2001:db8::1/128", false},
|
||||||
|
{"::1", "::1/128", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result, err := NormalizeCIDR(tt.input)
|
||||||
|
if tt.hasError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCIDR_Invalid(t *testing.T) {
|
||||||
|
tests := []string{
|
||||||
|
"",
|
||||||
|
"not-an-ip",
|
||||||
|
"10.0.0.256",
|
||||||
|
"10.0.0.0/33",
|
||||||
|
"10.0.0.0/abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, input := range tests {
|
||||||
|
_, err := NormalizeCIDR(input)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("NormalizeCIDR(%q): expected error for invalid input", input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDROverlaps_NoOverlap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b string
|
||||||
|
}{
|
||||||
|
{"10.0.0.0/24", "192.168.1.0/24"},
|
||||||
|
{"10.0.0.0/24", "10.0.1.0/24"},
|
||||||
|
{"10.0.0.0/32", "10.0.0.1/32"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if CIDROverlaps(tt.a, tt.b) {
|
||||||
|
t.Errorf("CIDROverlaps(%q, %q): expected no overlap", tt.a, tt.b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDROverlaps_HasOverlap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b string
|
||||||
|
}{
|
||||||
|
{"10.0.0.0/24", "10.0.0.0/24"}, // Exact same
|
||||||
|
{"10.0.0.0/24", "10.0.0.0/25"}, // a contains b
|
||||||
|
{"10.0.0.0/25", "10.0.0.0/24"}, // b contains a
|
||||||
|
{"10.0.0.0/16", "10.0.1.0/24"}, // a contains b
|
||||||
|
{"10.0.0.1/32", "10.0.0.0/24"}, // b contains a
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if !CIDROverlaps(tt.a, tt.b) {
|
||||||
|
t.Errorf("CIDROverlaps(%q, %q): expected overlap", tt.a, tt.b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDROverlaps_InvalidCIDR(t *testing.T) {
|
||||||
|
// Invalid CIDR should return false (no overlap)
|
||||||
|
if CIDROverlaps("not-valid", "10.0.0.0/24") {
|
||||||
|
t.Error("expected no overlap with invalid CIDR")
|
||||||
|
}
|
||||||
|
if CIDROverlaps("10.0.0.0/24", "not-valid") {
|
||||||
|
t.Error("expected no overlap with invalid CIDR")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user