diff --git a/cmd/server/main.go b/cmd/server/main.go index e2f0c09..e68f11e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -115,12 +115,13 @@ func main() { masterHandler := api.NewMasterHandler(db, masterService, syncService) featureHandler := api.NewFeatureHandler(rdb) modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{ - Enabled: cfg.ModelRegistry.Enabled, - RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second, - ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL, - ModelsDevRef: cfg.ModelRegistry.ModelsDevRef, - CacheDir: cfg.ModelRegistry.CacheDir, - Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second, + Enabled: cfg.ModelRegistry.Enabled, + RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second, + ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL, + ModelsDevAPIBaseURL: cfg.ModelRegistry.ModelsDevAPIBaseURL, + ModelsDevRef: cfg.ModelRegistry.ModelsDevRef, + CacheDir: cfg.ModelRegistry.CacheDir, + Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second, }) modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService) @@ -181,6 +182,7 @@ func main() { adminGroup.GET("/features", featureHandler.ListFeatures) adminGroup.PUT("/features", featureHandler.UpdateFeatures) adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus) + adminGroup.POST("/model-registry/check", modelRegistryHandler.Check) adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh) adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback) // Other admin routes for managing providers, models, etc. diff --git a/internal/api/model_registry_handler.go b/internal/api/model_registry_handler.go index 4fb0977..0609118 100644 --- a/internal/api/model_registry_handler.go +++ b/internal/api/model_registry_handler.go @@ -42,6 +42,34 @@ type refreshModelRegistryRequest struct { Ref string `json:"ref"` } +// CheckModelRegistry godoc +// @Summary Check model registry upstream version +// @Description Checks models.dev commit SHA for a ref and indicates whether refresh is needed (does not apply changes) +// @Tags admin +// @Accept json +// @Produce json +// @Security AdminAuth +// @Param body body refreshModelRegistryRequest false "optional override ref" +// @Success 200 {object} service.ModelRegistryCheckResult +// @Failure 400 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/model-registry/check [post] +func (h *ModelRegistryHandler) Check(c *gin.Context) { + if h == nil || h.reg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"}) + return + } + var req refreshModelRegistryRequest + _ = c.ShouldBindJSON(&req) + ref := strings.TrimSpace(req.Ref) + out, err := h.reg.Check(c.Request.Context(), ref) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model registry", "details": err.Error()}) + return + } + c.JSON(http.StatusOK, out) +} + // RefreshModelRegistry godoc // @Summary Refresh model registry from models.dev // @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models diff --git a/internal/config/config.go b/internal/config/config.go index 125fb01..8362bb6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,12 +43,13 @@ type LogConfig struct { } type ModelRegistryConfig struct { - Enabled bool - RefreshSeconds int - ModelsDevBaseURL string - ModelsDevRef string - CacheDir string - TimeoutSeconds int + Enabled bool + RefreshSeconds int + ModelsDevBaseURL string + ModelsDevAPIBaseURL string + ModelsDevRef string + CacheDir string + TimeoutSeconds int } func Load() (*Config, error) { @@ -66,6 +67,7 @@ func Load() (*Config, error) { v.SetDefault("model_registry.enabled", false) v.SetDefault("model_registry.refresh_seconds", 1800) v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz") + v.SetDefault("model_registry.models_dev_api_base_url", "https://api.github.com") v.SetDefault("model_registry.models_dev_ref", "dev") v.SetDefault("model_registry.cache_dir", "./data/model-registry") v.SetDefault("model_registry.timeout_seconds", 30) @@ -85,6 +87,7 @@ func Load() (*Config, error) { _ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED") _ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS") _ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL") + _ = v.BindEnv("model_registry.models_dev_api_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_API_BASE_URL") _ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF") _ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR") _ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS") @@ -125,12 +128,13 @@ func Load() (*Config, error) { JWTSecret: v.GetString("auth.jwt_secret"), }, ModelRegistry: ModelRegistryConfig{ - Enabled: v.GetBool("model_registry.enabled"), - RefreshSeconds: v.GetInt("model_registry.refresh_seconds"), - ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"), - ModelsDevRef: v.GetString("model_registry.models_dev_ref"), - CacheDir: v.GetString("model_registry.cache_dir"), - TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"), + Enabled: v.GetBool("model_registry.enabled"), + RefreshSeconds: v.GetInt("model_registry.refresh_seconds"), + ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"), + ModelsDevAPIBaseURL: v.GetString("model_registry.models_dev_api_base_url"), + ModelsDevRef: v.GetString("model_registry.models_dev_ref"), + CacheDir: v.GetString("model_registry.cache_dir"), + TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"), }, } diff --git a/internal/service/model_registry.go b/internal/service/model_registry.go index 411c5ef..0d059bb 100644 --- a/internal/service/model_registry.go +++ b/internal/service/model_registry.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -30,8 +31,9 @@ type ModelRegistryConfig struct { Enabled bool RefreshEvery time.Duration - ModelsDevBaseURL string - ModelsDevRef string + ModelsDevBaseURL string + ModelsDevAPIBaseURL string + ModelsDevRef string CacheDir string Timeout time.Duration @@ -58,6 +60,9 @@ func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryCo if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" { cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz" } + if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" { + cfg.ModelsDevAPIBaseURL = "https://api.github.com" + } if strings.TrimSpace(cfg.ModelsDevRef) == "" { cfg.ModelsDevRef = "dev" } @@ -136,6 +141,41 @@ func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus return out, nil } +type ModelRegistryCheckResult struct { + Enabled bool `json:"enabled"` + UpstreamRef string `json:"upstream_ref"` + CurrentVersion string `json:"current_version,omitempty"` + LatestVersion string `json:"latest_version,omitempty"` + NeedsRefresh bool `json:"needs_refresh"` + CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"` +} + +func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) { + if strings.TrimSpace(ref) == "" { + ref = s.cfg.ModelsDevRef + } + + latest, err := s.fetchModelsDevRefSHA(ctx, ref) + if err != nil { + return nil, err + } + + meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result() + currentVersion := strings.TrimSpace(meta["version"]) + currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"]) + + needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest) + + return &ModelRegistryCheckResult{ + Enabled: s.cfg.Enabled, + UpstreamRef: ref, + CurrentVersion: currentVersion, + LatestVersion: latest, + NeedsRefresh: needsRefresh, + CurrentUpstreamRef: currentUpstreamRef, + }, nil +} + func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error { if strings.TrimSpace(ref) == "" { ref = s.cfg.ModelsDevRef @@ -263,6 +303,58 @@ func (s *ModelRegistryService) setError(err error) { s.lastRefresh = time.Now() } +func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) { + base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/") + if base == "" { + base = "https://api.github.com" + } + + u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return "", err + } + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("User-Agent", "ez-api") + + resp, err := s.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b)) + } + + var payload struct { + SHA string `json:"sha"` + } + if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil { + return "", fmt.Errorf("decode github commit: %w", err) + } + sha := strings.TrimSpace(payload.SHA) + if sha == "" { + return "", fmt.Errorf("github commit sha missing for ref %q", ref) + } + return sha, nil +} + +func versionsEqual(current, latest string) bool { + c := strings.TrimSpace(current) + l := strings.TrimSpace(latest) + if c == "" || l == "" { + return false + } + if len(c) <= len(l) && strings.HasPrefix(l, c) { + return true + } + if len(l) <= len(c) && strings.HasPrefix(c, l) { + return true + } + return false +} + type boolVal struct { Known bool Val bool diff --git a/internal/service/model_registry_check_test.go b/internal/service/model_registry_check_test.go new file mode 100644 index 0000000..2ff8df1 --- /dev/null +++ b/internal/service/model_registry_check_test.go @@ -0,0 +1,75 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/ez-api/ez-api/internal/model" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestModelRegistry_Check(t *testing.T) { + t.Parallel() + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + const latestSHA = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/sst/models.dev/commits/dev" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"sha":"` + latestSHA + `"}`)) + })) + defer srv.Close() + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + + // Prefix-matching should treat short current as up-to-date. + mr.HSet("meta:models_meta", "version", "aaaaaaaa") + mr.HSet("meta:models_meta", "upstream_ref", "dev") + + svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{ + Enabled: true, + RefreshEvery: time.Hour, + ModelsDevAPIBaseURL: srv.URL, + ModelsDevRef: "dev", + Timeout: 5 * time.Second, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + out, err := svc.Check(ctx, "") + if err != nil { + t.Fatalf("check: %v", err) + } + if out.NeedsRefresh { + t.Fatalf("expected needs_refresh=false, got true (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion) + } + + mr.HSet("meta:models_meta", "version", "bbbbbbbb") + out, err = svc.Check(ctx, "dev") + if err != nil { + t.Fatalf("check2: %v", err) + } + if !out.NeedsRefresh { + t.Fatalf("expected needs_refresh=true, got false (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion) + } +}