model-registry: add upstream check endpoint

This commit is contained in:
zenfun
2025-12-18 16:43:12 +08:00
parent a61eff27e7
commit 7dd3fac24e
5 changed files with 221 additions and 20 deletions

View File

@@ -115,12 +115,13 @@ func main() {
masterHandler := api.NewMasterHandler(db, masterService, syncService)
featureHandler := api.NewFeatureHandler(rdb)
modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{
Enabled: cfg.ModelRegistry.Enabled,
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
CacheDir: cfg.ModelRegistry.CacheDir,
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
Enabled: cfg.ModelRegistry.Enabled,
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
ModelsDevAPIBaseURL: cfg.ModelRegistry.ModelsDevAPIBaseURL,
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
CacheDir: cfg.ModelRegistry.CacheDir,
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
})
modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService)
@@ -181,6 +182,7 @@ func main() {
adminGroup.GET("/features", featureHandler.ListFeatures)
adminGroup.PUT("/features", featureHandler.UpdateFeatures)
adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus)
adminGroup.POST("/model-registry/check", modelRegistryHandler.Check)
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
// Other admin routes for managing providers, models, etc.

View File

@@ -42,6 +42,34 @@ type refreshModelRegistryRequest struct {
Ref string `json:"ref"`
}
// CheckModelRegistry godoc
// @Summary Check model registry upstream version
// @Description Checks models.dev commit SHA for a ref and indicates whether refresh is needed (does not apply changes)
// @Tags admin
// @Accept json
// @Produce json
// @Security AdminAuth
// @Param body body refreshModelRegistryRequest false "optional override ref"
// @Success 200 {object} service.ModelRegistryCheckResult
// @Failure 400 {object} gin.H
// @Failure 500 {object} gin.H
// @Router /admin/model-registry/check [post]
func (h *ModelRegistryHandler) Check(c *gin.Context) {
if h == nil || h.reg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
return
}
var req refreshModelRegistryRequest
_ = c.ShouldBindJSON(&req)
ref := strings.TrimSpace(req.Ref)
out, err := h.reg.Check(c.Request.Context(), ref)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model registry", "details": err.Error()})
return
}
c.JSON(http.StatusOK, out)
}
// RefreshModelRegistry godoc
// @Summary Refresh model registry from models.dev
// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models

View File

@@ -43,12 +43,13 @@ type LogConfig struct {
}
type ModelRegistryConfig struct {
Enabled bool
RefreshSeconds int
ModelsDevBaseURL string
ModelsDevRef string
CacheDir string
TimeoutSeconds int
Enabled bool
RefreshSeconds int
ModelsDevBaseURL string
ModelsDevAPIBaseURL string
ModelsDevRef string
CacheDir string
TimeoutSeconds int
}
func Load() (*Config, error) {
@@ -66,6 +67,7 @@ func Load() (*Config, error) {
v.SetDefault("model_registry.enabled", false)
v.SetDefault("model_registry.refresh_seconds", 1800)
v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz")
v.SetDefault("model_registry.models_dev_api_base_url", "https://api.github.com")
v.SetDefault("model_registry.models_dev_ref", "dev")
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
v.SetDefault("model_registry.timeout_seconds", 30)
@@ -85,6 +87,7 @@ func Load() (*Config, error) {
_ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED")
_ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS")
_ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL")
_ = v.BindEnv("model_registry.models_dev_api_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_API_BASE_URL")
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
@@ -125,12 +128,13 @@ func Load() (*Config, error) {
JWTSecret: v.GetString("auth.jwt_secret"),
},
ModelRegistry: ModelRegistryConfig{
Enabled: v.GetBool("model_registry.enabled"),
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
CacheDir: v.GetString("model_registry.cache_dir"),
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
Enabled: v.GetBool("model_registry.enabled"),
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
ModelsDevAPIBaseURL: v.GetString("model_registry.models_dev_api_base_url"),
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
CacheDir: v.GetString("model_registry.cache_dir"),
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
},
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
@@ -30,8 +31,9 @@ type ModelRegistryConfig struct {
Enabled bool
RefreshEvery time.Duration
ModelsDevBaseURL string
ModelsDevRef string
ModelsDevBaseURL string
ModelsDevAPIBaseURL string
ModelsDevRef string
CacheDir string
Timeout time.Duration
@@ -58,6 +60,9 @@ func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryCo
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
}
if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" {
cfg.ModelsDevAPIBaseURL = "https://api.github.com"
}
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
cfg.ModelsDevRef = "dev"
}
@@ -136,6 +141,41 @@ func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus
return out, nil
}
type ModelRegistryCheckResult struct {
Enabled bool `json:"enabled"`
UpstreamRef string `json:"upstream_ref"`
CurrentVersion string `json:"current_version,omitempty"`
LatestVersion string `json:"latest_version,omitempty"`
NeedsRefresh bool `json:"needs_refresh"`
CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"`
}
func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) {
if strings.TrimSpace(ref) == "" {
ref = s.cfg.ModelsDevRef
}
latest, err := s.fetchModelsDevRefSHA(ctx, ref)
if err != nil {
return nil, err
}
meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
currentVersion := strings.TrimSpace(meta["version"])
currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"])
needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest)
return &ModelRegistryCheckResult{
Enabled: s.cfg.Enabled,
UpstreamRef: ref,
CurrentVersion: currentVersion,
LatestVersion: latest,
NeedsRefresh: needsRefresh,
CurrentUpstreamRef: currentUpstreamRef,
}, nil
}
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
if strings.TrimSpace(ref) == "" {
ref = s.cfg.ModelsDevRef
@@ -263,6 +303,58 @@ func (s *ModelRegistryService) setError(err error) {
s.lastRefresh = time.Now()
}
func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) {
base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/")
if base == "" {
base = "https://api.github.com"
}
u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", "ez-api")
resp, err := s.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b))
}
var payload struct {
SHA string `json:"sha"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil {
return "", fmt.Errorf("decode github commit: %w", err)
}
sha := strings.TrimSpace(payload.SHA)
if sha == "" {
return "", fmt.Errorf("github commit sha missing for ref %q", ref)
}
return sha, nil
}
func versionsEqual(current, latest string) bool {
c := strings.TrimSpace(current)
l := strings.TrimSpace(latest)
if c == "" || l == "" {
return false
}
if len(c) <= len(l) && strings.HasPrefix(l, c) {
return true
}
if len(l) <= len(c) && strings.HasPrefix(c, l) {
return true
}
return false
}
type boolVal struct {
Known bool
Val bool

View File

@@ -0,0 +1,75 @@
package service
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/ez-api/ez-api/internal/model"
"github.com/redis/go-redis/v9"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func TestModelRegistry_Check(t *testing.T) {
t.Parallel()
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
t.Fatalf("migrate: %v", err)
}
const latestSHA = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/sst/models.dev/commits/dev" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"sha":"` + latestSHA + `"}`))
}))
defer srv.Close()
mr := miniredis.RunT(t)
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
// Prefix-matching should treat short current as up-to-date.
mr.HSet("meta:models_meta", "version", "aaaaaaaa")
mr.HSet("meta:models_meta", "upstream_ref", "dev")
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
Enabled: true,
RefreshEvery: time.Hour,
ModelsDevAPIBaseURL: srv.URL,
ModelsDevRef: "dev",
Timeout: 5 * time.Second,
})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
out, err := svc.Check(ctx, "")
if err != nil {
t.Fatalf("check: %v", err)
}
if out.NeedsRefresh {
t.Fatalf("expected needs_refresh=false, got true (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
}
mr.HSet("meta:models_meta", "version", "bbbbbbbb")
out, err = svc.Check(ctx, "dev")
if err != nil {
t.Fatalf("check2: %v", err)
}
if !out.NeedsRefresh {
t.Fatalf("expected needs_refresh=true, got false (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
}
}