mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
model-registry: add upstream check endpoint
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -30,8 +31,9 @@ type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshEvery time.Duration
|
||||
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevAPIBaseURL string
|
||||
ModelsDevRef string
|
||||
|
||||
CacheDir string
|
||||
Timeout time.Duration
|
||||
@@ -58,6 +60,9 @@ func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryCo
|
||||
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
||||
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" {
|
||||
cfg.ModelsDevAPIBaseURL = "https://api.github.com"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
||||
cfg.ModelsDevRef = "dev"
|
||||
}
|
||||
@@ -136,6 +141,41 @@ func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type ModelRegistryCheckResult struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
UpstreamRef string `json:"upstream_ref"`
|
||||
CurrentVersion string `json:"current_version,omitempty"`
|
||||
LatestVersion string `json:"latest_version,omitempty"`
|
||||
NeedsRefresh bool `json:"needs_refresh"`
|
||||
CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
}
|
||||
|
||||
latest, err := s.fetchModelsDevRefSHA(ctx, ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
|
||||
currentVersion := strings.TrimSpace(meta["version"])
|
||||
currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"])
|
||||
|
||||
needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest)
|
||||
|
||||
return &ModelRegistryCheckResult{
|
||||
Enabled: s.cfg.Enabled,
|
||||
UpstreamRef: ref,
|
||||
CurrentVersion: currentVersion,
|
||||
LatestVersion: latest,
|
||||
NeedsRefresh: needsRefresh,
|
||||
CurrentUpstreamRef: currentUpstreamRef,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
@@ -263,6 +303,58 @@ func (s *ModelRegistryService) setError(err error) {
|
||||
s.lastRefresh = time.Now()
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) {
|
||||
base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/")
|
||||
if base == "" {
|
||||
base = "https://api.github.com"
|
||||
}
|
||||
|
||||
u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", "ez-api")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b))
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
SHA string `json:"sha"`
|
||||
}
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil {
|
||||
return "", fmt.Errorf("decode github commit: %w", err)
|
||||
}
|
||||
sha := strings.TrimSpace(payload.SHA)
|
||||
if sha == "" {
|
||||
return "", fmt.Errorf("github commit sha missing for ref %q", ref)
|
||||
}
|
||||
return sha, nil
|
||||
}
|
||||
|
||||
func versionsEqual(current, latest string) bool {
|
||||
c := strings.TrimSpace(current)
|
||||
l := strings.TrimSpace(latest)
|
||||
if c == "" || l == "" {
|
||||
return false
|
||||
}
|
||||
if len(c) <= len(l) && strings.HasPrefix(l, c) {
|
||||
return true
|
||||
}
|
||||
if len(l) <= len(c) && strings.HasPrefix(c, l) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type boolVal struct {
|
||||
Known bool
|
||||
Val bool
|
||||
|
||||
Reference in New Issue
Block a user