mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(model-registry): models.dev updater + admin endpoints
This commit is contained in:
193
internal/service/model_registry_test.go
Normal file
193
internal/service/model_registry_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustGzipTar(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gz)
|
||||
for name, body := range files {
|
||||
b := []byte(body)
|
||||
h := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0o644,
|
||||
Size: int64(len(b)),
|
||||
Typeflag: tar.TypeReg,
|
||||
}
|
||||
if err := tw.WriteHeader(h); err != nil {
|
||||
t.Fatalf("tar header: %v", err)
|
||||
}
|
||||
if _, err := tw.Write(b); err != nil {
|
||||
t.Fatalf("tar write: %v", err)
|
||||
}
|
||||
}
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("tar close: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelRegistry_RefreshAndRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Provider{
|
||||
Name: "p1",
|
||||
Type: "openai",
|
||||
Group: "rg",
|
||||
Models: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("create provider: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Binding{
|
||||
Namespace: "ns",
|
||||
PublicModel: "m",
|
||||
RouteGroup: "rg",
|
||||
SelectorType: "exact",
|
||||
SelectorValue: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("create binding: %v", err)
|
||||
}
|
||||
|
||||
tar1 := mustGzipTar(t, map[string]string{
|
||||
"sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": `
|
||||
tool_call = true
|
||||
[limit]
|
||||
context = 128000
|
||||
output = 8192
|
||||
[modalities]
|
||||
input = ["text","image"]
|
||||
`,
|
||||
})
|
||||
tar2 := mustGzipTar(t, map[string]string{
|
||||
"sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": `
|
||||
tool_call = false
|
||||
[limit]
|
||||
context = 64000
|
||||
output = 2048
|
||||
[modalities]
|
||||
input = ["text"]
|
||||
`,
|
||||
})
|
||||
|
||||
var served int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/dev" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/gzip")
|
||||
if served == 0 {
|
||||
served++
|
||||
_, _ = w.Write(tar1)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(tar2)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
|
||||
Enabled: true,
|
||||
RefreshEvery: time.Hour,
|
||||
ModelsDevBaseURL: srv.URL,
|
||||
ModelsDevRef: "dev",
|
||||
CacheDir: cacheDir,
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Refresh(ctx, "dev"); err != nil {
|
||||
t.Fatalf("refresh1: %v", err)
|
||||
}
|
||||
raw1 := mr.HGet("meta:models", "ns.m")
|
||||
if raw1 == "" {
|
||||
t.Fatalf("expected meta:models[ns.m]")
|
||||
}
|
||||
var m1 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw1), &m1); err != nil {
|
||||
t.Fatalf("unmarshal1: %v raw=%s", err, raw1)
|
||||
}
|
||||
if m1.SupportsVision != true || m1.SupportsFunction != true {
|
||||
t.Fatalf("expected vision/tools true, got %+v", m1)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
||||
t.Fatalf("expected version aaaaaaaa, got %q", v)
|
||||
}
|
||||
|
||||
if err := svc.Refresh(ctx, "dev"); err != nil {
|
||||
t.Fatalf("refresh2: %v", err)
|
||||
}
|
||||
raw2 := mr.HGet("meta:models", "ns.m")
|
||||
var m2 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw2), &m2); err != nil {
|
||||
t.Fatalf("unmarshal2: %v raw=%s", err, raw2)
|
||||
}
|
||||
// Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown;
|
||||
// here we have explicit false from models.dev and should reflect it.
|
||||
if m2.SupportsVision != false || m2.SupportsFunction != false {
|
||||
t.Fatalf("expected vision/tools false, got %+v", m2)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" {
|
||||
t.Fatalf("expected version bbbbbbbb, got %q", v)
|
||||
}
|
||||
|
||||
if err := svc.Rollback(ctx); err != nil {
|
||||
t.Fatalf("rollback: %v", err)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
||||
t.Fatalf("expected rollback to version aaaaaaaa, got %q", v)
|
||||
}
|
||||
raw3 := mr.HGet("meta:models", "ns.m")
|
||||
var m3 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw3), &m3); err != nil {
|
||||
t.Fatalf("unmarshal3: %v raw=%s", err, raw3)
|
||||
}
|
||||
if m3.SupportsVision != true || m3.SupportsFunction != true {
|
||||
t.Fatalf("expected rollback vision/tools true, got %+v", m3)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cacheDir + "/current.json"); err != nil {
|
||||
t.Fatalf("expected current cache file: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(cacheDir + "/prev.json"); err != nil {
|
||||
t.Fatalf("expected prev cache file: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user