mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(model-registry): models.dev updater + admin endpoints
This commit is contained in:
@@ -114,11 +114,21 @@ func main() {
|
||||
adminHandler := api.NewAdminHandler(masterService, syncService)
|
||||
masterHandler := api.NewMasterHandler(masterService, syncService)
|
||||
featureHandler := api.NewFeatureHandler(rdb)
|
||||
modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{
|
||||
Enabled: cfg.ModelRegistry.Enabled,
|
||||
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
|
||||
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
|
||||
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
|
||||
CacheDir: cfg.ModelRegistry.CacheDir,
|
||||
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService)
|
||||
|
||||
// 4.1 Prime Redis snapshots so DP can start with data
|
||||
if err := syncService.SyncAll(db); err != nil {
|
||||
logger.Warn("initial sync warning", "err", err)
|
||||
}
|
||||
modelRegistryService.Start(context.Background())
|
||||
|
||||
// 5. Setup Gin Router
|
||||
r := gin.Default()
|
||||
@@ -165,6 +175,9 @@ func main() {
|
||||
adminGroup.PUT("/keys/:id/access", handler.UpdateKeyAccess)
|
||||
adminGroup.GET("/features", featureHandler.ListFeatures)
|
||||
adminGroup.PUT("/features", featureHandler.UpdateFeatures)
|
||||
adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus)
|
||||
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
|
||||
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
|
||||
// Other admin routes for managing providers, models, etc.
|
||||
adminGroup.POST("/providers", handler.CreateProvider)
|
||||
adminGroup.POST("/providers/preset", handler.CreateProviderPreset)
|
||||
|
||||
4
go.mod
4
go.mod
@@ -6,6 +6,7 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/ez-api/foundation v0.2.0
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/swaggo/files v1.0.1
|
||||
@@ -17,6 +18,8 @@ require (
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
replace github.com/ez-api/foundation => ../foundation
|
||||
|
||||
require (
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
@@ -57,7 +60,6 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.54.0 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -20,8 +20,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/ez-api/foundation v0.2.0 h1:QdxbXZ2wr4O08Uxl6QK4fJZPrWb09yMNI+hS2aRSqG8=
|
||||
github.com/ez-api/foundation v0.2.0/go.mod h1:bTh1LA42TW4CXi1SebDEUE+fhEssFUphzcGEzyAFFZI=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
|
||||
91
internal/api/model_registry_handler.go
Normal file
91
internal/api/model_registry_handler.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ModelRegistryHandler struct {
|
||||
reg *service.ModelRegistryService
|
||||
}
|
||||
|
||||
func NewModelRegistryHandler(reg *service.ModelRegistryService) *ModelRegistryHandler {
|
||||
return &ModelRegistryHandler{reg: reg}
|
||||
}
|
||||
|
||||
// GetModelRegistryStatus godoc
|
||||
// @Summary Get model registry status
|
||||
// @Description Returns Redis meta and local last-good cache info for model capability registry
|
||||
// @Tags admin
|
||||
// @Produce json
|
||||
// @Security AdminAuth
|
||||
// @Success 200 {object} service.ModelRegistryStatus
|
||||
// @Failure 500 {object} gin.H
|
||||
// @Router /admin/model-registry/status [get]
|
||||
func (h *ModelRegistryHandler) GetStatus(c *gin.Context) {
|
||||
if h == nil || h.reg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
|
||||
return
|
||||
}
|
||||
st, err := h.reg.Status(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get model registry status", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, st)
|
||||
}
|
||||
|
||||
type refreshModelRegistryRequest struct {
|
||||
Ref string `json:"ref"`
|
||||
}
|
||||
|
||||
// RefreshModelRegistry godoc
|
||||
// @Summary Refresh model registry from models.dev
|
||||
// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models
|
||||
// @Tags admin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security AdminAuth
|
||||
// @Param body body refreshModelRegistryRequest false "optional override ref"
|
||||
// @Success 200 {object} gin.H
|
||||
// @Failure 400 {object} gin.H
|
||||
// @Failure 500 {object} gin.H
|
||||
// @Router /admin/model-registry/refresh [post]
|
||||
func (h *ModelRegistryHandler) Refresh(c *gin.Context) {
|
||||
if h == nil || h.reg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
|
||||
return
|
||||
}
|
||||
var req refreshModelRegistryRequest
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
ref := strings.TrimSpace(req.Ref)
|
||||
if err := h.reg.Refresh(c.Request.Context(), ref); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh model registry", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "refreshed"})
|
||||
}
|
||||
|
||||
// RollbackModelRegistry godoc
|
||||
// @Summary Rollback model registry
|
||||
// @Description Rollback meta:models to previous cached last-good version
|
||||
// @Tags admin
|
||||
// @Produce json
|
||||
// @Security AdminAuth
|
||||
// @Success 200 {object} gin.H
|
||||
// @Failure 500 {object} gin.H
|
||||
// @Router /admin/model-registry/rollback [post]
|
||||
func (h *ModelRegistryHandler) Rollback(c *gin.Context) {
|
||||
if h == nil || h.reg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
|
||||
return
|
||||
}
|
||||
if err := h.reg.Rollback(c.Request.Context()); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to rollback model registry", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "rolled_back"})
|
||||
}
|
||||
@@ -15,6 +15,7 @@ type Config struct {
|
||||
Redis RedisConfig
|
||||
Log LogConfig
|
||||
Auth AuthConfig
|
||||
ModelRegistry ModelRegistryConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -41,6 +42,15 @@ type LogConfig struct {
|
||||
QueueCapacity int
|
||||
}
|
||||
|
||||
type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshSeconds int
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
CacheDir string
|
||||
TimeoutSeconds int
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
@@ -53,6 +63,12 @@ func Load() (*Config, error) {
|
||||
v.SetDefault("log.flush_ms", 1000)
|
||||
v.SetDefault("log.queue_capacity", 10000)
|
||||
v.SetDefault("auth.jwt_secret", "change_me_in_production")
|
||||
v.SetDefault("model_registry.enabled", false)
|
||||
v.SetDefault("model_registry.refresh_seconds", 1800)
|
||||
v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz")
|
||||
v.SetDefault("model_registry.models_dev_ref", "dev")
|
||||
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
||||
v.SetDefault("model_registry.timeout_seconds", 30)
|
||||
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
@@ -66,6 +82,12 @@ func Load() (*Config, error) {
|
||||
_ = v.BindEnv("log.flush_ms", "EZ_LOG_FLUSH_MS")
|
||||
_ = v.BindEnv("log.queue_capacity", "EZ_LOG_QUEUE")
|
||||
_ = v.BindEnv("auth.jwt_secret", "EZ_JWT_SECRET")
|
||||
_ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED")
|
||||
_ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS")
|
||||
_ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL")
|
||||
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
||||
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
||||
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
||||
|
||||
if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" {
|
||||
v.SetConfigFile(configFile)
|
||||
@@ -102,6 +124,14 @@ func Load() (*Config, error) {
|
||||
Auth: AuthConfig{
|
||||
JWTSecret: v.GetString("auth.jwt_secret"),
|
||||
},
|
||||
ModelRegistry: ModelRegistryConfig{
|
||||
Enabled: v.GetBool("model_registry.enabled"),
|
||||
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
|
||||
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
|
||||
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
|
||||
CacheDir: v.GetString("model_registry.cache_dir"),
|
||||
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
||||
},
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
|
||||
844
internal/service/model_registry.go
Normal file
844
internal/service/model_registry.go
Normal file
@@ -0,0 +1,844 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
groupx "github.com/ez-api/foundation/group"
|
||||
"github.com/ez-api/foundation/jsoncodec"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/ez-api/foundation/routing"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshEvery time.Duration
|
||||
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
|
||||
CacheDir string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type ModelRegistryService struct {
|
||||
db *gorm.DB
|
||||
rdb *redis.Client
|
||||
|
||||
cfg ModelRegistryConfig
|
||||
client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
lastError string
|
||||
lastRefresh time.Time
|
||||
lastApplied modelcap.Meta
|
||||
lastUpstream string
|
||||
}
|
||||
|
||||
func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryConfig) *ModelRegistryService {
|
||||
if cfg.RefreshEvery <= 0 {
|
||||
cfg.RefreshEvery = 30 * time.Minute
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
||||
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
||||
cfg.ModelsDevRef = "dev"
|
||||
}
|
||||
if strings.TrimSpace(cfg.CacheDir) == "" {
|
||||
cfg.CacheDir = "./data/model-registry"
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = 30 * time.Second
|
||||
}
|
||||
return &ModelRegistryService{
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: cfg.Timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Start(ctx context.Context) {
|
||||
if !s.cfg.Enabled {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.cfg.RefreshEvery)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Best-effort initial refresh.
|
||||
_ = s.Refresh(ctx, "")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = s.Refresh(ctx, "")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type ModelRegistryStatus struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ModelsDevRef string `json:"models_dev_ref"`
|
||||
ModelsDevURL string `json:"models_dev_url"`
|
||||
LastRefreshAt int64 `json:"last_refresh_at,omitempty"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
RedisMeta map[string]string `json:"redis_meta,omitempty"`
|
||||
CacheCurrent *modelRegistryFile `json:"cache_current,omitempty"`
|
||||
CachePrev *modelRegistryFile `json:"cache_prev,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus, error) {
|
||||
s.mu.Lock()
|
||||
lastErr := s.lastError
|
||||
lastRefresh := s.lastRefresh
|
||||
s.mu.Unlock()
|
||||
|
||||
redisMeta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
|
||||
|
||||
current, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "current.json"))
|
||||
prev, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "prev.json"))
|
||||
|
||||
out := &ModelRegistryStatus{
|
||||
Enabled: s.cfg.Enabled,
|
||||
ModelsDevRef: s.cfg.ModelsDevRef,
|
||||
ModelsDevURL: strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + s.cfg.ModelsDevRef,
|
||||
LastError: lastErr,
|
||||
RedisMeta: redisMeta,
|
||||
CacheCurrent: current,
|
||||
CachePrev: prev,
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
out.LastRefreshAt = lastRefresh.Unix()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lastUpstream = ref
|
||||
s.mu.Unlock()
|
||||
|
||||
tarballURL := strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + ref
|
||||
reg, version, err := s.fetchModelsDev(ctx, tarballURL)
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
models, payloads, err := s.buildBindingModels(ctx, reg)
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Overlay explicit DB entries (admin overrides).
|
||||
var dbModels []model.Model
|
||||
if err := s.db.Find(&dbModels).Error; err != nil {
|
||||
s.setError(fmt.Errorf("load db models: %w", err))
|
||||
return err
|
||||
}
|
||||
for _, m := range dbModels {
|
||||
snap := modelcap.Model{
|
||||
Name: m.Name,
|
||||
Kind: string(modelcap.NormalizeKind(m.Kind)),
|
||||
ContextWindow: m.ContextWindow,
|
||||
CostPerToken: m.CostPerToken,
|
||||
SupportsVision: m.SupportsVision,
|
||||
SupportsFunction: m.SupportsFunctions,
|
||||
SupportsToolChoice: m.SupportsToolChoice,
|
||||
SupportsFim: m.SupportsFIM,
|
||||
MaxOutputTokens: m.MaxOutputTokens,
|
||||
}.Normalized()
|
||||
b, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("marshal db model %s: %w", m.Name, err))
|
||||
return err
|
||||
}
|
||||
models[snap.Name] = snap
|
||||
payloads[snap.Name] = string(b)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
meta := modelcap.Meta{
|
||||
Version: version,
|
||||
UpdatedAt: fmt.Sprintf("%d", now),
|
||||
Source: "models.dev",
|
||||
Checksum: modelcap.ChecksumFromPayloads(payloads),
|
||||
UpstreamURL: "https://github.com/sst/models.dev",
|
||||
UpstreamRef: ref,
|
||||
}
|
||||
|
||||
if err := s.applyToRedis(ctx, models, payloads, meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
if err := s.persistCache(models, meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lastError = ""
|
||||
s.lastRefresh = time.Now()
|
||||
s.lastApplied = meta
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Rollback(ctx context.Context) error {
|
||||
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
|
||||
prev, err := readModelRegistryFile(prevPath)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("read prev cache: %w", err))
|
||||
return err
|
||||
}
|
||||
if prev == nil {
|
||||
s.setError(fmt.Errorf("no prev cache"))
|
||||
return fmt.Errorf("no prev cache")
|
||||
}
|
||||
|
||||
payloads := make(map[string]string, len(prev.Models))
|
||||
for name, m := range prev.Models {
|
||||
m = m.Normalized()
|
||||
b, err := jsoncodec.Marshal(m)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("marshal cached model %s: %w", name, err))
|
||||
return err
|
||||
}
|
||||
prev.Models[name] = m
|
||||
payloads[name] = string(b)
|
||||
}
|
||||
prev.Meta.Checksum = modelcap.ChecksumFromPayloads(payloads)
|
||||
prev.Meta.Source = "rollback"
|
||||
prev.Meta.UpdatedAt = fmt.Sprintf("%d", time.Now().Unix())
|
||||
|
||||
if err := s.applyToRedis(ctx, prev.Models, payloads, prev.Meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
// Swap cache current/prev to reflect rollback.
|
||||
if err := s.persistCache(prev.Models, prev.Meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.lastError = ""
|
||||
s.lastRefresh = time.Now()
|
||||
s.lastApplied = prev.Meta
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) setError(err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lastError = err.Error()
|
||||
s.lastRefresh = time.Now()
|
||||
}
|
||||
|
||||
type boolVal struct {
|
||||
Known bool
|
||||
Val bool
|
||||
}
|
||||
|
||||
type intVal struct {
|
||||
Known bool
|
||||
Val int
|
||||
}
|
||||
|
||||
type upstreamCap struct {
|
||||
ContextWindow intVal
|
||||
MaxOutputTokens intVal
|
||||
SupportsVision boolVal
|
||||
SupportsTools boolVal
|
||||
}
|
||||
|
||||
type modelsDevRegistry struct {
|
||||
ByProviderModel map[string]upstreamCap // key: providerID|modelID
|
||||
ByModel map[string]upstreamCap // fallback: modelID
|
||||
}
|
||||
|
||||
var shaSuffix = regexp.MustCompile(`-([0-9a-f]{7,40})$`)
|
||||
|
||||
func (s *ModelRegistryService) fetchModelsDev(ctx context.Context, tarballURL string) (*modelsDevRegistry, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tarballURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
return nil, "", fmt.Errorf("models.dev fetch failed: %s body=%s", resp.Status, string(b))
|
||||
}
|
||||
tarball, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(bytes.NewReader(tarball))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer gr.Close()
|
||||
|
||||
tr := tar.NewReader(gr)
|
||||
reg := &modelsDevRegistry{
|
||||
ByProviderModel: make(map[string]upstreamCap),
|
||||
ByModel: make(map[string]upstreamCap),
|
||||
}
|
||||
|
||||
version := ""
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
name := hdr.Name
|
||||
if version == "" {
|
||||
if parts := strings.SplitN(name, "/", 2); len(parts) > 0 {
|
||||
if m := shaSuffix.FindStringSubmatch(parts[0]); len(m) == 2 {
|
||||
version = m[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".toml") {
|
||||
continue
|
||||
}
|
||||
providerID, modelID, ok := parseModelsDevModelPath(name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(tr, 1<<20))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cap, ok := parseModelsDevTOML(data)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
key := providerID + "|" + modelID
|
||||
reg.ByProviderModel[key] = cap
|
||||
if _, exists := reg.ByModel[modelID]; !exists {
|
||||
reg.ByModel[modelID] = cap
|
||||
}
|
||||
// Also store a flattened variant for providers that may use subdirs.
|
||||
if alt := strings.ReplaceAll(modelID, "/", "-"); alt != modelID {
|
||||
if _, exists := reg.ByModel[alt]; !exists {
|
||||
reg.ByModel[alt] = cap
|
||||
}
|
||||
altKey := providerID + "|" + alt
|
||||
if _, exists := reg.ByProviderModel[altKey]; !exists {
|
||||
reg.ByProviderModel[altKey] = cap
|
||||
}
|
||||
}
|
||||
}
|
||||
if version == "" {
|
||||
version = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
return reg, version, nil
|
||||
}
|
||||
|
||||
func parseModelsDevModelPath(path string) (providerID string, modelID string, ok bool) {
|
||||
parts := strings.Split(path, "/")
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] != "providers" {
|
||||
continue
|
||||
}
|
||||
if i+3 >= len(parts) {
|
||||
return "", "", false
|
||||
}
|
||||
providerID = strings.TrimSpace(parts[i+1])
|
||||
if providerID == "" {
|
||||
return "", "", false
|
||||
}
|
||||
if parts[i+2] != "models" {
|
||||
return "", "", false
|
||||
}
|
||||
modelPart := strings.Join(parts[i+3:], "/")
|
||||
modelPart = strings.TrimSuffix(modelPart, ".toml")
|
||||
modelPart = strings.TrimSpace(modelPart)
|
||||
if modelPart == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return providerID, modelPart, true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func parseModelsDevTOML(data []byte) (upstreamCap, bool) {
|
||||
var doc map[string]any
|
||||
if err := toml.Unmarshal(data, &doc); err != nil {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
|
||||
var cap upstreamCap
|
||||
|
||||
if v, ok := getBool(doc, "tool_call"); ok {
|
||||
cap.SupportsTools = boolVal{Known: true, Val: v}
|
||||
}
|
||||
if limit, ok := getMap(doc, "limit"); ok {
|
||||
if v, ok := getInt(limit, "context"); ok {
|
||||
cap.ContextWindow = intVal{Known: true, Val: v}
|
||||
}
|
||||
if v, ok := getInt(limit, "output"); ok {
|
||||
cap.MaxOutputTokens = intVal{Known: true, Val: v}
|
||||
}
|
||||
}
|
||||
if mods, ok := getMap(doc, "modalities"); ok {
|
||||
if input, ok := getStringSlice(mods["input"]); ok {
|
||||
hasImage := false
|
||||
for _, it := range input {
|
||||
if strings.EqualFold(strings.TrimSpace(it), "image") {
|
||||
hasImage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cap.SupportsVision = boolVal{Known: true, Val: hasImage}
|
||||
}
|
||||
}
|
||||
|
||||
if !cap.SupportsTools.Known && !cap.SupportsVision.Known && !cap.ContextWindow.Known && !cap.MaxOutputTokens.Known {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
return cap, true
|
||||
}
|
||||
|
||||
func getMap(doc map[string]any, key string) (map[string]any, bool) {
|
||||
if doc == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
m, ok := v.(map[string]any)
|
||||
return m, ok
|
||||
}
|
||||
|
||||
func getBool(doc map[string]any, key string) (bool, bool) {
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return false, false
|
||||
}
|
||||
switch b := v.(type) {
|
||||
case bool:
|
||||
return b, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func getInt(doc map[string]any, key string) (int, bool) {
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case float64:
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func getStringSlice(v any) ([]string, bool) {
|
||||
if v == nil {
|
||||
return nil, false
|
||||
}
|
||||
switch s := v.(type) {
|
||||
case []any:
|
||||
out := make([]string, 0, len(s))
|
||||
for _, it := range s {
|
||||
if str, ok := it.(string); ok {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
case []string:
|
||||
return append([]string(nil), s...), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
type capAgg struct {
|
||||
kind string
|
||||
|
||||
visionAnyTrue bool
|
||||
visionAllKnown bool
|
||||
visionKnownAny bool
|
||||
|
||||
toolsAnyTrue bool
|
||||
toolsAllKnown bool
|
||||
toolsKnownAny bool
|
||||
|
||||
maxOutputKnown bool
|
||||
maxOutputMax int
|
||||
|
||||
contextKnown bool
|
||||
contextMax int
|
||||
}
|
||||
|
||||
func inferKindFromPublicModel(publicModel string) string {
|
||||
pm := strings.ToLower(strings.TrimSpace(publicModel))
|
||||
if strings.Contains(pm, "embedding") {
|
||||
return string(modelcap.KindEmbedding)
|
||||
}
|
||||
if strings.Contains(pm, "rerank") {
|
||||
return string(modelcap.KindRerank)
|
||||
}
|
||||
return string(modelcap.KindChat)
|
||||
}
|
||||
|
||||
func (a *capAgg) merge(cap upstreamCap, ok bool) {
|
||||
if ok {
|
||||
// vision
|
||||
if cap.SupportsVision.Known {
|
||||
a.visionKnownAny = true
|
||||
if cap.SupportsVision.Val {
|
||||
a.visionAnyTrue = true
|
||||
}
|
||||
} else {
|
||||
a.visionAllKnown = false
|
||||
}
|
||||
// tools
|
||||
if cap.SupportsTools.Known {
|
||||
a.toolsKnownAny = true
|
||||
if cap.SupportsTools.Val {
|
||||
a.toolsAnyTrue = true
|
||||
}
|
||||
} else {
|
||||
a.toolsAllKnown = false
|
||||
}
|
||||
// limits
|
||||
if cap.MaxOutputTokens.Known && cap.MaxOutputTokens.Val > 0 {
|
||||
a.maxOutputKnown = true
|
||||
if cap.MaxOutputTokens.Val > a.maxOutputMax {
|
||||
a.maxOutputMax = cap.MaxOutputTokens.Val
|
||||
}
|
||||
}
|
||||
if cap.ContextWindow.Known && cap.ContextWindow.Val > 0 {
|
||||
a.contextKnown = true
|
||||
if cap.ContextWindow.Val > a.contextMax {
|
||||
a.contextMax = cap.ContextWindow.Val
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// Unknown upstream capability: do not force blocking, treat as "not all known".
|
||||
a.visionAllKnown = false
|
||||
a.toolsAllKnown = false
|
||||
}
|
||||
|
||||
func (a *capAgg) finalize(name string) modelcap.Model {
|
||||
// Safe defaults: unknown -> allow (true) so we avoid false blocking.
|
||||
supportsVision := true
|
||||
if a.visionKnownAny {
|
||||
if a.visionAnyTrue {
|
||||
supportsVision = true
|
||||
} else if a.visionAllKnown {
|
||||
supportsVision = false
|
||||
}
|
||||
}
|
||||
|
||||
supportsTools := true
|
||||
if a.toolsKnownAny {
|
||||
if a.toolsAnyTrue {
|
||||
supportsTools = true
|
||||
} else if a.toolsAllKnown {
|
||||
supportsTools = false
|
||||
}
|
||||
}
|
||||
|
||||
out := modelcap.Model{
|
||||
Name: name,
|
||||
Kind: string(modelcap.NormalizeKind(a.kind)),
|
||||
ContextWindow: 0,
|
||||
CostPerToken: 0,
|
||||
SupportsVision: supportsVision,
|
||||
SupportsFunction: supportsTools,
|
||||
SupportsToolChoice: true,
|
||||
SupportsFim: true,
|
||||
MaxOutputTokens: 0,
|
||||
}
|
||||
if a.contextKnown {
|
||||
out.ContextWindow = a.contextMax
|
||||
}
|
||||
if a.maxOutputKnown {
|
||||
out.MaxOutputTokens = a.maxOutputMax
|
||||
}
|
||||
return out.Normalized()
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
|
||||
var providers []model.Provider
|
||||
if err := s.db.Find(&providers).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load providers: %w", err)
|
||||
}
|
||||
var bindings []model.Binding
|
||||
if err := s.db.Find(&bindings).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load bindings: %w", err)
|
||||
}
|
||||
|
||||
type providerLite struct {
|
||||
id uint
|
||||
group string
|
||||
ptype string
|
||||
models []string
|
||||
}
|
||||
providersByGroup := make(map[string][]providerLite)
|
||||
now := time.Now().Unix()
|
||||
for _, p := range providers {
|
||||
if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now {
|
||||
continue
|
||||
}
|
||||
group := groupx.Normalize(p.Group)
|
||||
rawModels := strings.Split(p.Models, ",")
|
||||
var outModels []string
|
||||
for _, m := range rawModels {
|
||||
m = strings.TrimSpace(m)
|
||||
if m != "" {
|
||||
outModels = append(outModels, m)
|
||||
}
|
||||
}
|
||||
if group == "" || len(outModels) == 0 {
|
||||
continue
|
||||
}
|
||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
||||
id: p.ID,
|
||||
group: group,
|
||||
ptype: strings.TrimSpace(p.Type),
|
||||
models: outModels,
|
||||
})
|
||||
}
|
||||
|
||||
modelsOut := make(map[string]modelcap.Model)
|
||||
payloads := make(map[string]string)
|
||||
|
||||
for _, b := range bindings {
|
||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
ns := strings.TrimSpace(b.Namespace)
|
||||
pm := strings.TrimSpace(b.PublicModel)
|
||||
if ns == "" || pm == "" {
|
||||
continue
|
||||
}
|
||||
key := ns + "." + pm
|
||||
rg := groupx.Normalize(b.RouteGroup)
|
||||
if rg == "" {
|
||||
continue
|
||||
}
|
||||
pgroup := providersByGroup[rg]
|
||||
if len(pgroup) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
agg := &capAgg{
|
||||
kind: inferKindFromPublicModel(pm),
|
||||
visionAllKnown: true,
|
||||
toolsAllKnown: true,
|
||||
}
|
||||
|
||||
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
|
||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||
|
||||
for _, p := range pgroup {
|
||||
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up)
|
||||
agg.merge(cap, ok)
|
||||
}
|
||||
|
||||
out := agg.finalize(key)
|
||||
bs, err := jsoncodec.Marshal(out)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal model %s: %w", key, err)
|
||||
}
|
||||
modelsOut[key] = out
|
||||
payloads[key] = string(bs)
|
||||
}
|
||||
|
||||
return modelsOut, payloads, nil
|
||||
}
|
||||
|
||||
func modelsDevProviderKey(providerType string) string {
|
||||
pt := strings.ToLower(strings.TrimSpace(providerType))
|
||||
switch pt {
|
||||
case "openai", "compatible":
|
||||
return "openai"
|
||||
case "anthropic", "claude":
|
||||
return "anthropic"
|
||||
case "gemini", "vertex", "vertex-express":
|
||||
return "google"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func lookupModelsDevCap(reg *modelsDevRegistry, providerID, modelID string) (upstreamCap, bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" || reg == nil {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
if strings.TrimSpace(providerID) != "" {
|
||||
if cap, ok := reg.ByProviderModel[providerID+"|"+modelID]; ok {
|
||||
return cap, true
|
||||
}
|
||||
}
|
||||
if cap, ok := reg.ByModel[modelID]; ok {
|
||||
return cap, true
|
||||
}
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) applyToRedis(ctx context.Context, models map[string]modelcap.Model, payloads map[string]string, meta modelcap.Meta) error {
|
||||
pipe := s.rdb.TxPipeline()
|
||||
pipe.Del(ctx, "meta:models", "meta:models_meta")
|
||||
|
||||
for name := range models {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
continue
|
||||
}
|
||||
payload := payloads[name]
|
||||
if payload == "" {
|
||||
b, err := jsoncodec.Marshal(models[name].Normalized())
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal model %s: %w", name, err)
|
||||
}
|
||||
payload = string(b)
|
||||
}
|
||||
pipe.HSet(ctx, "meta:models", name, payload)
|
||||
}
|
||||
|
||||
fields := map[string]string{
|
||||
"version": strings.TrimSpace(meta.Version),
|
||||
"updated_at": strings.TrimSpace(meta.UpdatedAt),
|
||||
"source": strings.TrimSpace(meta.Source),
|
||||
"checksum": strings.TrimSpace(meta.Checksum),
|
||||
"upstream_url": strings.TrimSpace(meta.UpstreamURL),
|
||||
"upstream_ref": strings.TrimSpace(meta.UpstreamRef),
|
||||
}
|
||||
for k, v := range fields {
|
||||
if v == "" {
|
||||
delete(fields, k)
|
||||
}
|
||||
}
|
||||
if fields["version"] == "" {
|
||||
fields["version"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
if fields["updated_at"] == "" {
|
||||
fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
if fields["source"] == "" {
|
||||
fields["source"] = "models.dev"
|
||||
}
|
||||
if fields["checksum"] == "" {
|
||||
fields["checksum"] = modelcap.ChecksumFromPayloads(payloads)
|
||||
}
|
||||
if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil {
|
||||
return fmt.Errorf("write meta:models_meta: %w", err)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("apply registry to redis: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelRegistryFile struct {
|
||||
Meta modelcap.Meta `json:"meta"`
|
||||
Models map[string]modelcap.Model `json:"models"`
|
||||
}
|
||||
|
||||
func readModelRegistryFile(path string) (*modelRegistryFile, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out modelRegistryFile
|
||||
if err := jsoncodec.Unmarshal(b, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if out.Models == nil {
|
||||
out.Models = make(map[string]modelcap.Model)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) persistCache(models map[string]modelcap.Model, meta modelcap.Meta) error {
|
||||
if err := os.MkdirAll(s.cfg.CacheDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
currentPath := filepath.Join(s.cfg.CacheDir, "current.json")
|
||||
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
|
||||
tmpPath := filepath.Join(s.cfg.CacheDir, "current.json.tmp")
|
||||
|
||||
if _, err := os.Stat(currentPath); err == nil {
|
||||
_ = os.Remove(prevPath)
|
||||
_ = os.Rename(currentPath, prevPath)
|
||||
}
|
||||
|
||||
out := modelRegistryFile{
|
||||
Meta: meta,
|
||||
Models: models,
|
||||
}
|
||||
b, err := json.MarshalIndent(out, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(tmpPath, b, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tmpPath, currentPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
internal/service/model_registry_test.go
Normal file
193
internal/service/model_registry_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustGzipTar(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gz)
|
||||
for name, body := range files {
|
||||
b := []byte(body)
|
||||
h := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0o644,
|
||||
Size: int64(len(b)),
|
||||
Typeflag: tar.TypeReg,
|
||||
}
|
||||
if err := tw.WriteHeader(h); err != nil {
|
||||
t.Fatalf("tar header: %v", err)
|
||||
}
|
||||
if _, err := tw.Write(b); err != nil {
|
||||
t.Fatalf("tar write: %v", err)
|
||||
}
|
||||
}
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("tar close: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelRegistry_RefreshAndRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Provider{
|
||||
Name: "p1",
|
||||
Type: "openai",
|
||||
Group: "rg",
|
||||
Models: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("create provider: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Binding{
|
||||
Namespace: "ns",
|
||||
PublicModel: "m",
|
||||
RouteGroup: "rg",
|
||||
SelectorType: "exact",
|
||||
SelectorValue: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("create binding: %v", err)
|
||||
}
|
||||
|
||||
tar1 := mustGzipTar(t, map[string]string{
|
||||
"sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": `
|
||||
tool_call = true
|
||||
[limit]
|
||||
context = 128000
|
||||
output = 8192
|
||||
[modalities]
|
||||
input = ["text","image"]
|
||||
`,
|
||||
})
|
||||
tar2 := mustGzipTar(t, map[string]string{
|
||||
"sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": `
|
||||
tool_call = false
|
||||
[limit]
|
||||
context = 64000
|
||||
output = 2048
|
||||
[modalities]
|
||||
input = ["text"]
|
||||
`,
|
||||
})
|
||||
|
||||
var served int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/dev" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/gzip")
|
||||
if served == 0 {
|
||||
served++
|
||||
_, _ = w.Write(tar1)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(tar2)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
|
||||
Enabled: true,
|
||||
RefreshEvery: time.Hour,
|
||||
ModelsDevBaseURL: srv.URL,
|
||||
ModelsDevRef: "dev",
|
||||
CacheDir: cacheDir,
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Refresh(ctx, "dev"); err != nil {
|
||||
t.Fatalf("refresh1: %v", err)
|
||||
}
|
||||
raw1 := mr.HGet("meta:models", "ns.m")
|
||||
if raw1 == "" {
|
||||
t.Fatalf("expected meta:models[ns.m]")
|
||||
}
|
||||
var m1 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw1), &m1); err != nil {
|
||||
t.Fatalf("unmarshal1: %v raw=%s", err, raw1)
|
||||
}
|
||||
if m1.SupportsVision != true || m1.SupportsFunction != true {
|
||||
t.Fatalf("expected vision/tools true, got %+v", m1)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
||||
t.Fatalf("expected version aaaaaaaa, got %q", v)
|
||||
}
|
||||
|
||||
if err := svc.Refresh(ctx, "dev"); err != nil {
|
||||
t.Fatalf("refresh2: %v", err)
|
||||
}
|
||||
raw2 := mr.HGet("meta:models", "ns.m")
|
||||
var m2 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw2), &m2); err != nil {
|
||||
t.Fatalf("unmarshal2: %v raw=%s", err, raw2)
|
||||
}
|
||||
// Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown;
|
||||
// here we have explicit false from models.dev and should reflect it.
|
||||
if m2.SupportsVision != false || m2.SupportsFunction != false {
|
||||
t.Fatalf("expected vision/tools false, got %+v", m2)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" {
|
||||
t.Fatalf("expected version bbbbbbbb, got %q", v)
|
||||
}
|
||||
|
||||
if err := svc.Rollback(ctx); err != nil {
|
||||
t.Fatalf("rollback: %v", err)
|
||||
}
|
||||
if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" {
|
||||
t.Fatalf("expected rollback to version aaaaaaaa, got %q", v)
|
||||
}
|
||||
raw3 := mr.HGet("meta:models", "ns.m")
|
||||
var m3 modelcap.Model
|
||||
if err := json.Unmarshal([]byte(raw3), &m3); err != nil {
|
||||
t.Fatalf("unmarshal3: %v raw=%s", err, raw3)
|
||||
}
|
||||
if m3.SupportsVision != true || m3.SupportsFunction != true {
|
||||
t.Fatalf("expected rollback vision/tools true, got %+v", m3)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(cacheDir + "/current.json"); err != nil {
|
||||
t.Fatalf("expected current cache file: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(cacheDir + "/prev.json"); err != nil {
|
||||
t.Fatalf("expected prev cache file: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -2,16 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
groupx "github.com/ez-api/foundation/group"
|
||||
"github.com/ez-api/foundation/jsoncodec"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/ez-api/foundation/routing"
|
||||
"github.com/ez-api/foundation/tokenhash"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -120,17 +118,17 @@ func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
||||
// SyncModel writes a single model metadata record.
|
||||
func (s *SyncService) SyncModel(m *model.Model) error {
|
||||
ctx := context.Background()
|
||||
snap := modelSnapshot{
|
||||
snap := modelcap.Model{
|
||||
Name: m.Name,
|
||||
Kind: normalizeModelKind(m.Kind),
|
||||
Kind: string(modelcap.NormalizeKind(m.Kind)),
|
||||
ContextWindow: m.ContextWindow,
|
||||
CostPerToken: m.CostPerToken,
|
||||
SupportsVision: m.SupportsVision,
|
||||
SupportsFunction: m.SupportsFunctions,
|
||||
SupportsToolChoice: m.SupportsToolChoice,
|
||||
SupportsFIM: m.SupportsFIM,
|
||||
SupportsFim: m.SupportsFIM,
|
||||
MaxOutputTokens: m.MaxOutputTokens,
|
||||
}
|
||||
}.Normalized()
|
||||
if err := s.hsetJSON(ctx, "meta:models", snap.Name, snap); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -159,18 +157,6 @@ type providerSnapshot struct {
|
||||
|
||||
// keySnapshot is no longer needed as we write directly to auth:token:*
|
||||
|
||||
type modelSnapshot struct {
|
||||
Name string `json:"name"`
|
||||
Kind string `json:"kind"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
CostPerToken float64 `json:"cost_per_token"`
|
||||
SupportsVision bool `json:"supports_vision"`
|
||||
SupportsFunction bool `json:"supports_functions"`
|
||||
SupportsToolChoice bool `json:"supports_tool_choice"`
|
||||
SupportsFIM bool `json:"supports_fim"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
}
|
||||
|
||||
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
|
||||
func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
ctx := context.Background()
|
||||
@@ -194,6 +180,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
if err := db.Find(&models).Error; err != nil {
|
||||
return fmt.Errorf("load models: %w", err)
|
||||
}
|
||||
var modelsPayloads map[string]string
|
||||
|
||||
var bindings []model.Binding
|
||||
if err := db.Find(&bindings).Error; err != nil {
|
||||
@@ -292,29 +279,35 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
snap := modelSnapshot{
|
||||
snap := modelcap.Model{
|
||||
Name: m.Name,
|
||||
Kind: normalizeModelKind(m.Kind),
|
||||
Kind: string(modelcap.NormalizeKind(m.Kind)),
|
||||
ContextWindow: m.ContextWindow,
|
||||
CostPerToken: m.CostPerToken,
|
||||
SupportsVision: m.SupportsVision,
|
||||
SupportsFunction: m.SupportsFunctions,
|
||||
SupportsToolChoice: m.SupportsToolChoice,
|
||||
SupportsFIM: m.SupportsFIM,
|
||||
SupportsFim: m.SupportsFIM,
|
||||
MaxOutputTokens: m.MaxOutputTokens,
|
||||
}
|
||||
}.Normalized()
|
||||
payload, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal model %s: %w", m.Name, err)
|
||||
}
|
||||
// Capture payloads so we can compute deterministic checksum for meta:models_meta.
|
||||
if modelsPayloads == nil {
|
||||
modelsPayloads = make(map[string]string, len(models))
|
||||
}
|
||||
modelsPayloads[snap.Name] = string(payload)
|
||||
pipe.HSet(ctx, "meta:models", snap.Name, payload)
|
||||
}
|
||||
|
||||
if err := writeModelsMeta(ctx, pipe, modelsMetaInput{
|
||||
now := time.Now().Unix()
|
||||
if err := writeModelsMeta(ctx, pipe, modelcap.Meta{
|
||||
Version: fmt.Sprintf("%d", now),
|
||||
UpdatedAt: fmt.Sprintf("%d", now),
|
||||
Source: "db",
|
||||
Version: fmt.Sprintf("%d", time.Now().Unix()),
|
||||
UpdatedAtSec: time.Now().Unix(),
|
||||
Models: models,
|
||||
Checksum: modelcap.ChecksumFromPayloads(modelsPayloads),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -469,36 +462,6 @@ func normalizeStatus(status string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeModelKind(kind string) string {
|
||||
k := strings.ToLower(strings.TrimSpace(kind))
|
||||
if k == "" {
|
||||
return "chat"
|
||||
}
|
||||
switch k {
|
||||
case "chat", "embedding", "rerank", "other":
|
||||
return k
|
||||
default:
|
||||
return "other"
|
||||
}
|
||||
}
|
||||
|
||||
func checksumModelPayloads(payloads map[string]string) string {
|
||||
keys := make([]string, 0, len(payloads))
|
||||
for k := range payloads {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
h := sha256.New()
|
||||
for _, k := range keys {
|
||||
_, _ = h.Write([]byte(k))
|
||||
_, _ = h.Write([]byte{'\n'})
|
||||
_, _ = h.Write([]byte(payloads[k]))
|
||||
_, _ = h.Write([]byte{'\n'})
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source string) error {
|
||||
raw, err := s.rdb.HGetAll(ctx, "meta:models").Result()
|
||||
if err != nil {
|
||||
@@ -509,7 +472,7 @@ func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source str
|
||||
"version": fmt.Sprintf("%d", now),
|
||||
"updated_at": fmt.Sprintf("%d", now),
|
||||
"source": source,
|
||||
"checksum": checksumModelPayloads(raw),
|
||||
"checksum": modelcap.ChecksumFromPayloads(raw),
|
||||
}
|
||||
if err := s.rdb.HSet(ctx, "meta:models_meta", meta).Err(); err != nil {
|
||||
return fmt.Errorf("write meta:models_meta: %w", err)
|
||||
@@ -517,47 +480,33 @@ func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source str
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelsMetaInput struct {
|
||||
Source string
|
||||
Version string
|
||||
UpdatedAtSec int64
|
||||
Models []model.Model
|
||||
func writeModelsMeta(ctx context.Context, pipe redis.Pipeliner, meta modelcap.Meta) error {
|
||||
fields := map[string]string{
|
||||
"version": strings.TrimSpace(meta.Version),
|
||||
"updated_at": strings.TrimSpace(meta.UpdatedAt),
|
||||
"source": strings.TrimSpace(meta.Source),
|
||||
"checksum": strings.TrimSpace(meta.Checksum),
|
||||
"upstream_url": strings.TrimSpace(meta.UpstreamURL),
|
||||
"upstream_ref": strings.TrimSpace(meta.UpstreamRef),
|
||||
}
|
||||
|
||||
func writeModelsMeta(ctx context.Context, pipe redis.Pipeliner, in modelsMetaInput) error {
|
||||
payloads := make(map[string]string, len(in.Models))
|
||||
for _, m := range in.Models {
|
||||
snap := modelSnapshot{
|
||||
Name: m.Name,
|
||||
Kind: normalizeModelKind(m.Kind),
|
||||
ContextWindow: m.ContextWindow,
|
||||
CostPerToken: m.CostPerToken,
|
||||
SupportsVision: m.SupportsVision,
|
||||
SupportsFunction: m.SupportsFunctions,
|
||||
SupportsToolChoice: m.SupportsToolChoice,
|
||||
SupportsFIM: m.SupportsFIM,
|
||||
MaxOutputTokens: m.MaxOutputTokens,
|
||||
for k, v := range fields {
|
||||
if v == "" {
|
||||
delete(fields, k)
|
||||
}
|
||||
b, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal model %s for meta: %w", m.Name, err)
|
||||
}
|
||||
payloads[snap.Name] = string(b)
|
||||
if fields["version"] == "" {
|
||||
fields["version"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
"version": strings.TrimSpace(in.Version),
|
||||
"updated_at": fmt.Sprintf("%d", in.UpdatedAtSec),
|
||||
"source": strings.TrimSpace(in.Source),
|
||||
"checksum": checksumModelPayloads(payloads),
|
||||
if fields["updated_at"] == "" {
|
||||
fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
if strings.TrimSpace(meta["version"]) == "" {
|
||||
meta["version"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
if fields["source"] == "" {
|
||||
fields["source"] = "db"
|
||||
}
|
||||
if strings.TrimSpace(meta["source"]) == "" {
|
||||
meta["source"] = "db"
|
||||
if fields["checksum"] == "" {
|
||||
fields["checksum"] = "unknown"
|
||||
}
|
||||
if err := pipe.HSet(ctx, "meta:models_meta", meta).Err(); err != nil {
|
||||
if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil {
|
||||
return fmt.Errorf("write meta:models_meta: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user