mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
model-registry: add upstream check endpoint
This commit is contained in:
@@ -42,6 +42,34 @@ type refreshModelRegistryRequest struct {
|
||||
Ref string `json:"ref"`
|
||||
}
|
||||
|
||||
// CheckModelRegistry godoc
|
||||
// @Summary Check model registry upstream version
|
||||
// @Description Checks models.dev commit SHA for a ref and indicates whether refresh is needed (does not apply changes)
|
||||
// @Tags admin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security AdminAuth
|
||||
// @Param body body refreshModelRegistryRequest false "optional override ref"
|
||||
// @Success 200 {object} service.ModelRegistryCheckResult
|
||||
// @Failure 400 {object} gin.H
|
||||
// @Failure 500 {object} gin.H
|
||||
// @Router /admin/model-registry/check [post]
|
||||
func (h *ModelRegistryHandler) Check(c *gin.Context) {
|
||||
if h == nil || h.reg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "model registry not configured"})
|
||||
return
|
||||
}
|
||||
var req refreshModelRegistryRequest
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
ref := strings.TrimSpace(req.Ref)
|
||||
out, err := h.reg.Check(c.Request.Context(), ref)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model registry", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, out)
|
||||
}
|
||||
|
||||
// RefreshModelRegistry godoc
|
||||
// @Summary Refresh model registry from models.dev
|
||||
// @Description Fetches models.dev, computes per-binding capabilities, and updates Redis meta:models
|
||||
|
||||
@@ -43,12 +43,13 @@ type LogConfig struct {
|
||||
}
|
||||
|
||||
type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshSeconds int
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
CacheDir string
|
||||
TimeoutSeconds int
|
||||
Enabled bool
|
||||
RefreshSeconds int
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevAPIBaseURL string
|
||||
ModelsDevRef string
|
||||
CacheDir string
|
||||
TimeoutSeconds int
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -66,6 +67,7 @@ func Load() (*Config, error) {
|
||||
v.SetDefault("model_registry.enabled", false)
|
||||
v.SetDefault("model_registry.refresh_seconds", 1800)
|
||||
v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz")
|
||||
v.SetDefault("model_registry.models_dev_api_base_url", "https://api.github.com")
|
||||
v.SetDefault("model_registry.models_dev_ref", "dev")
|
||||
v.SetDefault("model_registry.cache_dir", "./data/model-registry")
|
||||
v.SetDefault("model_registry.timeout_seconds", 30)
|
||||
@@ -85,6 +87,7 @@ func Load() (*Config, error) {
|
||||
_ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED")
|
||||
_ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS")
|
||||
_ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL")
|
||||
_ = v.BindEnv("model_registry.models_dev_api_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_API_BASE_URL")
|
||||
_ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF")
|
||||
_ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR")
|
||||
_ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS")
|
||||
@@ -125,12 +128,13 @@ func Load() (*Config, error) {
|
||||
JWTSecret: v.GetString("auth.jwt_secret"),
|
||||
},
|
||||
ModelRegistry: ModelRegistryConfig{
|
||||
Enabled: v.GetBool("model_registry.enabled"),
|
||||
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
|
||||
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
|
||||
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
|
||||
CacheDir: v.GetString("model_registry.cache_dir"),
|
||||
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
||||
Enabled: v.GetBool("model_registry.enabled"),
|
||||
RefreshSeconds: v.GetInt("model_registry.refresh_seconds"),
|
||||
ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"),
|
||||
ModelsDevAPIBaseURL: v.GetString("model_registry.models_dev_api_base_url"),
|
||||
ModelsDevRef: v.GetString("model_registry.models_dev_ref"),
|
||||
CacheDir: v.GetString("model_registry.cache_dir"),
|
||||
TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -30,8 +31,9 @@ type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshEvery time.Duration
|
||||
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevAPIBaseURL string
|
||||
ModelsDevRef string
|
||||
|
||||
CacheDir string
|
||||
Timeout time.Duration
|
||||
@@ -58,6 +60,9 @@ func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryCo
|
||||
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
||||
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" {
|
||||
cfg.ModelsDevAPIBaseURL = "https://api.github.com"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
||||
cfg.ModelsDevRef = "dev"
|
||||
}
|
||||
@@ -136,6 +141,41 @@ func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type ModelRegistryCheckResult struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
UpstreamRef string `json:"upstream_ref"`
|
||||
CurrentVersion string `json:"current_version,omitempty"`
|
||||
LatestVersion string `json:"latest_version,omitempty"`
|
||||
NeedsRefresh bool `json:"needs_refresh"`
|
||||
CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
}
|
||||
|
||||
latest, err := s.fetchModelsDevRefSHA(ctx, ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
|
||||
currentVersion := strings.TrimSpace(meta["version"])
|
||||
currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"])
|
||||
|
||||
needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest)
|
||||
|
||||
return &ModelRegistryCheckResult{
|
||||
Enabled: s.cfg.Enabled,
|
||||
UpstreamRef: ref,
|
||||
CurrentVersion: currentVersion,
|
||||
LatestVersion: latest,
|
||||
NeedsRefresh: needsRefresh,
|
||||
CurrentUpstreamRef: currentUpstreamRef,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
@@ -263,6 +303,58 @@ func (s *ModelRegistryService) setError(err error) {
|
||||
s.lastRefresh = time.Now()
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) {
|
||||
base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/")
|
||||
if base == "" {
|
||||
base = "https://api.github.com"
|
||||
}
|
||||
|
||||
u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", "ez-api")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b))
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
SHA string `json:"sha"`
|
||||
}
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil {
|
||||
return "", fmt.Errorf("decode github commit: %w", err)
|
||||
}
|
||||
sha := strings.TrimSpace(payload.SHA)
|
||||
if sha == "" {
|
||||
return "", fmt.Errorf("github commit sha missing for ref %q", ref)
|
||||
}
|
||||
return sha, nil
|
||||
}
|
||||
|
||||
func versionsEqual(current, latest string) bool {
|
||||
c := strings.TrimSpace(current)
|
||||
l := strings.TrimSpace(latest)
|
||||
if c == "" || l == "" {
|
||||
return false
|
||||
}
|
||||
if len(c) <= len(l) && strings.HasPrefix(l, c) {
|
||||
return true
|
||||
}
|
||||
if len(l) <= len(c) && strings.HasPrefix(c, l) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type boolVal struct {
|
||||
Known bool
|
||||
Val bool
|
||||
|
||||
75
internal/service/model_registry_check_test.go
Normal file
75
internal/service/model_registry_check_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestModelRegistry_Check(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
const latestSHA = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/repos/sst/models.dev/commits/dev" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"sha":"` + latestSHA + `"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
|
||||
// Prefix-matching should treat short current as up-to-date.
|
||||
mr.HSet("meta:models_meta", "version", "aaaaaaaa")
|
||||
mr.HSet("meta:models_meta", "upstream_ref", "dev")
|
||||
|
||||
svc := NewModelRegistryService(db, rdb, ModelRegistryConfig{
|
||||
Enabled: true,
|
||||
RefreshEvery: time.Hour,
|
||||
ModelsDevAPIBaseURL: srv.URL,
|
||||
ModelsDevRef: "dev",
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
out, err := svc.Check(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("check: %v", err)
|
||||
}
|
||||
if out.NeedsRefresh {
|
||||
t.Fatalf("expected needs_refresh=false, got true (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
|
||||
}
|
||||
|
||||
mr.HSet("meta:models_meta", "version", "bbbbbbbb")
|
||||
out, err = svc.Check(ctx, "dev")
|
||||
if err != nil {
|
||||
t.Fatalf("check2: %v", err)
|
||||
}
|
||||
if !out.NeedsRefresh {
|
||||
t.Fatalf("expected needs_refresh=true, got false (current=%q latest=%q)", out.CurrentVersion, out.LatestVersion)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user