Files
ez-api/internal/service/model_registry.go
2025-12-18 16:48:43 +08:00

939 lines
23 KiB
Go

package service
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/ez-api/ez-api/internal/model"
groupx "github.com/ez-api/foundation/group"
"github.com/ez-api/foundation/jsoncodec"
"github.com/ez-api/foundation/modelcap"
"github.com/ez-api/foundation/routing"
"github.com/pelletier/go-toml/v2"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type ModelRegistryConfig struct {
Enabled bool
RefreshEvery time.Duration
ModelsDevBaseURL string
ModelsDevAPIBaseURL string
ModelsDevRef string
CacheDir string
Timeout time.Duration
}
type ModelRegistryService struct {
db *gorm.DB
rdb *redis.Client
cfg ModelRegistryConfig
client *http.Client
mu sync.Mutex
lastError string
lastRefresh time.Time
lastApplied modelcap.Meta
lastUpstream string
}
func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryConfig) *ModelRegistryService {
if cfg.RefreshEvery <= 0 {
cfg.RefreshEvery = 30 * time.Minute
}
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
}
if strings.TrimSpace(cfg.ModelsDevAPIBaseURL) == "" {
cfg.ModelsDevAPIBaseURL = "https://api.github.com"
}
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
cfg.ModelsDevRef = "dev"
}
if strings.TrimSpace(cfg.CacheDir) == "" {
cfg.CacheDir = "./data/model-registry"
}
if cfg.Timeout <= 0 {
cfg.Timeout = 30 * time.Second
}
return &ModelRegistryService{
db: db,
rdb: rdb,
cfg: cfg,
client: &http.Client{
Timeout: cfg.Timeout,
},
}
}
func (s *ModelRegistryService) Start(ctx context.Context) {
if !s.cfg.Enabled {
return
}
go func() {
ticker := time.NewTicker(s.cfg.RefreshEvery)
defer ticker.Stop()
// Best-effort initial refresh.
_ = s.Refresh(ctx, "")
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_ = s.Refresh(ctx, "")
}
}
}()
}
type ModelRegistryStatus struct {
Enabled bool `json:"enabled"`
ModelsDevRef string `json:"models_dev_ref"`
ModelsDevURL string `json:"models_dev_url"`
LastRefreshAt int64 `json:"last_refresh_at,omitempty"`
LastError string `json:"last_error,omitempty"`
RedisMeta map[string]string `json:"redis_meta,omitempty"`
CacheCurrent *modelRegistryFile `json:"cache_current,omitempty"`
CachePrev *modelRegistryFile `json:"cache_prev,omitempty"`
}
func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus, error) {
s.mu.Lock()
lastErr := s.lastError
lastRefresh := s.lastRefresh
s.mu.Unlock()
redisMeta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
current, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "current.json"))
prev, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "prev.json"))
out := &ModelRegistryStatus{
Enabled: s.cfg.Enabled,
ModelsDevRef: s.cfg.ModelsDevRef,
ModelsDevURL: strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + s.cfg.ModelsDevRef,
LastError: lastErr,
RedisMeta: redisMeta,
CacheCurrent: current,
CachePrev: prev,
}
if !lastRefresh.IsZero() {
out.LastRefreshAt = lastRefresh.Unix()
}
return out, nil
}
type ModelRegistryCheckResult struct {
Enabled bool `json:"enabled"`
UpstreamRef string `json:"upstream_ref"`
CurrentVersion string `json:"current_version,omitempty"`
LatestVersion string `json:"latest_version,omitempty"`
NeedsRefresh bool `json:"needs_refresh"`
CurrentUpstreamRef string `json:"current_upstream_ref,omitempty"`
}
func (s *ModelRegistryService) Check(ctx context.Context, ref string) (*ModelRegistryCheckResult, error) {
if strings.TrimSpace(ref) == "" {
ref = s.cfg.ModelsDevRef
}
latest, err := s.fetchModelsDevRefSHA(ctx, ref)
if err != nil {
return nil, err
}
meta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
currentVersion := strings.TrimSpace(meta["version"])
currentUpstreamRef := strings.TrimSpace(meta["upstream_ref"])
needsRefresh := currentVersion == "" || !versionsEqual(currentVersion, latest)
return &ModelRegistryCheckResult{
Enabled: s.cfg.Enabled,
UpstreamRef: ref,
CurrentVersion: currentVersion,
LatestVersion: latest,
NeedsRefresh: needsRefresh,
CurrentUpstreamRef: currentUpstreamRef,
}, nil
}
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
if strings.TrimSpace(ref) == "" {
ref = s.cfg.ModelsDevRef
}
s.mu.Lock()
s.lastUpstream = ref
s.mu.Unlock()
tarballURL := strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + ref
reg, version, err := s.fetchModelsDev(ctx, tarballURL)
if err != nil {
s.setError(err)
return err
}
models, payloads, err := s.buildBindingModels(ctx, reg)
if err != nil {
s.setError(err)
return err
}
// Overlay explicit DB entries (admin overrides).
var dbModels []model.Model
if err := s.db.Find(&dbModels).Error; err != nil {
s.setError(fmt.Errorf("load db models: %w", err))
return err
}
for _, m := range dbModels {
snap := modelcap.Model{
Name: m.Name,
Kind: string(modelcap.NormalizeKind(m.Kind)),
ContextWindow: m.ContextWindow,
CostPerToken: m.CostPerToken,
SupportsVision: m.SupportsVision,
SupportsFunction: m.SupportsFunctions,
SupportsToolChoice: m.SupportsToolChoice,
SupportsFim: m.SupportsFIM,
SupportsStream: true,
MaxOutputTokens: m.MaxOutputTokens,
}.Normalized()
b, err := jsoncodec.Marshal(snap)
if err != nil {
s.setError(fmt.Errorf("marshal db model %s: %w", m.Name, err))
return err
}
models[snap.Name] = snap
payloads[snap.Name] = string(b)
}
now := time.Now().Unix()
meta := modelcap.Meta{
Version: version,
UpdatedAt: fmt.Sprintf("%d", now),
Source: "models.dev",
Checksum: modelcap.ChecksumFromPayloads(payloads),
UpstreamURL: "https://github.com/sst/models.dev",
UpstreamRef: ref,
}
if err := s.applyToRedis(ctx, models, payloads, meta); err != nil {
s.setError(err)
return err
}
if err := s.persistCache(models, meta); err != nil {
s.setError(err)
return err
}
s.mu.Lock()
s.lastError = ""
s.lastRefresh = time.Now()
s.lastApplied = meta
s.mu.Unlock()
return nil
}
func (s *ModelRegistryService) Rollback(ctx context.Context) error {
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
prev, err := readModelRegistryFile(prevPath)
if err != nil {
s.setError(fmt.Errorf("read prev cache: %w", err))
return err
}
if prev == nil {
s.setError(fmt.Errorf("no prev cache"))
return fmt.Errorf("no prev cache")
}
payloads := make(map[string]string, len(prev.Models))
for name, m := range prev.Models {
m = m.Normalized()
b, err := jsoncodec.Marshal(m)
if err != nil {
s.setError(fmt.Errorf("marshal cached model %s: %w", name, err))
return err
}
prev.Models[name] = m
payloads[name] = string(b)
}
prev.Meta.Checksum = modelcap.ChecksumFromPayloads(payloads)
prev.Meta.Source = "rollback"
prev.Meta.UpdatedAt = fmt.Sprintf("%d", time.Now().Unix())
if err := s.applyToRedis(ctx, prev.Models, payloads, prev.Meta); err != nil {
s.setError(err)
return err
}
// Swap cache current/prev to reflect rollback.
if err := s.persistCache(prev.Models, prev.Meta); err != nil {
s.setError(err)
return err
}
s.mu.Lock()
s.lastError = ""
s.lastRefresh = time.Now()
s.lastApplied = prev.Meta
s.mu.Unlock()
return nil
}
func (s *ModelRegistryService) setError(err error) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastError = err.Error()
s.lastRefresh = time.Now()
}
func (s *ModelRegistryService) fetchModelsDevRefSHA(ctx context.Context, ref string) (string, error) {
base := strings.TrimRight(strings.TrimSpace(s.cfg.ModelsDevAPIBaseURL), "/")
if base == "" {
base = "https://api.github.com"
}
u := base + "/repos/sst/models.dev/commits/" + url.PathEscape(strings.TrimSpace(ref))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", "ez-api")
resp, err := s.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
return "", fmt.Errorf("models.dev commit check failed: %s body=%s", resp.Status, string(b))
}
var payload struct {
SHA string `json:"sha"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&payload); err != nil {
return "", fmt.Errorf("decode github commit: %w", err)
}
sha := strings.TrimSpace(payload.SHA)
if sha == "" {
return "", fmt.Errorf("github commit sha missing for ref %q", ref)
}
return sha, nil
}
func versionsEqual(current, latest string) bool {
c := strings.TrimSpace(current)
l := strings.TrimSpace(latest)
if c == "" || l == "" {
return false
}
if len(c) <= len(l) && strings.HasPrefix(l, c) {
return true
}
if len(l) <= len(c) && strings.HasPrefix(c, l) {
return true
}
return false
}
type boolVal struct {
Known bool
Val bool
}
type intVal struct {
Known bool
Val int
}
type upstreamCap struct {
ContextWindow intVal
MaxOutputTokens intVal
SupportsVision boolVal
SupportsTools boolVal
}
type modelsDevRegistry struct {
ByProviderModel map[string]upstreamCap // key: providerID|modelID
ByModel map[string]upstreamCap // fallback: modelID
}
var shaSuffix = regexp.MustCompile(`-([0-9a-f]{7,40})$`)
func (s *ModelRegistryService) fetchModelsDev(ctx context.Context, tarballURL string) (*modelsDevRegistry, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tarballURL, nil)
if err != nil {
return nil, "", err
}
resp, err := s.client.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
return nil, "", fmt.Errorf("models.dev fetch failed: %s body=%s", resp.Status, string(b))
}
tarball, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
gr, err := gzip.NewReader(bytes.NewReader(tarball))
if err != nil {
return nil, "", err
}
defer gr.Close()
tr := tar.NewReader(gr)
reg := &modelsDevRegistry{
ByProviderModel: make(map[string]upstreamCap),
ByModel: make(map[string]upstreamCap),
}
version := ""
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return nil, "", err
}
name := hdr.Name
if version == "" {
if parts := strings.SplitN(name, "/", 2); len(parts) > 0 {
if m := shaSuffix.FindStringSubmatch(parts[0]); len(m) == 2 {
version = m[1]
}
}
}
if hdr.Typeflag != tar.TypeReg {
continue
}
if !strings.HasSuffix(name, ".toml") {
continue
}
providerID, modelID, ok := parseModelsDevModelPath(name)
if !ok {
continue
}
data, err := io.ReadAll(io.LimitReader(tr, 1<<20))
if err != nil {
return nil, "", err
}
cap, ok := parseModelsDevTOML(data)
if !ok {
continue
}
key := providerID + "|" + modelID
reg.ByProviderModel[key] = cap
if _, exists := reg.ByModel[modelID]; !exists {
reg.ByModel[modelID] = cap
}
// Also store a flattened variant for providers that may use subdirs.
if alt := strings.ReplaceAll(modelID, "/", "-"); alt != modelID {
if _, exists := reg.ByModel[alt]; !exists {
reg.ByModel[alt] = cap
}
altKey := providerID + "|" + alt
if _, exists := reg.ByProviderModel[altKey]; !exists {
reg.ByProviderModel[altKey] = cap
}
}
}
if version == "" {
version = fmt.Sprintf("%d", time.Now().Unix())
}
return reg, version, nil
}
func parseModelsDevModelPath(path string) (providerID string, modelID string, ok bool) {
parts := strings.Split(path, "/")
for i := 0; i < len(parts); i++ {
if parts[i] != "providers" {
continue
}
if i+3 >= len(parts) {
return "", "", false
}
providerID = strings.TrimSpace(parts[i+1])
if providerID == "" {
return "", "", false
}
if parts[i+2] != "models" {
return "", "", false
}
modelPart := strings.Join(parts[i+3:], "/")
modelPart = strings.TrimSuffix(modelPart, ".toml")
modelPart = strings.TrimSpace(modelPart)
if modelPart == "" {
return "", "", false
}
return providerID, modelPart, true
}
return "", "", false
}
func parseModelsDevTOML(data []byte) (upstreamCap, bool) {
var doc map[string]any
if err := toml.Unmarshal(data, &doc); err != nil {
return upstreamCap{}, false
}
var cap upstreamCap
if v, ok := getBool(doc, "tool_call"); ok {
cap.SupportsTools = boolVal{Known: true, Val: v}
}
if limit, ok := getMap(doc, "limit"); ok {
if v, ok := getInt(limit, "context"); ok {
cap.ContextWindow = intVal{Known: true, Val: v}
}
if v, ok := getInt(limit, "output"); ok {
cap.MaxOutputTokens = intVal{Known: true, Val: v}
}
}
if mods, ok := getMap(doc, "modalities"); ok {
if input, ok := getStringSlice(mods["input"]); ok {
hasImage := false
for _, it := range input {
if strings.EqualFold(strings.TrimSpace(it), "image") {
hasImage = true
break
}
}
cap.SupportsVision = boolVal{Known: true, Val: hasImage}
}
}
if !cap.SupportsTools.Known && !cap.SupportsVision.Known && !cap.ContextWindow.Known && !cap.MaxOutputTokens.Known {
return upstreamCap{}, false
}
return cap, true
}
func getMap(doc map[string]any, key string) (map[string]any, bool) {
if doc == nil {
return nil, false
}
v, ok := doc[key]
if !ok || v == nil {
return nil, false
}
m, ok := v.(map[string]any)
return m, ok
}
func getBool(doc map[string]any, key string) (bool, bool) {
v, ok := doc[key]
if !ok || v == nil {
return false, false
}
switch b := v.(type) {
case bool:
return b, true
default:
return false, false
}
}
func getInt(doc map[string]any, key string) (int, bool) {
v, ok := doc[key]
if !ok || v == nil {
return 0, false
}
switch n := v.(type) {
case int64:
return int(n), true
case int:
return n, true
case float64:
return int(n), true
default:
return 0, false
}
}
func getStringSlice(v any) ([]string, bool) {
if v == nil {
return nil, false
}
switch s := v.(type) {
case []any:
out := make([]string, 0, len(s))
for _, it := range s {
if str, ok := it.(string); ok {
out = append(out, str)
}
}
return out, true
case []string:
return append([]string(nil), s...), true
default:
return nil, false
}
}
type capAgg struct {
kind string
visionAnyTrue bool
visionAllKnown bool
visionKnownAny bool
toolsAnyTrue bool
toolsAllKnown bool
toolsKnownAny bool
maxOutputKnown bool
maxOutputMax int
contextKnown bool
contextMax int
}
func inferKindFromPublicModel(publicModel string) string {
pm := strings.ToLower(strings.TrimSpace(publicModel))
if strings.Contains(pm, "embedding") {
return string(modelcap.KindEmbedding)
}
if strings.Contains(pm, "rerank") {
return string(modelcap.KindRerank)
}
return string(modelcap.KindChat)
}
func (a *capAgg) merge(cap upstreamCap, ok bool) {
if ok {
// vision
if cap.SupportsVision.Known {
a.visionKnownAny = true
if cap.SupportsVision.Val {
a.visionAnyTrue = true
}
} else {
a.visionAllKnown = false
}
// tools
if cap.SupportsTools.Known {
a.toolsKnownAny = true
if cap.SupportsTools.Val {
a.toolsAnyTrue = true
}
} else {
a.toolsAllKnown = false
}
// limits
if cap.MaxOutputTokens.Known && cap.MaxOutputTokens.Val > 0 {
a.maxOutputKnown = true
if cap.MaxOutputTokens.Val > a.maxOutputMax {
a.maxOutputMax = cap.MaxOutputTokens.Val
}
}
if cap.ContextWindow.Known && cap.ContextWindow.Val > 0 {
a.contextKnown = true
if cap.ContextWindow.Val > a.contextMax {
a.contextMax = cap.ContextWindow.Val
}
}
return
}
// Unknown upstream capability: do not force blocking, treat as "not all known".
a.visionAllKnown = false
a.toolsAllKnown = false
}
func (a *capAgg) finalize(name string) modelcap.Model {
// Safe defaults: unknown -> allow (true) so we avoid false blocking.
supportsVision := true
if a.visionKnownAny {
if a.visionAnyTrue {
supportsVision = true
} else if a.visionAllKnown {
supportsVision = false
}
}
supportsTools := true
if a.toolsKnownAny {
if a.toolsAnyTrue {
supportsTools = true
} else if a.toolsAllKnown {
supportsTools = false
}
}
out := modelcap.Model{
Name: name,
Kind: string(modelcap.NormalizeKind(a.kind)),
ContextWindow: 0,
CostPerToken: 0,
SupportsVision: supportsVision,
SupportsFunction: supportsTools,
SupportsToolChoice: true,
SupportsFim: true,
SupportsStream: true,
MaxOutputTokens: 0,
}
if a.contextKnown {
out.ContextWindow = a.contextMax
}
if a.maxOutputKnown {
out.MaxOutputTokens = a.maxOutputMax
}
return out.Normalized()
}
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
var providers []model.Provider
if err := s.db.Find(&providers).Error; err != nil {
return nil, nil, fmt.Errorf("load providers: %w", err)
}
var bindings []model.Binding
if err := s.db.Find(&bindings).Error; err != nil {
return nil, nil, fmt.Errorf("load bindings: %w", err)
}
type providerLite struct {
id uint
group string
ptype string
models []string
}
providersByGroup := make(map[string][]providerLite)
now := time.Now().Unix()
for _, p := range providers {
if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" {
continue
}
if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now {
continue
}
group := groupx.Normalize(p.Group)
rawModels := strings.Split(p.Models, ",")
var outModels []string
for _, m := range rawModels {
m = strings.TrimSpace(m)
if m != "" {
outModels = append(outModels, m)
}
}
if group == "" || len(outModels) == 0 {
continue
}
providersByGroup[group] = append(providersByGroup[group], providerLite{
id: p.ID,
group: group,
ptype: strings.TrimSpace(p.Type),
models: outModels,
})
}
modelsOut := make(map[string]modelcap.Model)
payloads := make(map[string]string)
for _, b := range bindings {
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
continue
}
ns := strings.TrimSpace(b.Namespace)
pm := strings.TrimSpace(b.PublicModel)
if ns == "" || pm == "" {
continue
}
key := ns + "." + pm
rg := groupx.Normalize(b.RouteGroup)
if rg == "" {
continue
}
pgroup := providersByGroup[rg]
if len(pgroup) == 0 {
continue
}
agg := &capAgg{
kind: inferKindFromPublicModel(pm),
visionAllKnown: true,
toolsAllKnown: true,
}
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
selectorValue := strings.TrimSpace(b.SelectorValue)
for _, p := range pgroup {
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models)
if err != nil {
continue
}
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up)
agg.merge(cap, ok)
}
out := agg.finalize(key)
bs, err := jsoncodec.Marshal(out)
if err != nil {
return nil, nil, fmt.Errorf("marshal model %s: %w", key, err)
}
modelsOut[key] = out
payloads[key] = string(bs)
}
return modelsOut, payloads, nil
}
func modelsDevProviderKey(providerType string) string {
pt := strings.ToLower(strings.TrimSpace(providerType))
switch pt {
case "openai", "compatible":
return "openai"
case "anthropic", "claude":
return "anthropic"
case "gemini", "vertex", "vertex-express":
return "google"
default:
return ""
}
}
func lookupModelsDevCap(reg *modelsDevRegistry, providerID, modelID string) (upstreamCap, bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" || reg == nil {
return upstreamCap{}, false
}
if strings.TrimSpace(providerID) != "" {
if cap, ok := reg.ByProviderModel[providerID+"|"+modelID]; ok {
return cap, true
}
}
if cap, ok := reg.ByModel[modelID]; ok {
return cap, true
}
return upstreamCap{}, false
}
func (s *ModelRegistryService) applyToRedis(ctx context.Context, models map[string]modelcap.Model, payloads map[string]string, meta modelcap.Meta) error {
pipe := s.rdb.TxPipeline()
pipe.Del(ctx, "meta:models", "meta:models_meta")
for name := range models {
if strings.TrimSpace(name) == "" {
continue
}
payload := payloads[name]
if payload == "" {
b, err := jsoncodec.Marshal(models[name].Normalized())
if err != nil {
return fmt.Errorf("marshal model %s: %w", name, err)
}
payload = string(b)
}
pipe.HSet(ctx, "meta:models", name, payload)
}
fields := map[string]string{
"version": strings.TrimSpace(meta.Version),
"updated_at": strings.TrimSpace(meta.UpdatedAt),
"source": strings.TrimSpace(meta.Source),
"checksum": strings.TrimSpace(meta.Checksum),
"upstream_url": strings.TrimSpace(meta.UpstreamURL),
"upstream_ref": strings.TrimSpace(meta.UpstreamRef),
}
for k, v := range fields {
if v == "" {
delete(fields, k)
}
}
if fields["version"] == "" {
fields["version"] = fmt.Sprintf("%d", time.Now().Unix())
}
if fields["updated_at"] == "" {
fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix())
}
if fields["source"] == "" {
fields["source"] = "models.dev"
}
if fields["checksum"] == "" {
fields["checksum"] = modelcap.ChecksumFromPayloads(payloads)
}
if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil {
return fmt.Errorf("write meta:models_meta: %w", err)
}
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("apply registry to redis: %w", err)
}
return nil
}
type modelRegistryFile struct {
Meta modelcap.Meta `json:"meta"`
Models map[string]modelcap.Model `json:"models"`
}
func readModelRegistryFile(path string) (*modelRegistryFile, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var out modelRegistryFile
if err := jsoncodec.Unmarshal(b, &out); err != nil {
return nil, err
}
if out.Models == nil {
out.Models = make(map[string]modelcap.Model)
}
return &out, nil
}
func (s *ModelRegistryService) persistCache(models map[string]modelcap.Model, meta modelcap.Meta) error {
if err := os.MkdirAll(s.cfg.CacheDir, 0o755); err != nil {
return err
}
currentPath := filepath.Join(s.cfg.CacheDir, "current.json")
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
tmpPath := filepath.Join(s.cfg.CacheDir, "current.json.tmp")
if _, err := os.Stat(currentPath); err == nil {
_ = os.Remove(prevPath)
_ = os.Rename(currentPath, prevPath)
}
out := modelRegistryFile{
Meta: meta,
Models: models,
}
b, err := json.MarshalIndent(out, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(tmpPath, b, 0o644); err != nil {
return err
}
if err := os.Rename(tmpPath, currentPath); err != nil {
return err
}
return nil
}