mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(model-registry): models.dev updater + admin endpoints
This commit is contained in:
844
internal/service/model_registry.go
Normal file
844
internal/service/model_registry.go
Normal file
@@ -0,0 +1,844 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
groupx "github.com/ez-api/foundation/group"
|
||||
"github.com/ez-api/foundation/jsoncodec"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/ez-api/foundation/routing"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ModelRegistryConfig struct {
|
||||
Enabled bool
|
||||
RefreshEvery time.Duration
|
||||
|
||||
ModelsDevBaseURL string
|
||||
ModelsDevRef string
|
||||
|
||||
CacheDir string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type ModelRegistryService struct {
|
||||
db *gorm.DB
|
||||
rdb *redis.Client
|
||||
|
||||
cfg ModelRegistryConfig
|
||||
client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
lastError string
|
||||
lastRefresh time.Time
|
||||
lastApplied modelcap.Meta
|
||||
lastUpstream string
|
||||
}
|
||||
|
||||
func NewModelRegistryService(db *gorm.DB, rdb *redis.Client, cfg ModelRegistryConfig) *ModelRegistryService {
|
||||
if cfg.RefreshEvery <= 0 {
|
||||
cfg.RefreshEvery = 30 * time.Minute
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevBaseURL) == "" {
|
||||
cfg.ModelsDevBaseURL = "https://codeload.github.com/sst/models.dev/tar.gz"
|
||||
}
|
||||
if strings.TrimSpace(cfg.ModelsDevRef) == "" {
|
||||
cfg.ModelsDevRef = "dev"
|
||||
}
|
||||
if strings.TrimSpace(cfg.CacheDir) == "" {
|
||||
cfg.CacheDir = "./data/model-registry"
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = 30 * time.Second
|
||||
}
|
||||
return &ModelRegistryService{
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: cfg.Timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Start(ctx context.Context) {
|
||||
if !s.cfg.Enabled {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.cfg.RefreshEvery)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Best-effort initial refresh.
|
||||
_ = s.Refresh(ctx, "")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = s.Refresh(ctx, "")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type ModelRegistryStatus struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ModelsDevRef string `json:"models_dev_ref"`
|
||||
ModelsDevURL string `json:"models_dev_url"`
|
||||
LastRefreshAt int64 `json:"last_refresh_at,omitempty"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
RedisMeta map[string]string `json:"redis_meta,omitempty"`
|
||||
CacheCurrent *modelRegistryFile `json:"cache_current,omitempty"`
|
||||
CachePrev *modelRegistryFile `json:"cache_prev,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Status(ctx context.Context) (*ModelRegistryStatus, error) {
|
||||
s.mu.Lock()
|
||||
lastErr := s.lastError
|
||||
lastRefresh := s.lastRefresh
|
||||
s.mu.Unlock()
|
||||
|
||||
redisMeta, _ := s.rdb.HGetAll(ctx, "meta:models_meta").Result()
|
||||
|
||||
current, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "current.json"))
|
||||
prev, _ := readModelRegistryFile(filepath.Join(s.cfg.CacheDir, "prev.json"))
|
||||
|
||||
out := &ModelRegistryStatus{
|
||||
Enabled: s.cfg.Enabled,
|
||||
ModelsDevRef: s.cfg.ModelsDevRef,
|
||||
ModelsDevURL: strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + s.cfg.ModelsDevRef,
|
||||
LastError: lastErr,
|
||||
RedisMeta: redisMeta,
|
||||
CacheCurrent: current,
|
||||
CachePrev: prev,
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
out.LastRefreshAt = lastRefresh.Unix()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Refresh(ctx context.Context, ref string) error {
|
||||
if strings.TrimSpace(ref) == "" {
|
||||
ref = s.cfg.ModelsDevRef
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lastUpstream = ref
|
||||
s.mu.Unlock()
|
||||
|
||||
tarballURL := strings.TrimRight(s.cfg.ModelsDevBaseURL, "/") + "/" + ref
|
||||
reg, version, err := s.fetchModelsDev(ctx, tarballURL)
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
models, payloads, err := s.buildBindingModels(ctx, reg)
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Overlay explicit DB entries (admin overrides).
|
||||
var dbModels []model.Model
|
||||
if err := s.db.Find(&dbModels).Error; err != nil {
|
||||
s.setError(fmt.Errorf("load db models: %w", err))
|
||||
return err
|
||||
}
|
||||
for _, m := range dbModels {
|
||||
snap := modelcap.Model{
|
||||
Name: m.Name,
|
||||
Kind: string(modelcap.NormalizeKind(m.Kind)),
|
||||
ContextWindow: m.ContextWindow,
|
||||
CostPerToken: m.CostPerToken,
|
||||
SupportsVision: m.SupportsVision,
|
||||
SupportsFunction: m.SupportsFunctions,
|
||||
SupportsToolChoice: m.SupportsToolChoice,
|
||||
SupportsFim: m.SupportsFIM,
|
||||
MaxOutputTokens: m.MaxOutputTokens,
|
||||
}.Normalized()
|
||||
b, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("marshal db model %s: %w", m.Name, err))
|
||||
return err
|
||||
}
|
||||
models[snap.Name] = snap
|
||||
payloads[snap.Name] = string(b)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
meta := modelcap.Meta{
|
||||
Version: version,
|
||||
UpdatedAt: fmt.Sprintf("%d", now),
|
||||
Source: "models.dev",
|
||||
Checksum: modelcap.ChecksumFromPayloads(payloads),
|
||||
UpstreamURL: "https://github.com/sst/models.dev",
|
||||
UpstreamRef: ref,
|
||||
}
|
||||
|
||||
if err := s.applyToRedis(ctx, models, payloads, meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
if err := s.persistCache(models, meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lastError = ""
|
||||
s.lastRefresh = time.Now()
|
||||
s.lastApplied = meta
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) Rollback(ctx context.Context) error {
|
||||
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
|
||||
prev, err := readModelRegistryFile(prevPath)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("read prev cache: %w", err))
|
||||
return err
|
||||
}
|
||||
if prev == nil {
|
||||
s.setError(fmt.Errorf("no prev cache"))
|
||||
return fmt.Errorf("no prev cache")
|
||||
}
|
||||
|
||||
payloads := make(map[string]string, len(prev.Models))
|
||||
for name, m := range prev.Models {
|
||||
m = m.Normalized()
|
||||
b, err := jsoncodec.Marshal(m)
|
||||
if err != nil {
|
||||
s.setError(fmt.Errorf("marshal cached model %s: %w", name, err))
|
||||
return err
|
||||
}
|
||||
prev.Models[name] = m
|
||||
payloads[name] = string(b)
|
||||
}
|
||||
prev.Meta.Checksum = modelcap.ChecksumFromPayloads(payloads)
|
||||
prev.Meta.Source = "rollback"
|
||||
prev.Meta.UpdatedAt = fmt.Sprintf("%d", time.Now().Unix())
|
||||
|
||||
if err := s.applyToRedis(ctx, prev.Models, payloads, prev.Meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
// Swap cache current/prev to reflect rollback.
|
||||
if err := s.persistCache(prev.Models, prev.Meta); err != nil {
|
||||
s.setError(err)
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.lastError = ""
|
||||
s.lastRefresh = time.Now()
|
||||
s.lastApplied = prev.Meta
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) setError(err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lastError = err.Error()
|
||||
s.lastRefresh = time.Now()
|
||||
}
|
||||
|
||||
type boolVal struct {
|
||||
Known bool
|
||||
Val bool
|
||||
}
|
||||
|
||||
type intVal struct {
|
||||
Known bool
|
||||
Val int
|
||||
}
|
||||
|
||||
type upstreamCap struct {
|
||||
ContextWindow intVal
|
||||
MaxOutputTokens intVal
|
||||
SupportsVision boolVal
|
||||
SupportsTools boolVal
|
||||
}
|
||||
|
||||
type modelsDevRegistry struct {
|
||||
ByProviderModel map[string]upstreamCap // key: providerID|modelID
|
||||
ByModel map[string]upstreamCap // fallback: modelID
|
||||
}
|
||||
|
||||
var shaSuffix = regexp.MustCompile(`-([0-9a-f]{7,40})$`)
|
||||
|
||||
func (s *ModelRegistryService) fetchModelsDev(ctx context.Context, tarballURL string) (*modelsDevRegistry, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tarballURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
return nil, "", fmt.Errorf("models.dev fetch failed: %s body=%s", resp.Status, string(b))
|
||||
}
|
||||
tarball, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(bytes.NewReader(tarball))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer gr.Close()
|
||||
|
||||
tr := tar.NewReader(gr)
|
||||
reg := &modelsDevRegistry{
|
||||
ByProviderModel: make(map[string]upstreamCap),
|
||||
ByModel: make(map[string]upstreamCap),
|
||||
}
|
||||
|
||||
version := ""
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
name := hdr.Name
|
||||
if version == "" {
|
||||
if parts := strings.SplitN(name, "/", 2); len(parts) > 0 {
|
||||
if m := shaSuffix.FindStringSubmatch(parts[0]); len(m) == 2 {
|
||||
version = m[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".toml") {
|
||||
continue
|
||||
}
|
||||
providerID, modelID, ok := parseModelsDevModelPath(name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(tr, 1<<20))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cap, ok := parseModelsDevTOML(data)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
key := providerID + "|" + modelID
|
||||
reg.ByProviderModel[key] = cap
|
||||
if _, exists := reg.ByModel[modelID]; !exists {
|
||||
reg.ByModel[modelID] = cap
|
||||
}
|
||||
// Also store a flattened variant for providers that may use subdirs.
|
||||
if alt := strings.ReplaceAll(modelID, "/", "-"); alt != modelID {
|
||||
if _, exists := reg.ByModel[alt]; !exists {
|
||||
reg.ByModel[alt] = cap
|
||||
}
|
||||
altKey := providerID + "|" + alt
|
||||
if _, exists := reg.ByProviderModel[altKey]; !exists {
|
||||
reg.ByProviderModel[altKey] = cap
|
||||
}
|
||||
}
|
||||
}
|
||||
if version == "" {
|
||||
version = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
return reg, version, nil
|
||||
}
|
||||
|
||||
func parseModelsDevModelPath(path string) (providerID string, modelID string, ok bool) {
|
||||
parts := strings.Split(path, "/")
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] != "providers" {
|
||||
continue
|
||||
}
|
||||
if i+3 >= len(parts) {
|
||||
return "", "", false
|
||||
}
|
||||
providerID = strings.TrimSpace(parts[i+1])
|
||||
if providerID == "" {
|
||||
return "", "", false
|
||||
}
|
||||
if parts[i+2] != "models" {
|
||||
return "", "", false
|
||||
}
|
||||
modelPart := strings.Join(parts[i+3:], "/")
|
||||
modelPart = strings.TrimSuffix(modelPart, ".toml")
|
||||
modelPart = strings.TrimSpace(modelPart)
|
||||
if modelPart == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return providerID, modelPart, true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func parseModelsDevTOML(data []byte) (upstreamCap, bool) {
|
||||
var doc map[string]any
|
||||
if err := toml.Unmarshal(data, &doc); err != nil {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
|
||||
var cap upstreamCap
|
||||
|
||||
if v, ok := getBool(doc, "tool_call"); ok {
|
||||
cap.SupportsTools = boolVal{Known: true, Val: v}
|
||||
}
|
||||
if limit, ok := getMap(doc, "limit"); ok {
|
||||
if v, ok := getInt(limit, "context"); ok {
|
||||
cap.ContextWindow = intVal{Known: true, Val: v}
|
||||
}
|
||||
if v, ok := getInt(limit, "output"); ok {
|
||||
cap.MaxOutputTokens = intVal{Known: true, Val: v}
|
||||
}
|
||||
}
|
||||
if mods, ok := getMap(doc, "modalities"); ok {
|
||||
if input, ok := getStringSlice(mods["input"]); ok {
|
||||
hasImage := false
|
||||
for _, it := range input {
|
||||
if strings.EqualFold(strings.TrimSpace(it), "image") {
|
||||
hasImage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cap.SupportsVision = boolVal{Known: true, Val: hasImage}
|
||||
}
|
||||
}
|
||||
|
||||
if !cap.SupportsTools.Known && !cap.SupportsVision.Known && !cap.ContextWindow.Known && !cap.MaxOutputTokens.Known {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
return cap, true
|
||||
}
|
||||
|
||||
func getMap(doc map[string]any, key string) (map[string]any, bool) {
|
||||
if doc == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
m, ok := v.(map[string]any)
|
||||
return m, ok
|
||||
}
|
||||
|
||||
func getBool(doc map[string]any, key string) (bool, bool) {
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return false, false
|
||||
}
|
||||
switch b := v.(type) {
|
||||
case bool:
|
||||
return b, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func getInt(doc map[string]any, key string) (int, bool) {
|
||||
v, ok := doc[key]
|
||||
if !ok || v == nil {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case float64:
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func getStringSlice(v any) ([]string, bool) {
|
||||
if v == nil {
|
||||
return nil, false
|
||||
}
|
||||
switch s := v.(type) {
|
||||
case []any:
|
||||
out := make([]string, 0, len(s))
|
||||
for _, it := range s {
|
||||
if str, ok := it.(string); ok {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
case []string:
|
||||
return append([]string(nil), s...), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
type capAgg struct {
|
||||
kind string
|
||||
|
||||
visionAnyTrue bool
|
||||
visionAllKnown bool
|
||||
visionKnownAny bool
|
||||
|
||||
toolsAnyTrue bool
|
||||
toolsAllKnown bool
|
||||
toolsKnownAny bool
|
||||
|
||||
maxOutputKnown bool
|
||||
maxOutputMax int
|
||||
|
||||
contextKnown bool
|
||||
contextMax int
|
||||
}
|
||||
|
||||
func inferKindFromPublicModel(publicModel string) string {
|
||||
pm := strings.ToLower(strings.TrimSpace(publicModel))
|
||||
if strings.Contains(pm, "embedding") {
|
||||
return string(modelcap.KindEmbedding)
|
||||
}
|
||||
if strings.Contains(pm, "rerank") {
|
||||
return string(modelcap.KindRerank)
|
||||
}
|
||||
return string(modelcap.KindChat)
|
||||
}
|
||||
|
||||
func (a *capAgg) merge(cap upstreamCap, ok bool) {
|
||||
if ok {
|
||||
// vision
|
||||
if cap.SupportsVision.Known {
|
||||
a.visionKnownAny = true
|
||||
if cap.SupportsVision.Val {
|
||||
a.visionAnyTrue = true
|
||||
}
|
||||
} else {
|
||||
a.visionAllKnown = false
|
||||
}
|
||||
// tools
|
||||
if cap.SupportsTools.Known {
|
||||
a.toolsKnownAny = true
|
||||
if cap.SupportsTools.Val {
|
||||
a.toolsAnyTrue = true
|
||||
}
|
||||
} else {
|
||||
a.toolsAllKnown = false
|
||||
}
|
||||
// limits
|
||||
if cap.MaxOutputTokens.Known && cap.MaxOutputTokens.Val > 0 {
|
||||
a.maxOutputKnown = true
|
||||
if cap.MaxOutputTokens.Val > a.maxOutputMax {
|
||||
a.maxOutputMax = cap.MaxOutputTokens.Val
|
||||
}
|
||||
}
|
||||
if cap.ContextWindow.Known && cap.ContextWindow.Val > 0 {
|
||||
a.contextKnown = true
|
||||
if cap.ContextWindow.Val > a.contextMax {
|
||||
a.contextMax = cap.ContextWindow.Val
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// Unknown upstream capability: do not force blocking, treat as "not all known".
|
||||
a.visionAllKnown = false
|
||||
a.toolsAllKnown = false
|
||||
}
|
||||
|
||||
func (a *capAgg) finalize(name string) modelcap.Model {
|
||||
// Safe defaults: unknown -> allow (true) so we avoid false blocking.
|
||||
supportsVision := true
|
||||
if a.visionKnownAny {
|
||||
if a.visionAnyTrue {
|
||||
supportsVision = true
|
||||
} else if a.visionAllKnown {
|
||||
supportsVision = false
|
||||
}
|
||||
}
|
||||
|
||||
supportsTools := true
|
||||
if a.toolsKnownAny {
|
||||
if a.toolsAnyTrue {
|
||||
supportsTools = true
|
||||
} else if a.toolsAllKnown {
|
||||
supportsTools = false
|
||||
}
|
||||
}
|
||||
|
||||
out := modelcap.Model{
|
||||
Name: name,
|
||||
Kind: string(modelcap.NormalizeKind(a.kind)),
|
||||
ContextWindow: 0,
|
||||
CostPerToken: 0,
|
||||
SupportsVision: supportsVision,
|
||||
SupportsFunction: supportsTools,
|
||||
SupportsToolChoice: true,
|
||||
SupportsFim: true,
|
||||
MaxOutputTokens: 0,
|
||||
}
|
||||
if a.contextKnown {
|
||||
out.ContextWindow = a.contextMax
|
||||
}
|
||||
if a.maxOutputKnown {
|
||||
out.MaxOutputTokens = a.maxOutputMax
|
||||
}
|
||||
return out.Normalized()
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
|
||||
var providers []model.Provider
|
||||
if err := s.db.Find(&providers).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load providers: %w", err)
|
||||
}
|
||||
var bindings []model.Binding
|
||||
if err := s.db.Find(&bindings).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load bindings: %w", err)
|
||||
}
|
||||
|
||||
type providerLite struct {
|
||||
id uint
|
||||
group string
|
||||
ptype string
|
||||
models []string
|
||||
}
|
||||
providersByGroup := make(map[string][]providerLite)
|
||||
now := time.Now().Unix()
|
||||
for _, p := range providers {
|
||||
if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now {
|
||||
continue
|
||||
}
|
||||
group := groupx.Normalize(p.Group)
|
||||
rawModels := strings.Split(p.Models, ",")
|
||||
var outModels []string
|
||||
for _, m := range rawModels {
|
||||
m = strings.TrimSpace(m)
|
||||
if m != "" {
|
||||
outModels = append(outModels, m)
|
||||
}
|
||||
}
|
||||
if group == "" || len(outModels) == 0 {
|
||||
continue
|
||||
}
|
||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
||||
id: p.ID,
|
||||
group: group,
|
||||
ptype: strings.TrimSpace(p.Type),
|
||||
models: outModels,
|
||||
})
|
||||
}
|
||||
|
||||
modelsOut := make(map[string]modelcap.Model)
|
||||
payloads := make(map[string]string)
|
||||
|
||||
for _, b := range bindings {
|
||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
ns := strings.TrimSpace(b.Namespace)
|
||||
pm := strings.TrimSpace(b.PublicModel)
|
||||
if ns == "" || pm == "" {
|
||||
continue
|
||||
}
|
||||
key := ns + "." + pm
|
||||
rg := groupx.Normalize(b.RouteGroup)
|
||||
if rg == "" {
|
||||
continue
|
||||
}
|
||||
pgroup := providersByGroup[rg]
|
||||
if len(pgroup) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
agg := &capAgg{
|
||||
kind: inferKindFromPublicModel(pm),
|
||||
visionAllKnown: true,
|
||||
toolsAllKnown: true,
|
||||
}
|
||||
|
||||
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
|
||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||
|
||||
for _, p := range pgroup {
|
||||
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up)
|
||||
agg.merge(cap, ok)
|
||||
}
|
||||
|
||||
out := agg.finalize(key)
|
||||
bs, err := jsoncodec.Marshal(out)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal model %s: %w", key, err)
|
||||
}
|
||||
modelsOut[key] = out
|
||||
payloads[key] = string(bs)
|
||||
}
|
||||
|
||||
return modelsOut, payloads, nil
|
||||
}
|
||||
|
||||
func modelsDevProviderKey(providerType string) string {
|
||||
pt := strings.ToLower(strings.TrimSpace(providerType))
|
||||
switch pt {
|
||||
case "openai", "compatible":
|
||||
return "openai"
|
||||
case "anthropic", "claude":
|
||||
return "anthropic"
|
||||
case "gemini", "vertex", "vertex-express":
|
||||
return "google"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func lookupModelsDevCap(reg *modelsDevRegistry, providerID, modelID string) (upstreamCap, bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" || reg == nil {
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
if strings.TrimSpace(providerID) != "" {
|
||||
if cap, ok := reg.ByProviderModel[providerID+"|"+modelID]; ok {
|
||||
return cap, true
|
||||
}
|
||||
}
|
||||
if cap, ok := reg.ByModel[modelID]; ok {
|
||||
return cap, true
|
||||
}
|
||||
return upstreamCap{}, false
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) applyToRedis(ctx context.Context, models map[string]modelcap.Model, payloads map[string]string, meta modelcap.Meta) error {
|
||||
pipe := s.rdb.TxPipeline()
|
||||
pipe.Del(ctx, "meta:models", "meta:models_meta")
|
||||
|
||||
for name := range models {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
continue
|
||||
}
|
||||
payload := payloads[name]
|
||||
if payload == "" {
|
||||
b, err := jsoncodec.Marshal(models[name].Normalized())
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal model %s: %w", name, err)
|
||||
}
|
||||
payload = string(b)
|
||||
}
|
||||
pipe.HSet(ctx, "meta:models", name, payload)
|
||||
}
|
||||
|
||||
fields := map[string]string{
|
||||
"version": strings.TrimSpace(meta.Version),
|
||||
"updated_at": strings.TrimSpace(meta.UpdatedAt),
|
||||
"source": strings.TrimSpace(meta.Source),
|
||||
"checksum": strings.TrimSpace(meta.Checksum),
|
||||
"upstream_url": strings.TrimSpace(meta.UpstreamURL),
|
||||
"upstream_ref": strings.TrimSpace(meta.UpstreamRef),
|
||||
}
|
||||
for k, v := range fields {
|
||||
if v == "" {
|
||||
delete(fields, k)
|
||||
}
|
||||
}
|
||||
if fields["version"] == "" {
|
||||
fields["version"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
if fields["updated_at"] == "" {
|
||||
fields["updated_at"] = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
if fields["source"] == "" {
|
||||
fields["source"] = "models.dev"
|
||||
}
|
||||
if fields["checksum"] == "" {
|
||||
fields["checksum"] = modelcap.ChecksumFromPayloads(payloads)
|
||||
}
|
||||
if err := pipe.HSet(ctx, "meta:models_meta", fields).Err(); err != nil {
|
||||
return fmt.Errorf("write meta:models_meta: %w", err)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("apply registry to redis: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelRegistryFile struct {
|
||||
Meta modelcap.Meta `json:"meta"`
|
||||
Models map[string]modelcap.Model `json:"models"`
|
||||
}
|
||||
|
||||
func readModelRegistryFile(path string) (*modelRegistryFile, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out modelRegistryFile
|
||||
if err := jsoncodec.Unmarshal(b, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if out.Models == nil {
|
||||
out.Models = make(map[string]modelcap.Model)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) persistCache(models map[string]modelcap.Model, meta modelcap.Meta) error {
|
||||
if err := os.MkdirAll(s.cfg.CacheDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
currentPath := filepath.Join(s.cfg.CacheDir, "current.json")
|
||||
prevPath := filepath.Join(s.cfg.CacheDir, "prev.json")
|
||||
tmpPath := filepath.Join(s.cfg.CacheDir, "current.json.tmp")
|
||||
|
||||
if _, err := os.Stat(currentPath); err == nil {
|
||||
_ = os.Remove(prevPath)
|
||||
_ = os.Rename(currentPath, prevPath)
|
||||
}
|
||||
|
||||
out := modelRegistryFile{
|
||||
Meta: meta,
|
||||
Models: models,
|
||||
}
|
||||
b, err := json.MarshalIndent(out, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(tmpPath, b, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tmpPath, currentPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user