mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
194 lines
4.9 KiB
Go
194 lines
4.9 KiB
Go
package service
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/ez-api/foundation/modelcap"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func mustGzipTar(t *testing.T, files map[string]string) []byte {
|
|
t.Helper()
|
|
var buf bytes.Buffer
|
|
gz := gzip.NewWriter(&buf)
|
|
tw := tar.NewWriter(gz)
|
|
for name, body := range files {
|
|
b := []byte(body)
|
|
h := &tar.Header{
|
|
Name: name,
|
|
Mode: 0o644,
|
|
Size: int64(len(b)),
|
|
Typeflag: tar.TypeReg,
|
|
}
|
|
if err := tw.WriteHeader(h); err != nil {
|
|
t.Fatalf("tar header: %v", err)
|
|
}
|
|
if _, err := tw.Write(b); err != nil {
|
|
t.Fatalf("tar write: %v", err)
|
|
}
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
t.Fatalf("tar close: %v", err)
|
|
}
|
|
if err := gz.Close(); err != nil {
|
|
t.Fatalf("gzip close: %v", err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestModelRegistry_RefreshAndRollback(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
if err := db.Create(&model.Provider{
|
|
Name: "p1",
|
|
Type: "openai",
|
|
Group: "rg",
|
|
Models: "gpt-4o-mini",
|
|
Status: "active",
|
|
}).Error; err != nil {
|
|
t.Fatalf("create provider: %v", err)
|
|
}
|
|
if err := db.Create(&model.Binding{
|
|
Namespace: "ns",
|
|
PublicModel: "m",
|
|
RouteGroup: "rg",
|
|
SelectorType: "exact",
|
|
SelectorValue: "gpt-4o-mini",
|
|
Status: "active",
|
|
}).Error; err != nil {
|
|
t.Fatalf("create binding: %v", err)
|
|
}
|
|
|
|
tar1 := mustGzipTar(t, map[string]string{
|
|
"sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": `
|
|
tool_call = true
|
|
[limit]
|
|
context = 128000
|
|
output = 8192
|
|
[modalities]
|
|
input = ["text","image"]
|
|
`,
|
|
})
|
|
tar2 := mustGzipTar(t, map[string]string{
|
|
"sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": `
|
|
tool_call = false
|
|
[limit]
|
|
context = 64000
|
|
output = 2048
|
|
[modalities]
|
|
input = ["text"]
|
|
`,
|
|
})
|
|
|
|
var served int
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/dev" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/gzip")
|
|
if served == 0 {
|
|
served++
|
|
_, _ = w.Write(tar1)
|
|
return
|
|
}
|
|
_, _ = w.Write(tar2)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
mr := miniredis.RunT(t)
|
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
|
|
|
cacheDir := t.TempDir()
|
|
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
|
|
Enabled: true,
|
|
RefreshEvery: time.Hour,
|
|
ModelsDevBaseURL: srv.URL,
|
|
ModelsDevRef: "dev",
|
|
CacheDir: cacheDir,
|
|
Timeout: 5 * time.Second,
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := svc.Refresh(ctx, "dev"); err != nil {
|
|
t.Fatalf("refresh1: %v", err)
|
|
}
|
|
raw1 := mr.HGet("meta:models", "ns.m")
|
|
if raw1 == "" {
|
|
t.Fatalf("expected meta:models[ns.m]")
|
|
}
|
|
var m1 modelcap.Model
|
|
if err := json.Unmarshal([]byte(raw1), &m1); err != nil {
|
|
t.Fatalf("unmarshal1: %v raw=%s", err, raw1)
|
|
}
|
|
if m1.SupportsVision != true || m1.SupportsFunction != true {
|
|
t.Fatalf("expected vision/tools true, got %+v", m1)
|
|
}
|
|
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
|
t.Fatalf("expected version aaaaaaaa, got %q", v)
|
|
}
|
|
|
|
if err := svc.Refresh(ctx, "dev"); err != nil {
|
|
t.Fatalf("refresh2: %v", err)
|
|
}
|
|
raw2 := mr.HGet("meta:models", "ns.m")
|
|
var m2 modelcap.Model
|
|
if err := json.Unmarshal([]byte(raw2), &m2); err != nil {
|
|
t.Fatalf("unmarshal2: %v raw=%s", err, raw2)
|
|
}
|
|
// Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown;
|
|
// here we have explicit false from models.dev and should reflect it.
|
|
if m2.SupportsVision != false || m2.SupportsFunction != false {
|
|
t.Fatalf("expected vision/tools false, got %+v", m2)
|
|
}
|
|
if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" {
|
|
t.Fatalf("expected version bbbbbbbb, got %q", v)
|
|
}
|
|
|
|
if err := svc.Rollback(ctx); err != nil {
|
|
t.Fatalf("rollback: %v", err)
|
|
}
|
|
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
|
t.Fatalf("expected rollback to version aaaaaaaa, got %q", v)
|
|
}
|
|
raw3 := mr.HGet("meta:models", "ns.m")
|
|
var m3 modelcap.Model
|
|
if err := json.Unmarshal([]byte(raw3), &m3); err != nil {
|
|
t.Fatalf("unmarshal3: %v raw=%s", err, raw3)
|
|
}
|
|
if m3.SupportsVision != true || m3.SupportsFunction != true {
|
|
t.Fatalf("expected rollback vision/tools true, got %+v", m3)
|
|
}
|
|
|
|
if _, err := os.Stat(cacheDir + "/current.json"); err != nil {
|
|
t.Fatalf("expected current cache file: %v", err)
|
|
}
|
|
if _, err := os.Stat(cacheDir + "/prev.json"); err != nil {
|
|
t.Fatalf("expected prev cache file: %v", err)
|
|
}
|
|
}
|