package service import ( "archive/tar" "bytes" "compress/gzip" "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/foundation/modelcap" "github.com/redis/go-redis/v9" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func mustGzipTar(t *testing.T, files map[string]string) []byte { t.Helper() var buf bytes.Buffer gz := gzip.NewWriter(&buf) tw := tar.NewWriter(gz) for name, body := range files { b := []byte(body) h := &tar.Header{ Name: name, Mode: 0o644, Size: int64(len(b)), Typeflag: tar.TypeReg, } if err := tw.WriteHeader(h); err != nil { t.Fatalf("tar header: %v", err) } if _, err := tw.Write(b); err != nil { t.Fatalf("tar write: %v", err) } } if err := tw.Close(); err != nil { t.Fatalf("tar close: %v", err) } if err := gz.Close(); err != nil { t.Fatalf("gzip close: %v", err) } return buf.Bytes() } func TestModelRegistry_RefreshAndRollback(t *testing.T) { t.Parallel() dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil { t.Fatalf("migrate: %v", err) } group := model.ProviderGroup{ Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "gpt-4o-mini", Status: "active", } if err := db.Create(&group).Error; err != nil { t.Fatalf("create provider group: %v", err) } if err := db.Create(&model.APIKey{ GroupID: group.ID, APIKey: "k", Status: "active", }).Error; err != nil { t.Fatalf("create api key: %v", err) } if err := db.Create(&model.Binding{ Namespace: "ns", PublicModel: "m", GroupID: group.ID, Weight: 1, SelectorType: "exact", SelectorValue: "gpt-4o-mini", Status: "active", }).Error; err != nil { t.Fatalf("create binding: %v", err) } tar1 := mustGzipTar(t, map[string]string{ "sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": ` tool_call = true [limit] context = 128000 output = 8192 [modalities] input = ["text","image"] `, }) tar2 := mustGzipTar(t, map[string]string{ "sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": ` tool_call = false [limit] context = 64000 output = 2048 [modalities] input = ["text"] `, }) var served int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/dev" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "application/gzip") if served == 0 { served++ _, _ = w.Write(tar1) return } _, _ = w.Write(tar2) })) defer srv.Close() mr := miniredis.RunT(t) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) cacheDir := t.TempDir() svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{ Enabled: true, RefreshEvery: time.Hour, ModelsDevBaseURL: srv.URL, ModelsDevRef: "dev", CacheDir: cacheDir, Timeout: 5 * time.Second, }) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := svc.Refresh(ctx, "dev"); err != nil { t.Fatalf("refresh1: %v", err) } raw1 := mr.HGet("meta:models", "ns.m") if raw1 == "" { t.Fatalf("expected meta:models[ns.m]") } var m1 modelcap.Model if err := json.Unmarshal([]byte(raw1), &m1); err != nil { t.Fatalf("unmarshal1: %v raw=%s", err, raw1) } if m1.SupportsVision != true || m1.SupportsFunction != true { t.Fatalf("expected vision/tools true, got %+v", m1) } if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" { t.Fatalf("expected version aaaaaaaa, got %q", v) } if err := svc.Refresh(ctx, "dev"); err != nil { t.Fatalf("refresh2: %v", err) } raw2 := mr.HGet("meta:models", "ns.m") var m2 modelcap.Model if err := json.Unmarshal([]byte(raw2), &m2); err != nil { t.Fatalf("unmarshal2: %v raw=%s", err, raw2) } // Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown; // here we have explicit false from models.dev and should reflect it. if m2.SupportsVision != false || m2.SupportsFunction != false { t.Fatalf("expected vision/tools false, got %+v", m2) } if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" { t.Fatalf("expected version bbbbbbbb, got %q", v) } if err := svc.Rollback(ctx); err != nil { t.Fatalf("rollback: %v", err) } if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" { t.Fatalf("expected rollback to version aaaaaaaa, got %q", v) } raw3 := mr.HGet("meta:models", "ns.m") var m3 modelcap.Model if err := json.Unmarshal([]byte(raw3), &m3); err != nil { t.Fatalf("unmarshal3: %v raw=%s", err, raw3) } if m3.SupportsVision != true || m3.SupportsFunction != true { t.Fatalf("expected rollback vision/tools true, got %+v", m3) } if _, err := os.Stat(cacheDir + "/current.json"); err != nil { t.Fatalf("expected current cache file: %v", err) } if _, err := os.Stat(cacheDir + "/prev.json"); err != nil { t.Fatalf("expected prev cache file: %v", err) } }