mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(cron): implement IP ban maintenance tasks
Add IPBanManager to handle periodic background jobs including: - Expiring outdated bans - Syncing hit counts from Redis to DB - Performing full Redis state synchronization Additionally, update the service expiration logic to use system time and add unit tests for CIDR normalization and overlap checking.
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -288,12 +289,14 @@ func (s *IPBanService) SyncAllToRedis(ctx context.Context) error {
|
||||
|
||||
// ExpireOutdatedBans marks expired bans and removes them from Redis.
|
||||
func (s *IPBanService) ExpireOutdatedBans(ctx context.Context) (int64, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Find active bans that have expired
|
||||
var expiredBans []model.IPBan
|
||||
if err := s.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at IS NOT NULL AND expires_at <= ?",
|
||||
model.IPBanStatusActive,
|
||||
ctx.Value("now")). // This should be time.Now().Unix() in caller
|
||||
now).
|
||||
Find(&expiredBans).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
152
internal/service/ip_ban_test.go
Normal file
152
internal/service/ip_ban_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeCIDR_IPv4Address(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{"10.0.0.1", "10.0.0.1/32", false},
|
||||
{"192.168.1.100", "192.168.1.100/32", false},
|
||||
{"8.8.8.8", "8.8.8.8/32", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result, err := NormalizeCIDR(tt.input)
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||
continue
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCIDR_IPv4CIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{"10.0.0.0/24", "10.0.0.0/24", false},
|
||||
{"192.168.1.0/16", "192.168.0.0/16", false}, // Normalized to network address
|
||||
{"172.16.0.1/32", "172.16.0.1/32", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result, err := NormalizeCIDR(tt.input)
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||
continue
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCIDR_IPv6Address(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{"2001:db8::1", "2001:db8::1/128", false},
|
||||
{"::1", "::1/128", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result, err := NormalizeCIDR(tt.input)
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
||||
continue
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCIDR_Invalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"not-an-ip",
|
||||
"10.0.0.256",
|
||||
"10.0.0.0/33",
|
||||
"10.0.0.0/abc",
|
||||
}
|
||||
|
||||
for _, input := range tests {
|
||||
_, err := NormalizeCIDR(input)
|
||||
if err == nil {
|
||||
t.Errorf("NormalizeCIDR(%q): expected error for invalid input", input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDROverlaps_NoOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
}{
|
||||
{"10.0.0.0/24", "192.168.1.0/24"},
|
||||
{"10.0.0.0/24", "10.0.1.0/24"},
|
||||
{"10.0.0.0/32", "10.0.0.1/32"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if CIDROverlaps(tt.a, tt.b) {
|
||||
t.Errorf("CIDROverlaps(%q, %q): expected no overlap", tt.a, tt.b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDROverlaps_HasOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
}{
|
||||
{"10.0.0.0/24", "10.0.0.0/24"}, // Exact same
|
||||
{"10.0.0.0/24", "10.0.0.0/25"}, // a contains b
|
||||
{"10.0.0.0/25", "10.0.0.0/24"}, // b contains a
|
||||
{"10.0.0.0/16", "10.0.1.0/24"}, // a contains b
|
||||
{"10.0.0.1/32", "10.0.0.0/24"}, // b contains a
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if !CIDROverlaps(tt.a, tt.b) {
|
||||
t.Errorf("CIDROverlaps(%q, %q): expected overlap", tt.a, tt.b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDROverlaps_InvalidCIDR(t *testing.T) {
|
||||
// Invalid CIDR should return false (no overlap)
|
||||
if CIDROverlaps("not-valid", "10.0.0.0/24") {
|
||||
t.Error("expected no overlap with invalid CIDR")
|
||||
}
|
||||
if CIDROverlaps("10.0.0.0/24", "not-valid") {
|
||||
t.Error("expected no overlap with invalid CIDR")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user