diff --git a/cmd/server/main.go b/cmd/server/main.go index 882f99f..88f9d24 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -114,11 +114,21 @@ func main() { adminHandler := api.NewAdminHandler(masterService, syncService) masterHandler := api.NewMasterHandler(masterService, syncService) featureHandler := api.NewFeatureHandler(rdb) + modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{ + Enabled: cfg.ModelRegistry.Enabled, + RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second, + ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL, + ModelsDevRef: cfg.ModelRegistry.ModelsDevRef, + CacheDir: cfg.ModelRegistry.CacheDir, + Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second, + }) + modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService) // 4.1 Prime Redis snapshots so DP can start with data if err := syncService.SyncAll(db); err != nil { logger.Warn("initial sync warning", "err", err) } + modelRegistryService.Start(context.Background()) // 5. Setup Gin Router r := gin.Default() @@ -165,6 +175,9 @@ func main() { adminGroup.PUT("/keys/:id/access", handler.UpdateKeyAccess) adminGroup.GET("/features", featureHandler.ListFeatures) adminGroup.PUT("/features", featureHandler.UpdateFeatures) + adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus) + adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh) + adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback) // Other admin routes for managing providers, models, etc. adminGroup.POST("/providers", handler.CreateProvider) adminGroup.POST("/providers/preset", handler.CreateProviderPreset) diff --git a/go.mod b/go.mod index 8e13999..104268a 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alicebob/miniredis/v2 v2.35.0 github.com/ez-api/foundation v0.2.0 github.com/gin-gonic/gin v1.11.0 + github.com/pelletier/go-toml/v2 v2.2.4 github.com/redis/go-redis/v9 v9.17.2 github.com/spf13/viper v1.21.0 github.com/swaggo/files v1.0.1 @@ -17,6 +18,8 @@ require ( gorm.io/gorm v1.31.1 ) +replace github.com/ez-api/foundation => ../foundation + require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/bytedance/sonic v1.14.0 // indirect @@ -57,7 +60,6 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.54.0 // indirect github.com/rs/zerolog v1.34.0 // indirect diff --git a/go.sum b/go.sum index cfc7b99..dbbb52d 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/ez-api/foundation v0.2.0 h1:QdxbXZ2wr4O08Uxl6QK4fJZPrWb09yMNI+hS2aRSqG8= -github.com/ez-api/foundation v0.2.0/go.mod h1:bTh1LA42TW4CXi1SebDEUE+fhEssFUphzcGEzyAFFZI= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= diff --git a/internal/api/model_registry_handler.go b/internal/api/model_registry_handler.go new file mode 100644 index 0000000..4fb0977 --- /dev/null +++ b/internal/api/model_registry_handler.go @@ -0,0 +1,91 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/ez-api/ez-api/internal/service" + "github.com/gin-gonic/gin" +) + +type ModelRegistryHandler struct { + reg *service.ModelRegistryService +} + +func NewModelRegistryHandler(reg *service.ModelRegistryService) *ModelRegistryHandler { + return &ModelRegistryHandler{reg: reg} +} + +// GetModelRegistryStatus godoc +// @Summary Get model registry status +// @Description Returns Redis meta and local last-good cache info for model capability registry +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Success 200 {object} service.ModelRegistryStatus +// @Failure 500 {object} gin.H +// @Router /admin/model-registry/status [get] +func (h *ModelRegistryHandler) GetStatus(c *gin.Context) { + if h == nil || h.reg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"}) + return + } + st, err := h.reg.Status(c.Request.Context()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get model registry status", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, st) +} + +type refreshModelRegistryRequest struct { + Ref string `json:"ref"` +} + +// RefreshModelRegistry godoc +// @Summary Refresh model registry from models.dev +// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param body body refreshModelRegistryRequest false "optional override ref" +// @Success 200 {object} gin.H +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/model-registry/refresh [post] +func (h *ModelRegistryHandler) Refresh(c *gin.Context) { + if h == nil || h.reg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"}) + return + } + var req refreshModelRegistryRequest + _ = c.ShouldBindJSON(&req) + ref := strings.TrimSpace(req.Ref) + if err := h.reg.Refresh(c.Request.Context(), ref); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh model registry", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "refreshed"}) +} + +// RollbackModelRegistry godoc +// @Summary Rollback model registry +// @Description Rollback meta:models to previous cached last-good version +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Success 200 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/model-registry/rollback [post] +func (h *ModelRegistryHandler) Rollback(c *gin.Context) { + if h == nil || h.reg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"}) + return + } + if err := h.reg.Rollback(c.Request.Context()); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to rollback model registry", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "rolled_back"}) +} diff --git a/internal/config/config.go b/internal/config/config.go index 192d08d..125fb01 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,11 +10,12 @@ import ( ) type Config struct { - Server ServerConfig - Postgres PostgresConfig - Redis RedisConfig - Log LogConfig - Auth AuthConfig + Server ServerConfig + Postgres PostgresConfig + Redis RedisConfig + Log LogConfig + Auth AuthConfig + ModelRegistry ModelRegistryConfig } type ServerConfig struct { @@ -41,6 +42,15 @@ type LogConfig struct { QueueCapacity int } +type ModelRegistryConfig struct { + Enabled bool + RefreshSeconds int + ModelsDevBaseURL string + ModelsDevRef string + CacheDir string + TimeoutSeconds int +} + func Load() (*Config, error) { v := viper.New() @@ -53,6 +63,12 @@ func Load() (*Config, error) { v.SetDefault("log.flush_ms", 1000) v.SetDefault("log.queue_capacity", 10000) v.SetDefault("auth.jwt_secret", "change_me_in_production") + v.SetDefault("model_registry.enabled", false) + v.SetDefault("model_registry.refresh_seconds", 1800) + v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz") + v.SetDefault("model_registry.models_dev_ref", "dev") + v.SetDefault("model_registry.cache_dir", "./data/model-registry") + v.SetDefault("model_registry.timeout_seconds", 30) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() @@ -66,6 +82,12 @@ func Load() (*Config, error) { _ = v.BindEnv("log.flush_ms", "EZ_LOG_FLUSH_MS") _ = v.BindEnv("log.queue_capacity", "EZ_LOG_QUEUE") _ = v.BindEnv("auth.jwt_secret", "EZ_JWT_SECRET") + _ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED") + _ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS") + _ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL") + _ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF") + _ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR") + _ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS") if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" { v.SetConfigFile(configFile) @@ -102,6 +124,14 @@ func Load() (*Config, error) { Auth: AuthConfig{ JWTSecret: v.GetString("auth.jwt_secret"), }, + ModelRegistry: ModelRegistryConfig{ + Enabled: v.GetBool("model_registry.enabled"), + RefreshSeconds: v.GetInt("model_registry.refresh_seconds"), + ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"), + ModelsDevRef: v.GetString("model_registry.models_dev_ref"), + CacheDir: v.GetString("model_registry.cache_dir"), + TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"), + }, } return cfg, nil diff --git a/internal/service/model_registry.go b/internal/service/model_registry.go new file mode 100644 index 0000000..411c5ef --- /dev/null +++ b/internal/service/model_registry.go @@ -0,0 +1,844 @@ +package service + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/ez-api/ez-api/internal/model" + groupx "github.com/ez-api/foundation/group" + "github.com/ez-api/foundation/jsoncodec" + "github.com/ez-api/foundation/modelcap" + "github.com/ez-api/foundation/routing" + "github.com/pelletier/go-toml/v2" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +type ModelRegistryConfig struct { + Enabled bool + RefreshEvery time.Duration + + ModelsDevBaseURL string + ModelsDevRef string + + CacheDir string + Timeout time.Duration +} + +type ModelRegistryService struct { + db *gorm.DB + rdb *redis.Client + + cfg ModelRegistryConfig + client *http.Client + + mu sync.Mutex + lastError string + lastRefresh time.Time + lastApplied modelcap.Meta + lastUpstream string +} + +func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryConfig) *ModelRegistryService { + if cfg.RefreshEvery <= 0 { + cfg.RefreshEvery = 30 * time.Minute + } + if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" { + cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz" + } + if strings.TrimSpace(cfg.ModelsDevRef) == "" { + cfg.ModelsDevRef = "dev" + } + if strings.TrimSpace(cfg.CacheDir) == "" { + cfg.CacheDir = "./data/model-registry" + } + if cfg.Timeout <= 0 { + cfg.Timeout = 30 * time.Second + } + return &ModelRegistryService{ + db: db, + rdb: rdb, + cfg: cfg, + client: &http.Client{ + Timeout: cfg.Timeout, + }, + } +} + +func (s *ModelRegistryService) Start(ctx context.Context) { + if !s.cfg.Enabled { + return + } + go func() { + ticker := time.NewTicker(s.cfg.RefreshEvery) + defer ticker.Stop() + + // Best-effort initial refresh. + _ = s.Refresh(ctx, "") + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _ = s.Refresh(ctx, "") + } + } + }() +} + +type ModelRegistryStatus struct { + Enabled bool `json:"enabled"` + ModelsDevRef string `json:"models_dev_ref"` + ModelsDevURL string `json:"models_dev_url"` + LastRefreshAt int64 `json:"last_refresh_at,omitempty"` + LastError string `json:"last_error,omitempty"` + RedisMeta map[string]string `json:"redis_meta,omitempty"` + CacheCurrent *modelRegistryFile `json:"cache_current,omitempty"` + CachePrev *modelRegistryFile `json:"cache_prev,omitempty"` +} + +func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus, error) { + s.mu.Lock() + lastErr := s.lastError + lastRefresh := s.lastRefresh + s.mu.Unlock() + + redisMeta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result() + + current, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "current.json")) + prev, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "prev.json")) + + out := &ModelRegistryStatus{ + Enabled: s.cfg.Enabled, + ModelsDevRef: s.cfg.ModelsDevRef, + ModelsDevURL: strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + s.cfg.ModelsDevRef, + LastError: lastErr, + RedisMeta: redisMeta, + CacheCurrent: current, + CachePrev: prev, + } + if !lastRefresh.IsZero() { + out.LastRefreshAt = lastRefresh.Unix() + } + return out, nil +} + +func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error { + if strings.TrimSpace(ref) == "" { + ref = s.cfg.ModelsDevRef + } + + s.mu.Lock() + s.lastUpstream = ref + s.mu.Unlock() + + tarballURL := strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + ref + reg, version, err := s.fetchModelsDev(ctx, tarballURL) + if err != nil { + s.setError(err) + return err + } + + models, payloads, err := s.buildBindingModels(ctx, reg) + if err != nil { + s.setError(err) + return err + } + + // Overlay explicit DB entries (admin overrides). + var dbModels []model.Model + if err := s.db.Find(&dbModels).Error; err != nil { + s.setError(fmt.Errorf("load db models: %w", err)) + return err + } + for _, m := range dbModels { + snap := modelcap.Model{ + Name: m.Name, + Kind: string(modelcap.NormalizeKind(m.Kind)), + ContextWindow: m.ContextWindow, + CostPerToken: m.CostPerToken, + SupportsVision: m.SupportsVision, + SupportsFunction: m.SupportsFunctions, + SupportsToolChoice: m.SupportsToolChoice, + SupportsFim: m.SupportsFIM, + MaxOutputTokens: m.MaxOutputTokens, + }.Normalized() + b, err := jsoncodec.Marshal(snap) + if err != nil { + s.setError(fmt.Errorf("marshal db model %s: %w", m.Name, err)) + return err + } + models[snap.Name] = snap + payloads[snap.Name] = string(b) + } + + now := time.Now().Unix() + meta := modelcap.Meta{ + Version: version, + UpdatedAt: fmt.Sprintf("%d", now), + Source: "models.dev", + Checksum: modelcap.ChecksumFromPayloads(payloads), + UpstreamURL: "https://github.com/sst/models.dev", + UpstreamRef: ref, + } + + if err := s.applyToRedis(ctx, models, payloads, meta); err != nil { + s.setError(err) + return err + } + if err := s.persistCache(models, meta); err != nil { + s.setError(err) + return err + } + + s.mu.Lock() + s.lastError = "" + s.lastRefresh = time.Now() + s.lastApplied = meta + s.mu.Unlock() + return nil +} + +func (s *ModelRegistryService) Rollback(ctx context.Context) error { + prevPath := filepath.Join(s.cfg.CacheDir, "prev.json") + prev, err := readModelRegistryFile(prevPath) + if err != nil { + s.setError(fmt.Errorf("read prev cache: %w", err)) + return err + } + if prev == nil { + s.setError(fmt.Errorf("no prev cache")) + return fmt.Errorf("no prev cache") + } + + payloads := make(map[string]string, len(prev.Models)) + for name, m := range prev.Models { + m = m.Normalized() + b, err := jsoncodec.Marshal(m) + if err != nil { + s.setError(fmt.Errorf("marshal cached model %s: %w", name, err)) + return err + } + prev.Models[name] = m + payloads[name] = string(b) + } + prev.Meta.Checksum = modelcap.ChecksumFromPayloads(payloads) + prev.Meta.Source = "rollback" + prev.Meta.UpdatedAt = fmt.Sprintf("%d", time.Now().Unix()) + + if err := s.applyToRedis(ctx, prev.Models, payloads, prev.Meta); err != nil { + s.setError(err) + return err + } + // Swap cache current/prev to reflect rollback. + if err := s.persistCache(prev.Models, prev.Meta); err != nil { + s.setError(err) + return err + } + s.mu.Lock() + s.lastError = "" + s.lastRefresh = time.Now() + s.lastApplied = prev.Meta + s.mu.Unlock() + return nil +} + +func (s *ModelRegistryService) setError(err error) { + s.mu.Lock() + defer s.mu.Unlock() + s.lastError = err.Error() + s.lastRefresh = time.Now() +} + +type boolVal struct { + Known bool + Val bool +} + +type intVal struct { + Known bool + Val int +} + +type upstreamCap struct { + ContextWindow intVal + MaxOutputTokens intVal + SupportsVision boolVal + SupportsTools boolVal +} + +type modelsDevRegistry struct { + ByProviderModel map[string]upstreamCap // key: providerID|modelID + ByModel map[string]upstreamCap // fallback: modelID +} + +var shaSuffix = regexp.MustCompile(`-([0-9a-f]{7,40})$`) + +func (s *ModelRegistryService) fetchModelsDev(ctx context.Context, tarballURL string) (*modelsDevRegistry, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, tarballURL, nil) + if err != nil { + return nil, "", err + } + resp, err := s.client.Do(req) + if err != nil { + return nil, "", err + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + return nil, "", fmt.Errorf("models.dev fetch failed: %s body=%s", resp.Status, string(b)) + } + tarball, err := io.ReadAll(resp.Body) + if err != nil { + return nil, "", err + } + + gr, err := gzip.NewReader(bytes.NewReader(tarball)) + if err != nil { + return nil, "", err + } + defer gr.Close() + + tr := tar.NewReader(gr) + reg := &modelsDevRegistry{ + ByProviderModel: make(map[string]upstreamCap), + ByModel: make(map[string]upstreamCap), + } + + version := "" + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, "", err + } + name := hdr.Name + if version == "" { + if parts := strings.SplitN(name, "/", 2); len(parts) > 0 { + if m := shaSuffix.FindStringSubmatch(parts[0]); len(m) == 2 { + version = m[1] + } + } + } + if hdr.Typeflag != tar.TypeReg { + continue + } + if !strings.HasSuffix(name, ".toml") { + continue + } + providerID, modelID, ok := parseModelsDevModelPath(name) + if !ok { + continue + } + + data, err := io.ReadAll(io.LimitReader(tr, 1<<20)) + if err != nil { + return nil, "", err + } + cap, ok := parseModelsDevTOML(data) + if !ok { + continue + } + key := providerID + "|" + modelID + reg.ByProviderModel[key] = cap + if _, exists := reg.ByModel[modelID]; !exists { + reg.ByModel[modelID] = cap + } + // Also store a flattened variant for providers that may use subdirs. + if alt := strings.ReplaceAll(modelID, "/", "-"); alt != modelID { + if _, exists := reg.ByModel[alt]; !exists { + reg.ByModel[alt] = cap + } + altKey := providerID + "|" + alt + if _, exists := reg.ByProviderModel[altKey]; !exists { + reg.ByProviderModel[altKey] = cap + } + } + } + if version == "" { + version = fmt.Sprintf("%d", time.Now().Unix()) + } + return reg, version, nil +} + +func parseModelsDevModelPath(path string) (providerID string, modelID string, ok bool) { + parts := strings.Split(path, "/") + for i := 0; i < len(parts); i++ { + if parts[i] != "providers" { + continue + } + if i+3 >= len(parts) { + return "", "", false + } + providerID = strings.TrimSpace(parts[i+1]) + if providerID == "" { + return "", "", false + } + if parts[i+2] != "models" { + return "", "", false + } + modelPart := strings.Join(parts[i+3:], "/") + modelPart = strings.TrimSuffix(modelPart, ".toml") + modelPart = strings.TrimSpace(modelPart) + if modelPart == "" { + return "", "", false + } + return providerID, modelPart, true + } + return "", "", false +} + +func parseModelsDevTOML(data []byte) (upstreamCap, bool) { + var doc map[string]any + if err := toml.Unmarshal(data, &doc); err != nil { + return upstreamCap{}, false + } + + var cap upstreamCap + + if v, ok := getBool(doc, "tool_call"); ok { + cap.SupportsTools = boolVal{Known: true, Val: v} + } + if limit, ok := getMap(doc, "limit"); ok { + if v, ok := getInt(limit, "context"); ok { + cap.ContextWindow = intVal{Known: true, Val: v} + } + if v, ok := getInt(limit, "output"); ok { + cap.MaxOutputTokens = intVal{Known: true, Val: v} + } + } + if mods, ok := getMap(doc, "modalities"); ok { + if input, ok := getStringSlice(mods["input"]); ok { + hasImage := false + for _, it := range input { + if strings.EqualFold(strings.TrimSpace(it), "image") { + hasImage = true + break + } + } + cap.SupportsVision = boolVal{Known: true, Val: hasImage} + } + } + + if !cap.SupportsTools.Known && !cap.SupportsVision.Known && !cap.ContextWindow.Known && !cap.MaxOutputTokens.Known { + return upstreamCap{}, false + } + return cap, true +} + +func getMap(doc map[string]any, key string) (map[string]any, bool) { + if doc == nil { + return nil, false + } + v, ok := doc[key] + if !ok || v == nil { + return nil, false + } + m, ok := v.(map[string]any) + return m, ok +} + +func getBool(doc map[string]any, key string) (bool, bool) { + v, ok := doc[key] + if !ok || v == nil { + return false, false + } + switch b := v.(type) { + case bool: + return b, true + default: + return false, false + } +} + +func getInt(doc map[string]any, key string) (int, bool) { + v, ok := doc[key] + if !ok || v == nil { + return 0, false + } + switch n := v.(type) { + case int64: + return int(n), true + case int: + return n, true + case float64: + return int(n), true + default: + return 0, false + } +} + +func getStringSlice(v any) ([]string, bool) { + if v == nil { + return nil, false + } + switch s := v.(type) { + case []any: + out := make([]string, 0, len(s)) + for _, it := range s { + if str, ok := it.(string); ok { + out = append(out, str) + } + } + return out, true + case []string: + return append([]string(nil), s...), true + default: + return nil, false + } +} + +type capAgg struct { + kind string + + visionAnyTrue bool + visionAllKnown bool + visionKnownAny bool + + toolsAnyTrue bool + toolsAllKnown bool + toolsKnownAny bool + + maxOutputKnown bool + maxOutputMax int + + contextKnown bool + contextMax int +} + +func inferKindFromPublicModel(publicModel string) string { + pm := strings.ToLower(strings.TrimSpace(publicModel)) + if strings.Contains(pm, "embedding") { + return string(modelcap.KindEmbedding) + } + if strings.Contains(pm, "rerank") { + return string(modelcap.KindRerank) + } + return string(modelcap.KindChat) +} + +func (a *capAgg) merge(cap upstreamCap, ok bool) { + if ok { + // vision + if cap.SupportsVision.Known { + a.visionKnownAny = true + if cap.SupportsVision.Val { + a.visionAnyTrue = true + } + } else { + a.visionAllKnown = false + } + // tools + if cap.SupportsTools.Known { + a.toolsKnownAny = true + if cap.SupportsTools.Val { + a.toolsAnyTrue = true + } + } else { + a.toolsAllKnown = false + } + // limits + if cap.MaxOutputTokens.Known && cap.MaxOutputTokens.Val > 0 { + a.maxOutputKnown = true + if cap.MaxOutputTokens.Val > a.maxOutputMax { + a.maxOutputMax = cap.MaxOutputTokens.Val + } + } + if cap.ContextWindow.Known && cap.ContextWindow.Val > 0 { + a.contextKnown = true + if cap.ContextWindow.Val > a.contextMax { + a.contextMax = cap.ContextWindow.Val + } + } + return + } + // Unknown upstream capability: do not force blocking, treat as "not all known". + a.visionAllKnown = false + a.toolsAllKnown = false +} + +func (a *capAgg) finalize(name string) modelcap.Model { + // Safe defaults: unknown -> allow (true) so we avoid false blocking. + supportsVision := true + if a.visionKnownAny { + if a.visionAnyTrue { + supportsVision = true + } else if a.visionAllKnown { + supportsVision = false + } + } + + supportsTools := true + if a.toolsKnownAny { + if a.toolsAnyTrue { + supportsTools = true + } else if a.toolsAllKnown { + supportsTools = false + } + } + + out := modelcap.Model{ + Name: name, + Kind: string(modelcap.NormalizeKind(a.kind)), + ContextWindow: 0, + CostPerToken: 0, + SupportsVision: supportsVision, + SupportsFunction: supportsTools, + SupportsToolChoice: true, + SupportsFim: true, + MaxOutputTokens: 0, + } + if a.contextKnown { + out.ContextWindow = a.contextMax + } + if a.maxOutputKnown { + out.MaxOutputTokens = a.maxOutputMax + } + return out.Normalized() +} + +func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) { + var providers []model.Provider + if err := s.db.Find(&providers).Error; err != nil { + return nil, nil, fmt.Errorf("load providers: %w", err) + } + var bindings []model.Binding + if err := s.db.Find(&bindings).Error; err != nil { + return nil, nil, fmt.Errorf("load bindings: %w", err) + } + + type providerLite struct { + id uint + group string + ptype string + models []string + } + providersByGroup := make(map[string][]providerLite) + now := time.Now().Unix() + for _, p := range providers { + if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" { + continue + } + if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now { + continue + } + group := groupx.Normalize(p.Group) + rawModels := strings.Split(p.Models, ",") + var outModels []string + for _, m := range rawModels { + m = strings.TrimSpace(m) + if m != "" { + outModels = append(outModels, m) + } + } + if group == "" || len(outModels) == 0 { + continue + } + providersByGroup[group] = append(providersByGroup[group], providerLite{ + id: p.ID, + group: group, + ptype: strings.TrimSpace(p.Type), + models: outModels, + }) + } + + modelsOut := make(map[string]modelcap.Model) + payloads := make(map[string]string) + + for _, b := range bindings { + if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" { + continue + } + ns := strings.TrimSpace(b.Namespace) + pm := strings.TrimSpace(b.PublicModel) + if ns == "" || pm == "" { + continue + } + key := ns + "." + pm + rg := groupx.Normalize(b.RouteGroup) + if rg == "" { + continue + } + pgroup := providersByGroup[rg] + if len(pgroup) == 0 { + continue + } + + agg := &capAgg{ + kind: inferKindFromPublicModel(pm), + visionAllKnown: true, + toolsAllKnown: true, + } + + selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType)) + selectorValue := strings.TrimSpace(b.SelectorValue) + + for _, p := range pgroup { + up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models) + if err != nil { + continue + } + cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up) + agg.merge(cap, ok) + } + + out := agg.finalize(key) + bs, err := jsoncodec.Marshal(out) + if err != nil { + return nil, nil, fmt.Errorf("marshal model %s: %w", key, err) + } + modelsOut[key] = out + payloads[key] = string(bs) + } + + return modelsOut, payloads, nil +} + +func modelsDevProviderKey(providerType string) string { + pt := strings.ToLower(strings.TrimSpace(providerType)) + switch pt { + case "openai", "compatible": + return "openai" + case "anthropic", "claude": + return "anthropic" + case "gemini", "vertex", "vertex-express": + return "google" + default: + return "" + } +} + +func lookupModelsDevCap(reg *modelsDevRegistry, providerID, modelID string) (upstreamCap, bool) { + modelID = strings.TrimSpace(modelID) + if modelID == "" || reg == nil { + return upstreamCap{}, false + } + if strings.TrimSpace(providerID) != "" { + if cap, ok := reg.ByProviderModel[providerID+"|"+modelID]; ok { + return cap, true + } + } + if cap, ok := reg.ByModel[modelID]; ok { + return cap, true + } + return upstreamCap{}, false +} + +func (s *ModelRegistryService) applyToRedis(ctx context.Context, models map[string]modelcap.Model, payloads map[string]string, meta modelcap.Meta) error { + pipe := s.rdb.TxPipeline() + pipe.Del(ctx, "meta:models", "meta:models_meta") + + for name := range models { + if strings.TrimSpace(name) == "" { + continue + } + payload := payloads[name] + if payload == "" { + b, err := jsoncodec.Marshal(models[name].Normalized()) + if err != nil { + return fmt.Errorf("marshal model %s: %w", name, err) + } + payload = string(b) + } + pipe.HSet(ctx, "meta:models", name, payload) + } + + fields := map[string]string{ + "version": strings.TrimSpace(meta.Version), + "updated_at": strings.TrimSpace(meta.UpdatedAt), + "source": strings.TrimSpace(meta.Source), + "checksum": strings.TrimSpace(meta.Checksum), + "upstream_url": strings.TrimSpace(meta.UpstreamURL), + "upstream_ref": strings.TrimSpace(meta.UpstreamRef), + } + for k, v := range fields { + if v == "" { + delete(fields, k) + } + } + if fields["version"] == "" { + fields["version"] = fmt.Sprintf("%d", time.Now().Unix()) + } + if fields["updated_at"] == "" { + fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix()) + } + if fields["source"] == "" { + fields["source"] = "models.dev" + } + if fields["checksum"] == "" { + fields["checksum"] = modelcap.ChecksumFromPayloads(payloads) + } + if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil { + return fmt.Errorf("write meta:models_meta: %w", err) + } + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("apply registry to redis: %w", err) + } + return nil +} + +type modelRegistryFile struct { + Meta modelcap.Meta `json:"meta"` + Models map[string]modelcap.Model `json:"models"` +} + +func readModelRegistryFile(path string) (*modelRegistryFile, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var out modelRegistryFile + if err := jsoncodec.Unmarshal(b, &out); err != nil { + return nil, err + } + if out.Models == nil { + out.Models = make(map[string]modelcap.Model) + } + return &out, nil +} + +func (s *ModelRegistryService) persistCache(models map[string]modelcap.Model, meta modelcap.Meta) error { + if err := os.MkdirAll(s.cfg.CacheDir, 0o755); err != nil { + return err + } + currentPath := filepath.Join(s.cfg.CacheDir, "current.json") + prevPath := filepath.Join(s.cfg.CacheDir, "prev.json") + tmpPath := filepath.Join(s.cfg.CacheDir, "current.json.tmp") + + if _, err := os.Stat(currentPath); err == nil { + _ = os.Remove(prevPath) + _ = os.Rename(currentPath, prevPath) + } + + out := modelRegistryFile{ + Meta: meta, + Models: models, + } + b, err := json.MarshalIndent(out, "", " ") + if err != nil { + return err + } + if err := os.WriteFile(tmpPath, b, 0o644); err != nil { + return err + } + if err := os.Rename(tmpPath, currentPath); err != nil { + return err + } + return nil +} diff --git a/internal/service/model_registry_test.go b/internal/service/model_registry_test.go new file mode 100644 index 0000000..82e600c --- /dev/null +++ b/internal/service/model_registry_test.go @@ -0,0 +1,193 @@ +package service + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/foundation/modelcap" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func mustGzipTar(t *testing.T, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + for name, body := range files { + b := []byte(body) + h := &tar.Header{ + Name: name, + Mode: 0o644, + Size: int64(len(b)), + Typeflag: tar.TypeReg, + } + if err := tw.WriteHeader(h); err != nil { + t.Fatalf("tar header: %v", err) + } + if _, err := tw.Write(b); err != nil { + t.Fatalf("tar write: %v", err) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +func TestModelRegistry_RefreshAndRollback(t *testing.T) { + t.Parallel() + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + t.Fatalf("migrate: %v", err) + } + if err := db.Create(&model.Provider{ + Name: "p1", + Type: "openai", + Group: "rg", + Models: "gpt-4o-mini", + Status: "active", + }).Error; err != nil { + t.Fatalf("create provider: %v", err) + } + if err := db.Create(&model.Binding{ + Namespace: "ns", + PublicModel: "m", + RouteGroup: "rg", + SelectorType: "exact", + SelectorValue: "gpt-4o-mini", + Status: "active", + }).Error; err != nil { + t.Fatalf("create binding: %v", err) + } + + tar1 := mustGzipTar(t, map[string]string{ + "sst-models.dev-aaaaaaaa/providers/openai/models/gpt-4o-mini.toml": ` +tool_call = true +[limit] +context = 128000 +output = 8192 +[modalities] +input = ["text","image"] +`, + }) + tar2 := mustGzipTar(t, map[string]string{ + "sst-models.dev-bbbbbbbb/providers/openai/models/gpt-4o-mini.toml": ` +tool_call = false +[limit] +context = 64000 +output = 2048 +[modalities] +input = ["text"] +`, + }) + + var served int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/dev" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/gzip") + if served == 0 { + served++ + _, _ = w.Write(tar1) + return + } + _, _ = w.Write(tar2) + })) + defer srv.Close() + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + + cacheDir := t.TempDir() + svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{ + Enabled: true, + RefreshEvery: time.Hour, + ModelsDevBaseURL: srv.URL, + ModelsDevRef: "dev", + CacheDir: cacheDir, + Timeout: 5 * time.Second, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := svc.Refresh(ctx, "dev"); err != nil { + t.Fatalf("refresh1: %v", err) + } + raw1 := mr.HGet("meta:models", "ns.m") + if raw1 == "" { + t.Fatalf("expected meta:models[ns.m]") + } + var m1 modelcap.Model + if err := json.Unmarshal([]byte(raw1), &m1); err != nil { + t.Fatalf("unmarshal1: %v raw=%s", err, raw1) + } + if m1.SupportsVision != true || m1.SupportsFunction != true { + t.Fatalf("expected vision/tools true, got %+v", m1) + } + if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" { + t.Fatalf("expected version aaaaaaaa, got %q", v) + } + + if err := svc.Refresh(ctx, "dev"); err != nil { + t.Fatalf("refresh2: %v", err) + } + raw2 := mr.HGet("meta:models", "ns.m") + var m2 modelcap.Model + if err := json.Unmarshal([]byte(raw2), &m2); err != nil { + t.Fatalf("unmarshal2: %v raw=%s", err, raw2) + } + // Second refresh says no vision/tools, but our safe defaults treat unknown as allow only when unknown; + // here we have explicit false from models.dev and should reflect it. + if m2.SupportsVision != false || m2.SupportsFunction != false { + t.Fatalf("expected vision/tools false, got %+v", m2) + } + if v := mr.HGet("meta:models_meta", "version"); v != "bbbbbbbb" { + t.Fatalf("expected version bbbbbbbb, got %q", v) + } + + if err := svc.Rollback(ctx); err != nil { + t.Fatalf("rollback: %v", err) + } + if v := mr.HGet("meta:models_meta", "version"); v != "aaaaaaaa" { + t.Fatalf("expected rollback to version aaaaaaaa, got %q", v) + } + raw3 := mr.HGet("meta:models", "ns.m") + var m3 modelcap.Model + if err := json.Unmarshal([]byte(raw3), &m3); err != nil { + t.Fatalf("unmarshal3: %v raw=%s", err, raw3) + } + if m3.SupportsVision != true || m3.SupportsFunction != true { + t.Fatalf("expected rollback vision/tools true, got %+v", m3) + } + + if _, err := os.Stat(cacheDir + "/current.json"); err != nil { + t.Fatalf("expected current cache file: %v", err) + } + if _, err := os.Stat(cacheDir + "/prev.json"); err != nil { + t.Fatalf("expected prev cache file: %v", err) + } +} diff --git a/internal/service/sync.go b/internal/service/sync.go index 73ebc90..976d189 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -2,16 +2,14 @@ package service import ( "context" - "crypto/sha256" - "encoding/hex" "fmt" - "sort" "strings" "time" "github.com/ez-api/ez-api/internal/model" groupx "github.com/ez-api/foundation/group" "github.com/ez-api/foundation/jsoncodec" + "github.com/ez-api/foundation/modelcap" "github.com/ez-api/foundation/routing" "github.com/ez-api/foundation/tokenhash" "github.com/redis/go-redis/v9" @@ -120,17 +118,17 @@ func (s *SyncService) SyncProvider(provider *model.Provider) error { // SyncModel writes a single model metadata record. func (s *SyncService) SyncModel(m *model.Model) error { ctx := context.Background() - snap := modelSnapshot{ + snap := modelcap.Model{ Name: m.Name, - Kind: normalizeModelKind(m.Kind), + Kind: string(modelcap.NormalizeKind(m.Kind)), ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, SupportsVision: m.SupportsVision, SupportsFunction: m.SupportsFunctions, SupportsToolChoice: m.SupportsToolChoice, - SupportsFIM: m.SupportsFIM, + SupportsFim: m.SupportsFIM, MaxOutputTokens: m.MaxOutputTokens, - } + }.Normalized() if err := s.hsetJSON(ctx, "meta:models", snap.Name, snap); err != nil { return err } @@ -159,18 +157,6 @@ type providerSnapshot struct { // keySnapshot is no longer needed as we write directly to auth:token:* -type modelSnapshot struct { - Name string `json:"name"` - Kind string `json:"kind"` - ContextWindow int `json:"context_window"` - CostPerToken float64 `json:"cost_per_token"` - SupportsVision bool `json:"supports_vision"` - SupportsFunction bool `json:"supports_functions"` - SupportsToolChoice bool `json:"supports_tool_choice"` - SupportsFIM bool `json:"supports_fim"` - MaxOutputTokens int `json:"max_output_tokens"` -} - // SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes. func (s *SyncService) SyncAll(db *gorm.DB) error { ctx := context.Background() @@ -194,6 +180,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { if err := db.Find(&models).Error; err != nil { return fmt.Errorf("load models: %w", err) } + var modelsPayloads map[string]string var bindings []model.Binding if err := db.Find(&bindings).Error; err != nil { @@ -292,29 +279,35 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { } for _, m := range models { - snap := modelSnapshot{ + snap := modelcap.Model{ Name: m.Name, - Kind: normalizeModelKind(m.Kind), + Kind: string(modelcap.NormalizeKind(m.Kind)), ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, SupportsVision: m.SupportsVision, SupportsFunction: m.SupportsFunctions, SupportsToolChoice: m.SupportsToolChoice, - SupportsFIM: m.SupportsFIM, + SupportsFim: m.SupportsFIM, MaxOutputTokens: m.MaxOutputTokens, - } + }.Normalized() payload, err := jsoncodec.Marshal(snap) if err != nil { return fmt.Errorf("marshal model %s: %w", m.Name, err) } + // Capture payloads so we can compute deterministic checksum for meta:models_meta. + if modelsPayloads == nil { + modelsPayloads = make(map[string]string, len(models)) + } + modelsPayloads[snap.Name] = string(payload) pipe.HSet(ctx, "meta:models", snap.Name, payload) } - if err := writeModelsMeta(ctx, pipe, modelsMetaInput{ - Source: "db", - Version: fmt.Sprintf("%d", time.Now().Unix()), - UpdatedAtSec: time.Now().Unix(), - Models: models, + now := time.Now().Unix() + if err := writeModelsMeta(ctx, pipe, modelcap.Meta{ + Version: fmt.Sprintf("%d", now), + UpdatedAt: fmt.Sprintf("%d", now), + Source: "db", + Checksum: modelcap.ChecksumFromPayloads(modelsPayloads), }); err != nil { return err } @@ -469,36 +462,6 @@ func normalizeStatus(status string) string { } } -func normalizeModelKind(kind string) string { - k := strings.ToLower(strings.TrimSpace(kind)) - if k == "" { - return "chat" - } - switch k { - case "chat", "embedding", "rerank", "other": - return k - default: - return "other" - } -} - -func checksumModelPayloads(payloads map[string]string) string { - keys := make([]string, 0, len(payloads)) - for k := range payloads { - keys = append(keys, k) - } - sort.Strings(keys) - - h := sha256.New() - for _, k := range keys { - _, _ = h.Write([]byte(k)) - _, _ = h.Write([]byte{'\n'}) - _, _ = h.Write([]byte(payloads[k])) - _, _ = h.Write([]byte{'\n'}) - } - return hex.EncodeToString(h.Sum(nil)) -} - func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source string) error { raw, err := s.rdb.HGetAll(ctx, "meta:models").Result() if err != nil { @@ -509,7 +472,7 @@ func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source str "version": fmt.Sprintf("%d", now), "updated_at": fmt.Sprintf("%d", now), "source": source, - "checksum": checksumModelPayloads(raw), + "checksum": modelcap.ChecksumFromPayloads(raw), } if err := s.rdb.HSet(ctx, "meta:models_meta", meta).Err(); err != nil { return fmt.Errorf("write meta:models_meta: %w", err) @@ -517,47 +480,33 @@ func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source str return nil } -type modelsMetaInput struct { - Source string - Version string - UpdatedAtSec int64 - Models []model.Model -} - -func writeModelsMeta(ctx context.Context, pipe redis.Pipeliner, in modelsMetaInput) error { - payloads := make(map[string]string, len(in.Models)) - for _, m := range in.Models { - snap := modelSnapshot{ - Name: m.Name, - Kind: normalizeModelKind(m.Kind), - ContextWindow: m.ContextWindow, - CostPerToken: m.CostPerToken, - SupportsVision: m.SupportsVision, - SupportsFunction: m.SupportsFunctions, - SupportsToolChoice: m.SupportsToolChoice, - SupportsFIM: m.SupportsFIM, - MaxOutputTokens: m.MaxOutputTokens, +func writeModelsMeta(ctx context.Context, pipe redis.Pipeliner, meta modelcap.Meta) error { + fields := map[string]string{ + "version": strings.TrimSpace(meta.Version), + "updated_at": strings.TrimSpace(meta.UpdatedAt), + "source": strings.TrimSpace(meta.Source), + "checksum": strings.TrimSpace(meta.Checksum), + "upstream_url": strings.TrimSpace(meta.UpstreamURL), + "upstream_ref": strings.TrimSpace(meta.UpstreamRef), + } + for k, v := range fields { + if v == "" { + delete(fields, k) } - b, err := jsoncodec.Marshal(snap) - if err != nil { - return fmt.Errorf("marshal model %s for meta: %w", m.Name, err) - } - payloads[snap.Name] = string(b) } - - meta := map[string]string{ - "version": strings.TrimSpace(in.Version), - "updated_at": fmt.Sprintf("%d", in.UpdatedAtSec), - "source": strings.TrimSpace(in.Source), - "checksum": checksumModelPayloads(payloads), + if fields["version"] == "" { + fields["version"] = fmt.Sprintf("%d", time.Now().Unix()) } - if strings.TrimSpace(meta["version"]) == "" { - meta["version"] = fmt.Sprintf("%d", time.Now().Unix()) + if fields["updated_at"] == "" { + fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix()) } - if strings.TrimSpace(meta["source"]) == "" { - meta["source"] = "db" + if fields["source"] == "" { + fields["source"] = "db" } - if err := pipe.HSet(ctx, "meta:models_meta", meta).Err(); err != nil { + if fields["checksum"] == "" { + fields["checksum"] = "unknown" + } + if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil { return fmt.Errorf("write meta:models_meta: %w", err) } return nil