package service import ( "archive/tar" "bytes" "compress/gzip" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "regexp" "strings" "sync" "time" "github.com/ez-api/ez-api/internal/model" groupx "github.com/ez-api/foundation/group" "github.com/ez-api/foundation/jsoncodec" "github.com/ez-api/foundation/modelcap" "github.com/ez-api/foundation/routing" "github.com/pelletier/go-toml/v2" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) type ModelRegistryConfig struct { Enabled bool RefreshEvery time.Duration ModelsDevBaseURL string ModelsDevAPIBaseURL string ModelsDevRef string CacheDir string Timeout time.Duration } type ModelRegistryService struct { db *gorm.DB rdb *redis.Client cfg ModelRegistryConfig client *http.Client mu sync.Mutex lastError string lastRefresh time.Time lastApplied modelcap.Meta lastUpstream string } func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryConfig) *ModelRegistryService { if cfg.RefreshEvery <= 0 { cfg.RefreshEvery = 30 * time.Minute } if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" { cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz" } if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" { cfg.ModelsDevAPIBaseURL = "https://api.github.com" } if strings.TrimSpace(cfg.ModelsDevRef) == "" { cfg.ModelsDevRef = "dev" } if strings.TrimSpace(cfg.CacheDir) == "" { cfg.CacheDir = "./data/model-registry" } if cfg.Timeout <= 0 { cfg.Timeout = 30 * time.Second } return &ModelRegistryService{ db: db, rdb: rdb, cfg: cfg, client: &http.Client{ Timeout: cfg.Timeout, }, } } func (s *ModelRegistryService) Start(ctx context.Context) { if !s.cfg.Enabled { return } go func() { ticker := time.NewTicker(s.cfg.RefreshEvery) defer ticker.Stop() // Best-effort initial refresh. _ = s.Refresh(ctx, "") for { select { case <-ctx.Done(): return case <-ticker.C: _ = s.Refresh(ctx, "") } } }() } type ModelRegistryStatus struct { Enabled bool `json:"enabled"` ModelsDevRef string `json:"models_dev_ref"` ModelsDevURL string `json:"models_dev_url"` LastRefreshAt int64 `json:"last_refresh_at,omitempty"` LastError string `json:"last_error,omitempty"` RedisMeta map[string]string `json:"redis_meta,omitempty"` CacheCurrent *modelRegistryFile `json:"cache_current,omitempty"` CachePrev *modelRegistryFile `json:"cache_prev,omitempty"` } func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus, error) { s.mu.Lock() lastErr := s.lastError lastRefresh := s.lastRefresh s.mu.Unlock() redisMeta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result() current, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "current.json")) prev, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "prev.json")) out := &ModelRegistryStatus{ Enabled: s.cfg.Enabled, ModelsDevRef: s.cfg.ModelsDevRef, ModelsDevURL: strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + s.cfg.ModelsDevRef, LastError: lastErr, RedisMeta: redisMeta, CacheCurrent: current, CachePrev: prev, } if !lastRefresh.IsZero() { out.LastRefreshAt = lastRefresh.Unix() } return out, nil } type ModelRegistryCheckResult struct { Enabled bool `json:"enabled"` UpstreamRef string `json:"upstream_ref"` CurrentVersion string `json:"current_version,omitempty"` LatestVersion string `json:"latest_version,omitempty"` NeedsRefresh bool `json:"needs_refresh"` CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"` } func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) { if strings.TrimSpace(ref) == "" { ref = s.cfg.ModelsDevRef } latest, err := s.fetchModelsDevRefSHA(ctx, ref) if err != nil { return nil, err } meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result() currentVersion := strings.TrimSpace(meta["version"]) currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"]) needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest) return &ModelRegistryCheckResult{ Enabled: s.cfg.Enabled, UpstreamRef: ref, CurrentVersion: currentVersion, LatestVersion: latest, NeedsRefresh: needsRefresh, CurrentUpstreamRef: currentUpstreamRef, }, nil } func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error { if strings.TrimSpace(ref) == "" { ref = s.cfg.ModelsDevRef } s.mu.Lock() s.lastUpstream = ref s.mu.Unlock() tarballURL := strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + ref reg, version, err := s.fetchModelsDev(ctx, tarballURL) if err != nil { s.setError(err) return err } models, payloads, err := s.buildBindingModels(ctx, reg) if err != nil { s.setError(err) return err } // Overlay explicit DB entries (admin overrides). var dbModels []model.Model if err := s.db.Find(&dbModels).Error; err != nil { s.setError(fmt.Errorf("load db models: %w", err)) return err } for _, m := range dbModels { snap := modelcap.Model{ Name: m.Name, Kind: string(modelcap.NormalizeKind(m.Kind)), ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, SupportsVision: m.SupportsVision, SupportsFunction: m.SupportsFunctions, SupportsToolChoice: m.SupportsToolChoice, SupportsFim: m.SupportsFIM, SupportsStream: true, MaxOutputTokens: m.MaxOutputTokens, }.Normalized() b, err := jsoncodec.Marshal(snap) if err != nil { s.setError(fmt.Errorf("marshal db model %s: %w", m.Name, err)) return err } models[snap.Name] = snap payloads[snap.Name] = string(b) } now := time.Now().Unix() meta := modelcap.Meta{ Version: version, UpdatedAt: fmt.Sprintf("%d", now), Source: "models.dev", Checksum: modelcap.ChecksumFromPayloads(payloads), UpstreamURL: "https://github.com/sst/models.dev", UpstreamRef: ref, } if err := s.applyToRedis(ctx, models, payloads, meta); err != nil { s.setError(err) return err } if err := s.persistCache(models, meta); err != nil { s.setError(err) return err } s.mu.Lock() s.lastError = "" s.lastRefresh = time.Now() s.lastApplied = meta s.mu.Unlock() return nil } func (s *ModelRegistryService) Rollback(ctx context.Context) error { prevPath := filepath.Join(s.cfg.CacheDir, "prev.json") prev, err := readModelRegistryFile(prevPath) if err != nil { s.setError(fmt.Errorf("read prev cache: %w", err)) return err } if prev == nil { s.setError(fmt.Errorf("no prev cache")) return fmt.Errorf("no prev cache") } payloads := make(map[string]string, len(prev.Models)) for name, m := range prev.Models { m = m.Normalized() b, err := jsoncodec.Marshal(m) if err != nil { s.setError(fmt.Errorf("marshal cached model %s: %w", name, err)) return err } prev.Models[name] = m payloads[name] = string(b) } prev.Meta.Checksum = modelcap.ChecksumFromPayloads(payloads) prev.Meta.Source = "rollback" prev.Meta.UpdatedAt = fmt.Sprintf("%d", time.Now().Unix()) if err := s.applyToRedis(ctx, prev.Models, payloads, prev.Meta); err != nil { s.setError(err) return err } // Swap cache current/prev to reflect rollback. if err := s.persistCache(prev.Models, prev.Meta); err != nil { s.setError(err) return err } s.mu.Lock() s.lastError = "" s.lastRefresh = time.Now() s.lastApplied = prev.Meta s.mu.Unlock() return nil } func (s *ModelRegistryService) setError(err error) { s.mu.Lock() defer s.mu.Unlock() s.lastError = err.Error() s.lastRefresh = time.Now() } func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) { base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/") if base == "" { base = "https://api.github.com" } u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { return "", err } req.Header.Set("Accept", "application/vnd.github+json") req.Header.Set("User-Agent", "ez-api") resp, err := s.client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode/100 != 2 { b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b)) } var payload struct { SHA string `json:"sha"` } if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil { return "", fmt.Errorf("decode github commit: %w", err) } sha := strings.TrimSpace(payload.SHA) if sha == "" { return "", fmt.Errorf("github commit sha missing for ref %q", ref) } return sha, nil } func versionsEqual(current, latest string) bool { c := strings.TrimSpace(current) l := strings.TrimSpace(latest) if c == "" || l == "" { return false } if len(c) <= len(l) && strings.HasPrefix(l, c) { return true } if len(l) <= len(c) && strings.HasPrefix(c, l) { return true } return false } type boolVal struct { Known bool Val bool } type intVal struct { Known bool Val int } type upstreamCap struct { ContextWindow intVal MaxOutputTokens intVal SupportsVision boolVal SupportsTools boolVal } type modelsDevRegistry struct { ByProviderModel map[string]upstreamCap // key: providerID|modelID ByModel map[string]upstreamCap // fallback: modelID } var shaSuffix = regexp.MustCompile(`-([0-9a-f]{7,40})$`) func (s *ModelRegistryService) fetchModelsDev(ctx context.Context, tarballURL string) (*modelsDevRegistry, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, tarballURL, nil) if err != nil { return nil, "", err } resp, err := s.client.Do(req) if err != nil { return nil, "", err } defer resp.Body.Close() if resp.StatusCode/100 != 2 { b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) return nil, "", fmt.Errorf("models.dev fetch failed: %s body=%s", resp.Status, string(b)) } tarball, err := io.ReadAll(resp.Body) if err != nil { return nil, "", err } gr, err := gzip.NewReader(bytes.NewReader(tarball)) if err != nil { return nil, "", err } defer gr.Close() tr := tar.NewReader(gr) reg := &modelsDevRegistry{ ByProviderModel: make(map[string]upstreamCap), ByModel: make(map[string]upstreamCap), } version := "" for { hdr, err := tr.Next() if err == io.EOF { break } if err != nil { return nil, "", err } name := hdr.Name if version == "" { if parts := strings.SplitN(name, "/", 2); len(parts) > 0 { if m := shaSuffix.FindStringSubmatch(parts[0]); len(m) == 2 { version = m[1] } } } if hdr.Typeflag != tar.TypeReg { continue } if !strings.HasSuffix(name, ".toml") { continue } providerID, modelID, ok := parseModelsDevModelPath(name) if !ok { continue } data, err := io.ReadAll(io.LimitReader(tr, 1<<20)) if err != nil { return nil, "", err } cap, ok := parseModelsDevTOML(data) if !ok { continue } key := providerID + "|" + modelID reg.ByProviderModel[key] = cap if _, exists := reg.ByModel[modelID]; !exists { reg.ByModel[modelID] = cap } // Also store a flattened variant for providers that may use subdirs. if alt := strings.ReplaceAll(modelID, "/", "-"); alt != modelID { if _, exists := reg.ByModel[alt]; !exists { reg.ByModel[alt] = cap } altKey := providerID + "|" + alt if _, exists := reg.ByProviderModel[altKey]; !exists { reg.ByProviderModel[altKey] = cap } } } if version == "" { version = fmt.Sprintf("%d", time.Now().Unix()) } return reg, version, nil } func parseModelsDevModelPath(path string) (providerID string, modelID string, ok bool) { parts := strings.Split(path, "/") for i := 0; i < len(parts); i++ { if parts[i] != "providers" { continue } if i+3 >= len(parts) { return "", "", false } providerID = strings.TrimSpace(parts[i+1]) if providerID == "" { return "", "", false } if parts[i+2] != "models" { return "", "", false } modelPart := strings.Join(parts[i+3:], "/") modelPart = strings.TrimSuffix(modelPart, ".toml") modelPart = strings.TrimSpace(modelPart) if modelPart == "" { return "", "", false } return providerID, modelPart, true } return "", "", false } func parseModelsDevTOML(data []byte) (upstreamCap, bool) { var doc map[string]any if err := toml.Unmarshal(data, &doc); err != nil { return upstreamCap{}, false } var cap upstreamCap if v, ok := getBool(doc, "tool_call"); ok { cap.SupportsTools = boolVal{Known: true, Val: v} } if limit, ok := getMap(doc, "limit"); ok { if v, ok := getInt(limit, "context"); ok { cap.ContextWindow = intVal{Known: true, Val: v} } if v, ok := getInt(limit, "output"); ok { cap.MaxOutputTokens = intVal{Known: true, Val: v} } } if mods, ok := getMap(doc, "modalities"); ok { if input, ok := getStringSlice(mods["input"]); ok { hasImage := false for _, it := range input { if strings.EqualFold(strings.TrimSpace(it), "image") { hasImage = true break } } cap.SupportsVision = boolVal{Known: true, Val: hasImage} } } if !cap.SupportsTools.Known && !cap.SupportsVision.Known && !cap.ContextWindow.Known && !cap.MaxOutputTokens.Known { return upstreamCap{}, false } return cap, true } func getMap(doc map[string]any, key string) (map[string]any, bool) { if doc == nil { return nil, false } v, ok := doc[key] if !ok || v == nil { return nil, false } m, ok := v.(map[string]any) return m, ok } func getBool(doc map[string]any, key string) (bool, bool) { v, ok := doc[key] if !ok || v == nil { return false, false } switch b := v.(type) { case bool: return b, true default: return false, false } } func getInt(doc map[string]any, key string) (int, bool) { v, ok := doc[key] if !ok || v == nil { return 0, false } switch n := v.(type) { case int64: return int(n), true case int: return n, true case float64: return int(n), true default: return 0, false } } func getStringSlice(v any) ([]string, bool) { if v == nil { return nil, false } switch s := v.(type) { case []any: out := make([]string, 0, len(s)) for _, it := range s { if str, ok := it.(string); ok { out = append(out, str) } } return out, true case []string: return append([]string(nil), s...), true default: return nil, false } } type capAgg struct { kind string visionAnyTrue bool visionAllKnown bool visionKnownAny bool toolsAnyTrue bool toolsAllKnown bool toolsKnownAny bool maxOutputKnown bool maxOutputMax int contextKnown bool contextMax int } func inferKindFromPublicModel(publicModel string) string { pm := strings.ToLower(strings.TrimSpace(publicModel)) if strings.Contains(pm, "embedding") { return string(modelcap.KindEmbedding) } if strings.Contains(pm, "rerank") { return string(modelcap.KindRerank) } return string(modelcap.KindChat) } func (a *capAgg) merge(cap upstreamCap, ok bool) { if ok { // vision if cap.SupportsVision.Known { a.visionKnownAny = true if cap.SupportsVision.Val { a.visionAnyTrue = true } } else { a.visionAllKnown = false } // tools if cap.SupportsTools.Known { a.toolsKnownAny = true if cap.SupportsTools.Val { a.toolsAnyTrue = true } } else { a.toolsAllKnown = false } // limits if cap.MaxOutputTokens.Known && cap.MaxOutputTokens.Val > 0 { a.maxOutputKnown = true if cap.MaxOutputTokens.Val > a.maxOutputMax { a.maxOutputMax = cap.MaxOutputTokens.Val } } if cap.ContextWindow.Known && cap.ContextWindow.Val > 0 { a.contextKnown = true if cap.ContextWindow.Val > a.contextMax { a.contextMax = cap.ContextWindow.Val } } return } // Unknown upstream capability: do not force blocking, treat as "not all known". a.visionAllKnown = false a.toolsAllKnown = false } func (a *capAgg) finalize(name string) modelcap.Model { // Safe defaults: unknown -> allow (true) so we avoid false blocking. supportsVision := true if a.visionKnownAny { if a.visionAnyTrue { supportsVision = true } else if a.visionAllKnown { supportsVision = false } } supportsTools := true if a.toolsKnownAny { if a.toolsAnyTrue { supportsTools = true } else if a.toolsAllKnown { supportsTools = false } } out := modelcap.Model{ Name: name, Kind: string(modelcap.NormalizeKind(a.kind)), ContextWindow: 0, CostPerToken: 0, SupportsVision: supportsVision, SupportsFunction: supportsTools, SupportsToolChoice: true, SupportsFim: true, SupportsStream: true, MaxOutputTokens: 0, } if a.contextKnown { out.ContextWindow = a.contextMax } if a.maxOutputKnown { out.MaxOutputTokens = a.maxOutputMax } return out.Normalized() } func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) { var providers []model.Provider if err := s.db.Find(&providers).Error; err != nil { return nil, nil, fmt.Errorf("load providers: %w", err) } var bindings []model.Binding if err := s.db.Find(&bindings).Error; err != nil { return nil, nil, fmt.Errorf("load bindings: %w", err) } type providerLite struct { id uint group string ptype string models []string } providersByGroup := make(map[string][]providerLite) now := time.Now().Unix() for _, p := range providers { if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" { continue } if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now { continue } group := groupx.Normalize(p.Group) rawModels := strings.Split(p.Models, ",") var outModels []string for _, m := range rawModels { m = strings.TrimSpace(m) if m != "" { outModels = append(outModels, m) } } if group == "" || len(outModels) == 0 { continue } providersByGroup[group] = append(providersByGroup[group], providerLite{ id: p.ID, group: group, ptype: strings.TrimSpace(p.Type), models: outModels, }) } modelsOut := make(map[string]modelcap.Model) payloads := make(map[string]string) for _, b := range bindings { if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" { continue } ns := strings.TrimSpace(b.Namespace) pm := strings.TrimSpace(b.PublicModel) if ns == "" || pm == "" { continue } key := ns + "." + pm rg := groupx.Normalize(b.RouteGroup) if rg == "" { continue } pgroup := providersByGroup[rg] if len(pgroup) == 0 { continue } agg := &capAgg{ kind: inferKindFromPublicModel(pm), visionAllKnown: true, toolsAllKnown: true, } selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType)) selectorValue := strings.TrimSpace(b.SelectorValue) for _, p := range pgroup { up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models) if err != nil { continue } cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up) agg.merge(cap, ok) } out := agg.finalize(key) bs, err := jsoncodec.Marshal(out) if err != nil { return nil, nil, fmt.Errorf("marshal model %s: %w", key, err) } modelsOut[key] = out payloads[key] = string(bs) } return modelsOut, payloads, nil } func modelsDevProviderKey(providerType string) string { pt := strings.ToLower(strings.TrimSpace(providerType)) switch pt { case "openai", "compatible": return "openai" case "anthropic", "claude": return "anthropic" case "gemini", "vertex", "vertex-express": return "google" default: return "" } } func lookupModelsDevCap(reg *modelsDevRegistry, providerID, modelID string) (upstreamCap, bool) { modelID = strings.TrimSpace(modelID) if modelID == "" || reg == nil { return upstreamCap{}, false } if strings.TrimSpace(providerID) != "" { if cap, ok := reg.ByProviderModel[providerID+"|"+modelID]; ok { return cap, true } } if cap, ok := reg.ByModel[modelID]; ok { return cap, true } return upstreamCap{}, false } func (s *ModelRegistryService) applyToRedis(ctx context.Context, models map[string]modelcap.Model, payloads map[string]string, meta modelcap.Meta) error { pipe := s.rdb.TxPipeline() pipe.Del(ctx, "meta:models", "meta:models_meta") for name := range models { if strings.TrimSpace(name) == "" { continue } payload := payloads[name] if payload == "" { b, err := jsoncodec.Marshal(models[name].Normalized()) if err != nil { return fmt.Errorf("marshal model %s: %w", name, err) } payload = string(b) } pipe.HSet(ctx, "meta:models", name, payload) } fields := map[string]string{ "version": strings.TrimSpace(meta.Version), "updated_at": strings.TrimSpace(meta.UpdatedAt), "source": strings.TrimSpace(meta.Source), "checksum": strings.TrimSpace(meta.Checksum), "upstream_url": strings.TrimSpace(meta.UpstreamURL), "upstream_ref": strings.TrimSpace(meta.UpstreamRef), } for k, v := range fields { if v == "" { delete(fields, k) } } if fields["version"] == "" { fields["version"] = fmt.Sprintf("%d", time.Now().Unix()) } if fields["updated_at"] == "" { fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix()) } if fields["source"] == "" { fields["source"] = "models.dev" } if fields["checksum"] == "" { fields["checksum"] = modelcap.ChecksumFromPayloads(payloads) } if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil { return fmt.Errorf("write meta:models_meta: %w", err) } if _, err := pipe.Exec(ctx); err != nil { return fmt.Errorf("apply registry to redis: %w", err) } return nil } type modelRegistryFile struct { Meta modelcap.Meta `json:"meta"` Models map[string]modelcap.Model `json:"models"` } func readModelRegistryFile(path string) (*modelRegistryFile, error) { b, err := os.ReadFile(path) if err != nil { return nil, err } var out modelRegistryFile if err := jsoncodec.Unmarshal(b, &out); err != nil { return nil, err } if out.Models == nil { out.Models = make(map[string]modelcap.Model) } return &out, nil } func (s *ModelRegistryService) persistCache(models map[string]modelcap.Model, meta modelcap.Meta) error { if err := os.MkdirAll(s.cfg.CacheDir, 0o755); err != nil { return err } currentPath := filepath.Join(s.cfg.CacheDir, "current.json") prevPath := filepath.Join(s.cfg.CacheDir, "prev.json") tmpPath := filepath.Join(s.cfg.CacheDir, "current.json.tmp") if _, err := os.Stat(currentPath); err == nil { _ = os.Remove(prevPath) _ = os.Rename(currentPath, prevPath) } out := modelRegistryFile{ Meta: meta, Models: models, } b, err := json.MarshalIndent(out, "", " ") if err != nil { return err } if err := os.WriteFile(tmpPath, b, 0o644); err != nil { return err } if err := os.Rename(tmpPath, currentPath); err != nil { return err } return nil }