package service import ( "context" "errors" "testing" "time" "github.com/ez-api/ez-api/internal/model" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func TestNormalizeCIDR_IPv4Address(t *testing.T) { tests := []struct { input string expected string hasError bool }{ {"10.0.0.1", "10.0.0.1/32", false}, {"192.168.1.100", "192.168.1.100/32", false}, {"8.8.8.8", "8.8.8.8/32", false}, } for _, tt := range tests { result, err := NormalizeCIDR(tt.input) if tt.hasError { if err == nil { t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) } continue } if err != nil { t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) continue } if result != tt.expected { t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) } } } func TestNormalizeCIDR_IPv4CIDR(t *testing.T) { tests := []struct { input string expected string hasError bool }{ {"10.0.0.0/24", "10.0.0.0/24", false}, {"192.168.1.0/16", "192.168.0.0/16", false}, // Normalized to network address {"172.16.0.1/32", "172.16.0.1/32", false}, } for _, tt := range tests { result, err := NormalizeCIDR(tt.input) if tt.hasError { if err == nil { t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) } continue } if err != nil { t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) continue } if result != tt.expected { t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) } } } func TestNormalizeCIDR_IPv6Address(t *testing.T) { tests := []struct { input string expected string hasError bool }{ {"2001:db8::1", "2001:db8::1/128", false}, {"::1", "::1/128", false}, } for _, tt := range tests { result, err := NormalizeCIDR(tt.input) if tt.hasError { if err == nil { t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input) } continue } if err != nil { t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err) continue } if result != tt.expected { t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result) } } } func TestNormalizeCIDR_Invalid(t *testing.T) { tests := []string{ "", "not-an-ip", "10.0.0.256", "10.0.0.0/33", "10.0.0.0/abc", } for _, input := range tests { _, err := NormalizeCIDR(input) if err == nil { t.Errorf("NormalizeCIDR(%q): expected error for invalid input", input) } } } func TestCIDROverlaps_NoOverlap(t *testing.T) { tests := []struct { a, b string }{ {"10.0.0.0/24", "192.168.1.0/24"}, {"10.0.0.0/24", "10.0.1.0/24"}, {"10.0.0.0/32", "10.0.0.1/32"}, } for _, tt := range tests { if CIDROverlaps(tt.a, tt.b) { t.Errorf("CIDROverlaps(%q, %q): expected no overlap", tt.a, tt.b) } } } func TestCIDROverlaps_HasOverlap(t *testing.T) { tests := []struct { a, b string }{ {"10.0.0.0/24", "10.0.0.0/24"}, // Exact same {"10.0.0.0/24", "10.0.0.0/25"}, // a contains b {"10.0.0.0/25", "10.0.0.0/24"}, // b contains a {"10.0.0.0/16", "10.0.1.0/24"}, // a contains b {"10.0.0.1/32", "10.0.0.0/24"}, // b contains a } for _, tt := range tests { if !CIDROverlaps(tt.a, tt.b) { t.Errorf("CIDROverlaps(%q, %q): expected overlap", tt.a, tt.b) } } } func TestCIDROverlaps_InvalidCIDR(t *testing.T) { // Invalid CIDR should return false (no overlap) if CIDROverlaps("not-valid", "10.0.0.0/24") { t.Error("expected no overlap with invalid CIDR") } if CIDROverlaps("10.0.0.0/24", "not-valid") { t.Error("expected no overlap with invalid CIDR") } } func TestIPBanService_Update_ReactivateOverlap(t *testing.T) { t.Parallel() db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.IPBan{}); err != nil { t.Fatalf("migrate: %v", err) } active := model.IPBan{CIDR: "10.0.0.0/24", Status: model.IPBanStatusActive} if err := db.Create(&active).Error; err != nil { t.Fatalf("create active ban: %v", err) } expired := model.IPBan{CIDR: "10.0.0.128/25", Status: model.IPBanStatusExpired} if err := db.Create(&expired).Error; err != nil { t.Fatalf("create expired ban: %v", err) } svc := NewIPBanService(db, nil) _, err = svc.Update(context.Background(), expired.ID, UpdateIPBanRequest{ Status: func() *string { s := model.IPBanStatusActive return &s }(), }) if !errors.Is(err, ErrCIDROverlap) { t.Fatalf("expected overlap error, got %v", err) } var reloaded model.IPBan if err := db.First(&reloaded, expired.ID).Error; err != nil { t.Fatalf("reload expired ban: %v", err) } if reloaded.Status != model.IPBanStatusExpired { t.Fatalf("expected status to remain expired, got %s", reloaded.Status) } } func TestIPBanService_Update_ClearExpiresAt(t *testing.T) { t.Parallel() db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.IPBan{}); err != nil { t.Fatalf("migrate: %v", err) } exp := time.Now().Add(time.Hour).Unix() ban := model.IPBan{ CIDR: "192.168.1.1/32", Status: model.IPBanStatusActive, ExpiresAt: &exp, } if err := db.Create(&ban).Error; err != nil { t.Fatalf("create ban: %v", err) } svc := NewIPBanService(db, nil) if _, err := svc.Update(context.Background(), ban.ID, UpdateIPBanRequest{ ExpiresAt: nil, ExpiresAtSet: true, }); err != nil { t.Fatalf("update expires_at: %v", err) } var reloaded model.IPBan if err := db.First(&reloaded, ban.ID).Error; err != nil { t.Fatalf("reload ban: %v", err) } if reloaded.ExpiresAt != nil { t.Fatalf("expected expires_at to be cleared, got %v", *reloaded.ExpiresAt) } }