mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Add integration tests for `IPBanService.Update` to verify: - Reactivating an expired ban correctly detects overlaps with existing active bans. - Explicitly clearing the `expires_at` field (setting to null) works as expected.
239 lines
5.7 KiB
Go
239 lines
5.7 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func TestNormalizeCIDR_IPv4Address(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
hasError bool
|
|
}{
|
|
{"10.0.0.1", "10.0.0.1/32", false},
|
|
{"192.168.1.100", "192.168.1.100/32", false},
|
|
{"8.8.8.8", "8.8.8.8/32", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result, err := NormalizeCIDR(tt.input)
|
|
if tt.hasError {
|
|
if err == nil {
|
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
|
continue
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNormalizeCIDR_IPv4CIDR(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
hasError bool
|
|
}{
|
|
{"10.0.0.0/24", "10.0.0.0/24", false},
|
|
{"192.168.1.0/16", "192.168.0.0/16", false}, // Normalized to network address
|
|
{"172.16.0.1/32", "172.16.0.1/32", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result, err := NormalizeCIDR(tt.input)
|
|
if tt.hasError {
|
|
if err == nil {
|
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
|
continue
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNormalizeCIDR_IPv6Address(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
hasError bool
|
|
}{
|
|
{"2001:db8::1", "2001:db8::1/128", false},
|
|
{"::1", "::1/128", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result, err := NormalizeCIDR(tt.input)
|
|
if tt.hasError {
|
|
if err == nil {
|
|
t.Errorf("NormalizeCIDR(%q): expected error, got nil", tt.input)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
t.Errorf("NormalizeCIDR(%q): unexpected error: %v", tt.input, err)
|
|
continue
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("NormalizeCIDR(%q): expected %q, got %q", tt.input, tt.expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNormalizeCIDR_Invalid(t *testing.T) {
|
|
tests := []string{
|
|
"",
|
|
"not-an-ip",
|
|
"10.0.0.256",
|
|
"10.0.0.0/33",
|
|
"10.0.0.0/abc",
|
|
}
|
|
|
|
for _, input := range tests {
|
|
_, err := NormalizeCIDR(input)
|
|
if err == nil {
|
|
t.Errorf("NormalizeCIDR(%q): expected error for invalid input", input)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCIDROverlaps_NoOverlap(t *testing.T) {
|
|
tests := []struct {
|
|
a, b string
|
|
}{
|
|
{"10.0.0.0/24", "192.168.1.0/24"},
|
|
{"10.0.0.0/24", "10.0.1.0/24"},
|
|
{"10.0.0.0/32", "10.0.0.1/32"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if CIDROverlaps(tt.a, tt.b) {
|
|
t.Errorf("CIDROverlaps(%q, %q): expected no overlap", tt.a, tt.b)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCIDROverlaps_HasOverlap(t *testing.T) {
|
|
tests := []struct {
|
|
a, b string
|
|
}{
|
|
{"10.0.0.0/24", "10.0.0.0/24"}, // Exact same
|
|
{"10.0.0.0/24", "10.0.0.0/25"}, // a contains b
|
|
{"10.0.0.0/25", "10.0.0.0/24"}, // b contains a
|
|
{"10.0.0.0/16", "10.0.1.0/24"}, // a contains b
|
|
{"10.0.0.1/32", "10.0.0.0/24"}, // b contains a
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if !CIDROverlaps(tt.a, tt.b) {
|
|
t.Errorf("CIDROverlaps(%q, %q): expected overlap", tt.a, tt.b)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCIDROverlaps_InvalidCIDR(t *testing.T) {
|
|
// Invalid CIDR should return false (no overlap)
|
|
if CIDROverlaps("not-valid", "10.0.0.0/24") {
|
|
t.Error("expected no overlap with invalid CIDR")
|
|
}
|
|
if CIDROverlaps("10.0.0.0/24", "not-valid") {
|
|
t.Error("expected no overlap with invalid CIDR")
|
|
}
|
|
}
|
|
|
|
func TestIPBanService_Update_ReactivateOverlap(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.IPBan{}); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
active := model.IPBan{CIDR: "10.0.0.0/24", Status: model.IPBanStatusActive}
|
|
if err := db.Create(&active).Error; err != nil {
|
|
t.Fatalf("create active ban: %v", err)
|
|
}
|
|
|
|
expired := model.IPBan{CIDR: "10.0.0.128/25", Status: model.IPBanStatusExpired}
|
|
if err := db.Create(&expired).Error; err != nil {
|
|
t.Fatalf("create expired ban: %v", err)
|
|
}
|
|
|
|
svc := NewIPBanService(db, nil)
|
|
_, err = svc.Update(context.Background(), expired.ID, UpdateIPBanRequest{
|
|
Status: func() *string {
|
|
s := model.IPBanStatusActive
|
|
return &s
|
|
}(),
|
|
})
|
|
if !errors.Is(err, ErrCIDROverlap) {
|
|
t.Fatalf("expected overlap error, got %v", err)
|
|
}
|
|
|
|
var reloaded model.IPBan
|
|
if err := db.First(&reloaded, expired.ID).Error; err != nil {
|
|
t.Fatalf("reload expired ban: %v", err)
|
|
}
|
|
if reloaded.Status != model.IPBanStatusExpired {
|
|
t.Fatalf("expected status to remain expired, got %s", reloaded.Status)
|
|
}
|
|
}
|
|
|
|
func TestIPBanService_Update_ClearExpiresAt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.IPBan{}); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
exp := time.Now().Add(time.Hour).Unix()
|
|
ban := model.IPBan{
|
|
CIDR: "192.168.1.1/32",
|
|
Status: model.IPBanStatusActive,
|
|
ExpiresAt: &exp,
|
|
}
|
|
if err := db.Create(&ban).Error; err != nil {
|
|
t.Fatalf("create ban: %v", err)
|
|
}
|
|
|
|
svc := NewIPBanService(db, nil)
|
|
if _, err := svc.Update(context.Background(), ban.ID, UpdateIPBanRequest{
|
|
ExpiresAt: nil,
|
|
ExpiresAtSet: true,
|
|
}); err != nil {
|
|
t.Fatalf("update expires_at: %v", err)
|
|
}
|
|
|
|
var reloaded model.IPBan
|
|
if err := db.First(&reloaded, ban.ID).Error; err != nil {
|
|
t.Fatalf("reload ban: %v", err)
|
|
}
|
|
if reloaded.ExpiresAt != nil {
|
|
t.Fatalf("expected expires_at to be cleared, got %v", *reloaded.ExpiresAt)
|
|
}
|
|
}
|