feat(model-registry): models.dev updater + admin endpoints

This commit is contained in:
zenfun
2025-12-17 23:59:34 +08:00
parent 96e1fe41a5
commit b2d2df18c5
8 changed files with 1223 additions and 103 deletions

View File

@@ -0,0 +1,193 @@
package service
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/foundation/modelcap"
"github.com/redis/go-redis/v9"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func mustGzipTar(t *testing.T, files map[string]string) []byte {
t.Helper()
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
tw := tar.NewWriter(gz)
for name, body := range files {
b := []byte(body)
h := &tar.Header{
Name: name,
Mode: 0o644,
Size: int64(len(b)),
Typeflag: tar.TypeReg,
}
if err := tw.WriteHeader(h); err != nil {
t.Fatalf("tar header: %v", err)
}
if _, err := tw.Write(b); err != nil {
t.Fatalf("tar write: %v", err)
}
}
if err := tw.Close(); err != nil {
t.Fatalf("tar close: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
return buf.Bytes()
}
func TestModelRegistry_RefreshAndRollback(t *testing.T) {
t.Parallel()
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
t.Fatalf("migrate: %v", err)
}
if err := db.Create(&model.Provider{
Name: "p1",
Type: "openai",
Group: "rg",
Models: "gpt-4o-mini",
Status: "active",
}).Error; err != nil {
t.Fatalf("create provider: %v", err)
}
if err := db.Create(&model.Binding{
Namespace: "ns",
PublicModel: "m",
RouteGroup: "rg",
SelectorType: "exact",
SelectorValue: "gpt-4o-mini",
Status: "active",
}).Error; err != nil {
t.Fatalf("create binding: %v", err)
}
tar1 := mustGzipTar(t, map[string]string{
"sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": `
tool_call = true
[limit]
context = 128000
output = 8192
[modalities]
input = ["text","image"]
`,
})
tar2 := mustGzipTar(t, map[string]string{
"sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": `
tool_call = false
[limit]
context = 64000
output = 2048
[modalities]
input = ["text"]
`,
})
var served int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/dev" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/gzip")
if served == 0 {
served++
_, _ = w.Write(tar1)
return
}
_, _ = w.Write(tar2)
}))
defer srv.Close()
mr := miniredis.RunT(t)
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
cacheDir := t.TempDir()
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
Enabled: true,
RefreshEvery: time.Hour,
ModelsDevBaseURL: srv.URL,
ModelsDevRef: "dev",
CacheDir: cacheDir,
Timeout: 5 * time.Second,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := svc.Refresh(ctx, "dev"); err != nil {
t.Fatalf("refresh1: %v", err)
}
raw1 := mr.HGet("meta:models", "ns.m")
if raw1 == "" {
t.Fatalf("expected meta:models[ns.m]")
}
var m1 modelcap.Model
if err := json.Unmarshal([]byte(raw1), &m1); err != nil {
t.Fatalf("unmarshal1: %v raw=%s", err, raw1)
}
if m1.SupportsVision != true || m1.SupportsFunction != true {
t.Fatalf("expected vision/tools true, got %+v", m1)
}
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
t.Fatalf("expected version aaaaaaaa, got %q", v)
}
if err := svc.Refresh(ctx, "dev"); err != nil {
t.Fatalf("refresh2: %v", err)
}
raw2 := mr.HGet("meta:models", "ns.m")
var m2 modelcap.Model
if err := json.Unmarshal([]byte(raw2), &m2); err != nil {
t.Fatalf("unmarshal2: %v raw=%s", err, raw2)
}
// Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown;
// here we have explicit false from models.dev and should reflect it.
if m2.SupportsVision != false || m2.SupportsFunction != false {
t.Fatalf("expected vision/tools false, got %+v", m2)
}
if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" {
t.Fatalf("expected version bbbbbbbb, got %q", v)
}
if err := svc.Rollback(ctx); err != nil {
t.Fatalf("rollback: %v", err)
}
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
t.Fatalf("expected rollback to version aaaaaaaa, got %q", v)
}
raw3 := mr.HGet("meta:models", "ns.m")
var m3 modelcap.Model
if err := json.Unmarshal([]byte(raw3), &m3); err != nil {
t.Fatalf("unmarshal3: %v raw=%s", err, raw3)
}
if m3.SupportsVision != true || m3.SupportsFunction != true {
t.Fatalf("expected rollback vision/tools true, got %+v", m3)
}
if _, err := os.Stat(cacheDir + "/current.json"); err != nil {
t.Fatalf("expected current cache file: %v", err)
}
if _, err := os.Stat(cacheDir + "/prev.json"); err != nil {
t.Fatalf("expected prev cache file: %v", err)
}
}