mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
model-registry: add upstream check endpoint
This commit is contained in:
@@ -115,12 +115,13 @@ func main() {
|
|||||||
masterHandler := api.NewMasterHandler(db, masterService, syncService)
|
masterHandler := api.NewMasterHandler(db, masterService, syncService)
|
||||||
featureHandler := api.NewFeatureHandler(rdb)
|
featureHandler := api.NewFeatureHandler(rdb)
|
||||||
modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{
|
modelRegistryService := service.NewModelRegistryService(db, rdb, service.ModelRegistryConfig{
|
||||||
Enabled: cfg.ModelRegistry.Enabled,
|
Enabled: cfg.ModelRegistry.Enabled,
|
||||||
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
|
RefreshEvery: time.Duration(cfg.ModelRegistry.RefreshSeconds) * time.Second,
|
||||||
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
|
ModelsDevBaseURL: cfg.ModelRegistry.ModelsDevBaseURL,
|
||||||
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
|
ModelsDevAPIBaseURL: cfg.ModelRegistry.ModelsDevAPIBaseURL,
|
||||||
CacheDir: cfg.ModelRegistry.CacheDir,
|
ModelsDevRef: cfg.ModelRegistry.ModelsDevRef,
|
||||||
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
|
CacheDir: cfg.ModelRegistry.CacheDir,
|
||||||
|
Timeout: time.Duration(cfg.ModelRegistry.TimeoutSeconds) * time.Second,
|
||||||
})
|
})
|
||||||
modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService)
|
modelRegistryHandler := api.NewModelRegistryHandler(modelRegistryService)
|
||||||
|
|
||||||
@@ -181,6 +182,7 @@ func main() {
|
|||||||
adminGroup.GET("/features", featureHandler.ListFeatures)
|
adminGroup.GET("/features", featureHandler.ListFeatures)
|
||||||
adminGroup.PUT("/features", featureHandler.UpdateFeatures)
|
adminGroup.PUT("/features", featureHandler.UpdateFeatures)
|
||||||
adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus)
|
adminGroup.GET("/model-registry/status", modelRegistryHandler.GetStatus)
|
||||||
|
adminGroup.POST("/model-registry/check", modelRegistryHandler.Check)
|
||||||
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
|
adminGroup.POST("/model-registry/refresh", modelRegistryHandler.Refresh)
|
||||||
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
|
adminGroup.POST("/model-registry/rollback", modelRegistryHandler.Rollback)
|
||||||
// Other admin routes for managing providers, models, etc.
|
// Other admin routes for managing providers, models, etc.
|
||||||
|
|||||||
@@ -42,6 +42,34 @@ type refreshModelRegistryRequest struct {
|
|||||||
Ref string `json:"ref"`
|
Ref string `json:"ref"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckModelRegistry godoc
|
||||||
|
// @Summary Check model registry upstream version
|
||||||
|
// @Description Checks models.dev commit SHA for a ref and indicates whether refresh is needed (does not apply changes)
|
||||||
|
// @Tags admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security AdminAuth
|
||||||
|
// @Param body body refreshModelRegistryRequest false "optional override ref"
|
||||||
|
// @Success 200 {object} service.ModelRegistryCheckResult
|
||||||
|
// @Failure 400 {object} gin.H
|
||||||
|
// @Failure 500 {object} gin.H
|
||||||
|
// @Router /admin/model-registry/check [post]
|
||||||
|
func (h *ModelRegistryHandler) Check(c *gin.Context) {
|
||||||
|
if h == nil || h.reg == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req refreshModelRegistryRequest
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
ref := strings.TrimSpace(req.Ref)
|
||||||
|
out, err := h.reg.Check(c.Request.Context(), ref)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model registry", "details": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, out)
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshModelRegistry godoc
|
// RefreshModelRegistry godoc
|
||||||
// @Summary Refresh model registry from models.dev
|
// @Summary Refresh model registry from models.dev
|
||||||
// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models
|
// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models
|
||||||
|
|||||||
@@ -43,12 +43,13 @@ type LogConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ModelRegistryConfig struct {
|
type ModelRegistryConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
RefreshSeconds int
|
RefreshSeconds int
|
||||||
ModelsDevBaseURL string
|
ModelsDevBaseURL string
|
||||||
ModelsDevRef string
|
ModelsDevAPIBaseURL string
|
||||||
CacheDir string
|
ModelsDevRef string
|
||||||
TimeoutSeconds int
|
CacheDir string
|
||||||
|
TimeoutSeconds int
|
||||||
}
|
}
|
||||||
|
|
||||||
func Load() (*Config, error) {
|
func Load() (*Config, error) {
|
||||||
@@ -66,6 +67,7 @@ func Load() (*Config, error) {
|
|||||||
v.SetDefault("model_registry.enabled", false)
|
v.SetDefault("model_registry.enabled", false)
|
||||||
v.SetDefault("model_registry.refresh_seconds", 1800)
|
v.SetDefault("model_registry.refresh_seconds", 1800)
|
||||||
v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz")
|
v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz")
|
||||||
|
v.SetDefault("model_registry.models_dev_api_base_url", "https://api.github.com")
|
||||||
v.SetDefault("model_registry.models_dev_ref", "dev")
|
v.SetDefault("model_registry.models_dev_ref", "dev")
|
||||||
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
||||||
v.SetDefault("model_registry.timeout_seconds", 30)
|
v.SetDefault("model_registry.timeout_seconds", 30)
|
||||||
@@ -85,6 +87,7 @@ func Load() (*Config, error) {
|
|||||||
_ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED")
|
_ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED")
|
||||||
_ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS")
|
_ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS")
|
||||||
_ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL")
|
_ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL")
|
||||||
|
_ = v.BindEnv("model_registry.models_dev_api_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_API_BASE_URL")
|
||||||
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
||||||
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
||||||
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
||||||
@@ -125,12 +128,13 @@ func Load() (*Config, error) {
|
|||||||
JWTSecret: v.GetString("auth.jwt_secret"),
|
JWTSecret: v.GetString("auth.jwt_secret"),
|
||||||
},
|
},
|
||||||
ModelRegistry: ModelRegistryConfig{
|
ModelRegistry: ModelRegistryConfig{
|
||||||
Enabled: v.GetBool("model_registry.enabled"),
|
Enabled: v.GetBool("model_registry.enabled"),
|
||||||
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
|
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
|
||||||
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
|
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
|
||||||
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
|
ModelsDevAPIBaseURL: v.GetString("model_registry.models_dev_api_base_url"),
|
||||||
CacheDir: v.GetString("model_registry.cache_dir"),
|
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
|
||||||
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
CacheDir: v.GetString("model_registry.cache_dir"),
|
||||||
|
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -30,8 +31,9 @@ type ModelRegistryConfig struct {
|
|||||||
Enabled bool
|
Enabled bool
|
||||||
RefreshEvery time.Duration
|
RefreshEvery time.Duration
|
||||||
|
|
||||||
ModelsDevBaseURL string
|
ModelsDevBaseURL string
|
||||||
ModelsDevRef string
|
ModelsDevAPIBaseURL string
|
||||||
|
ModelsDevRef string
|
||||||
|
|
||||||
CacheDir string
|
CacheDir string
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
@@ -58,6 +60,9 @@ func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryCo
|
|||||||
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
||||||
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" {
|
||||||
|
cfg.ModelsDevAPIBaseURL = "https://api.github.com"
|
||||||
|
}
|
||||||
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
||||||
cfg.ModelsDevRef = "dev"
|
cfg.ModelsDevRef = "dev"
|
||||||
}
|
}
|
||||||
@@ -136,6 +141,41 @@ func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelRegistryCheckResult struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
UpstreamRef string `json:"upstream_ref"`
|
||||||
|
CurrentVersion string `json:"current_version,omitempty"`
|
||||||
|
LatestVersion string `json:"latest_version,omitempty"`
|
||||||
|
NeedsRefresh bool `json:"needs_refresh"`
|
||||||
|
CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) {
|
||||||
|
if strings.TrimSpace(ref) == "" {
|
||||||
|
ref = s.cfg.ModelsDevRef
|
||||||
|
}
|
||||||
|
|
||||||
|
latest, err := s.fetchModelsDevRefSHA(ctx, ref)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
|
||||||
|
currentVersion := strings.TrimSpace(meta["version"])
|
||||||
|
currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"])
|
||||||
|
|
||||||
|
needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest)
|
||||||
|
|
||||||
|
return &ModelRegistryCheckResult{
|
||||||
|
Enabled: s.cfg.Enabled,
|
||||||
|
UpstreamRef: ref,
|
||||||
|
CurrentVersion: currentVersion,
|
||||||
|
LatestVersion: latest,
|
||||||
|
NeedsRefresh: needsRefresh,
|
||||||
|
CurrentUpstreamRef: currentUpstreamRef,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
||||||
if strings.TrimSpace(ref) == "" {
|
if strings.TrimSpace(ref) == "" {
|
||||||
ref = s.cfg.ModelsDevRef
|
ref = s.cfg.ModelsDevRef
|
||||||
@@ -263,6 +303,58 @@ func (s *ModelRegistryService) setError(err error) {
|
|||||||
s.lastRefresh = time.Now()
|
s.lastRefresh = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) {
|
||||||
|
base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/")
|
||||||
|
if base == "" {
|
||||||
|
base = "https://api.github.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref))
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/vnd.github+json")
|
||||||
|
req.Header.Set("User-Agent", "ez-api")
|
||||||
|
|
||||||
|
resp, err := s.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode/100 != 2 {
|
||||||
|
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
SHA string `json:"sha"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil {
|
||||||
|
return "", fmt.Errorf("decode github commit: %w", err)
|
||||||
|
}
|
||||||
|
sha := strings.TrimSpace(payload.SHA)
|
||||||
|
if sha == "" {
|
||||||
|
return "", fmt.Errorf("github commit sha missing for ref %q", ref)
|
||||||
|
}
|
||||||
|
return sha, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func versionsEqual(current, latest string) bool {
|
||||||
|
c := strings.TrimSpace(current)
|
||||||
|
l := strings.TrimSpace(latest)
|
||||||
|
if c == "" || l == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(c) <= len(l) && strings.HasPrefix(l, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(l) <= len(c) && strings.HasPrefix(c, l) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type boolVal struct {
|
type boolVal struct {
|
||||||
Known bool
|
Known bool
|
||||||
Val bool
|
Val bool
|
||||||
|
|||||||
75
internal/service/model_registry_check_test.go
Normal file
75
internal/service/model_registry_check_test.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelRegistry_Check(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||||
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const latestSHA = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/repos/sst/models.dev/commits/dev" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"sha":"` + latestSHA + `"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
|
||||||
|
// Prefix-matching should treat short current as up-to-date.
|
||||||
|
mr.HSet("meta:models_meta", "version", "aaaaaaaa")
|
||||||
|
mr.HSet("meta:models_meta", "upstream_ref", "dev")
|
||||||
|
|
||||||
|
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RefreshEvery: time.Hour,
|
||||||
|
ModelsDevAPIBaseURL: srv.URL,
|
||||||
|
ModelsDevRef: "dev",
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
out, err := svc.Check(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("check: %v", err)
|
||||||
|
}
|
||||||
|
if out.NeedsRefresh {
|
||||||
|
t.Fatalf("expected needs_refresh=false, got true (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
mr.HSet("meta:models_meta", "version", "bbbbbbbb")
|
||||||
|
out, err = svc.Check(ctx, "dev")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("check2: %v", err)
|
||||||
|
}
|
||||||
|
if !out.NeedsRefresh {
|
||||||
|
t.Fatalf("expected needs_refresh=true, got false (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user