From 96e1fe41a598a24fb5dd2776642ad1ead0ea5594 Mon Sep 17 00:00:00 2001 From: zenfun Date: Wed, 17 Dec 2025 23:15:12 +0800 Subject: [PATCH] feat(models): add kind and models_meta snapshot --- internal/api/handler.go | 41 ++++++++- internal/api/model_handler_test.go | 105 +++++++++++++++++++++++ internal/api/provider_handler_test.go | 2 +- internal/dto/model.go | 17 ++-- internal/model/models.go | 1 + internal/service/sync.go | 119 +++++++++++++++++++++++++- 6 files changed, 272 insertions(+), 13 deletions(-) create mode 100644 internal/api/model_handler_test.go diff --git a/internal/api/handler.go b/internal/api/handler.go index 1517072..8b1058c 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -294,8 +294,25 @@ func (h *Handler) CreateModel(c *gin.Context) { return } + name := strings.TrimSpace(req.Name) + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name required"}) + return + } + kind := strings.ToLower(strings.TrimSpace(req.Kind)) + if kind == "" { + kind = "chat" + } + switch kind { + case "chat", "embedding", "rerank", "other": + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid kind"}) + return + } + modelReq := model.Model{ - Name: req.Name, + Name: name, + Kind: kind, ContextWindow: req.ContextWindow, CostPerToken: req.CostPerToken, SupportsVision: req.SupportsVision, @@ -370,7 +387,27 @@ func (h *Handler) UpdateModel(c *gin.Context) { return } - existing.Name = req.Name + name := strings.TrimSpace(req.Name) + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name required"}) + return + } + kind := strings.ToLower(strings.TrimSpace(req.Kind)) + if kind == "" { + kind = strings.ToLower(strings.TrimSpace(existing.Kind)) + } + if kind == "" { + kind = "chat" + } + switch kind { + case "chat", "embedding", "rerank", "other": + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid kind"}) + return + } + + existing.Name = name + existing.Kind = kind existing.ContextWindow = req.ContextWindow existing.CostPerToken = req.CostPerToken existing.SupportsVision = req.SupportsVision diff --git a/internal/api/model_handler_test.go b/internal/api/model_handler_test.go new file mode 100644 index 0000000..7d7ce70 --- /dev/null +++ b/internal/api/model_handler_test.go @@ -0,0 +1,105 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/ez-api/internal/service" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func newTestHandlerWithRedis(t *testing.T) (*Handler, *gorm.DB, *miniredis.Miniredis) { + t.Helper() + gin.SetMode(gin.TestMode) + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + sync := service.NewSyncService(rdb) + return NewHandler(db, sync, nil), db, mr +} + +func TestCreateModel_DefaultsKindChat_AndWritesModelsMeta(t *testing.T) { + h, _, mr := newTestHandlerWithRedis(t) + + r := gin.New() + r.POST("/admin/models", h.CreateModel) + + reqBody := map[string]any{ + "name": "ns.m", + } + b, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/admin/models", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String()) + } + + raw := mr.HGet("meta:models", "ns.m") + if raw == "" { + t.Fatalf("expected meta:models[ns.m] to be written") + } + var snap map[string]any + if err := json.Unmarshal([]byte(raw), &snap); err != nil { + t.Fatalf("unmarshal snapshot: %v raw=%s", err, raw) + } + if snap["kind"] != "chat" { + t.Fatalf("expected kind=chat, got %v raw=%s", snap["kind"], raw) + } + + if v := mr.HGet("meta:models_meta", "version"); v == "" { + t.Fatalf("expected meta:models_meta.version") + } + if v := mr.HGet("meta:models_meta", "updated_at"); v == "" { + t.Fatalf("expected meta:models_meta.updated_at") + } + if v := mr.HGet("meta:models_meta", "source"); v == "" { + t.Fatalf("expected meta:models_meta.source") + } + if v := mr.HGet("meta:models_meta", "checksum"); v == "" { + t.Fatalf("expected meta:models_meta.checksum") + } +} + +func TestCreateModel_InvalidKind_Returns400(t *testing.T) { + h, _, _ := newTestHandlerWithRedis(t) + + r := gin.New() + r.POST("/admin/models", h.CreateModel) + + reqBody := map[string]any{ + "name": "ns.m2", + "kind": "bad", + } + b, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/admin/models", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } +} diff --git a/internal/api/provider_handler_test.go b/internal/api/provider_handler_test.go index 1d24f29..3cfea47 100644 --- a/internal/api/provider_handler_test.go +++ b/internal/api/provider_handler_test.go @@ -28,7 +28,7 @@ func newTestHandler(t *testing.T) (*Handler, *gorm.DB) { if err != nil { t.Fatalf("open sqlite: %v", err) } - if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil { + if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil { t.Fatalf("migrate: %v", err) } diff --git a/internal/dto/model.go b/internal/dto/model.go index 3845a4f..c4c5783 100644 --- a/internal/dto/model.go +++ b/internal/dto/model.go @@ -2,12 +2,13 @@ package dto // ModelDTO is used for create/update of model capabilities. type ModelDTO struct { - Name string `json:"name"` - ContextWindow int `json:"context_window"` - CostPerToken float64 `json:"cost_per_token"` - SupportsVision bool `json:"supports_vision"` - SupportsFunctions bool `json:"supports_functions"` - SupportsToolChoice bool `json:"supports_tool_choice"` - SupportsFIM bool `json:"supports_fim"` - MaxOutputTokens int `json:"max_output_tokens"` + Name string `json:"name"` + Kind string `json:"kind"` + ContextWindow int `json:"context_window"` + CostPerToken float64 `json:"cost_per_token"` + SupportsVision bool `json:"supports_vision"` + SupportsFunctions bool `json:"supports_functions"` + SupportsToolChoice bool `json:"supports_tool_choice"` + SupportsFIM bool `json:"supports_fim"` + MaxOutputTokens int `json:"max_output_tokens"` } diff --git a/internal/model/models.go b/internal/model/models.go index 838d905..31a0ef8 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -60,6 +60,7 @@ type Provider struct { type Model struct { gorm.Model Name string `gorm:"uniqueIndex;not null" json:"name"` + Kind string `gorm:"size:50;default:'chat'" json:"kind"` ContextWindow int `json:"context_window"` CostPerToken float64 `json:"cost_per_token"` SupportsVision bool `json:"supports_vision"` diff --git a/internal/service/sync.go b/internal/service/sync.go index 157785e..73ebc90 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -2,7 +2,10 @@ package service import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" + "sort" "strings" "time" @@ -119,6 +122,7 @@ func (s *SyncService) SyncModel(m *model.Model) error { ctx := context.Background() snap := modelSnapshot{ Name: m.Name, + Kind: normalizeModelKind(m.Kind), ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, SupportsVision: m.SupportsVision, @@ -127,7 +131,13 @@ func (s *SyncService) SyncModel(m *model.Model) error { SupportsFIM: m.SupportsFIM, MaxOutputTokens: m.MaxOutputTokens, } - return s.hsetJSON(ctx, "meta:models", snap.Name, snap) + if err := s.hsetJSON(ctx, "meta:models", snap.Name, snap); err != nil { + return err + } + if err := s.refreshModelsMetaFromRedis(ctx, "db"); err != nil { + return err + } + return nil } type providerSnapshot struct { @@ -151,6 +161,7 @@ type providerSnapshot struct { type modelSnapshot struct { Name string `json:"name"` + Kind string `json:"kind"` ContextWindow int `json:"context_window"` CostPerToken float64 `json:"cost_per_token"` SupportsVision bool `json:"supports_vision"` @@ -190,7 +201,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { } pipe := s.rdb.TxPipeline() - pipe.Del(ctx, "config:providers", "config:keys", "meta:models", "config:bindings", "meta:bindings_meta") + pipe.Del(ctx, "config:providers", "config:keys", "meta:models", "meta:models_meta", "config:bindings", "meta:bindings_meta") // Also clear master keys var masterKeys []string iter := s.rdb.Scan(ctx, 0, "auth:master:*", 0).Iterator() @@ -283,6 +294,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { for _, m := range models { snap := modelSnapshot{ Name: m.Name, + Kind: normalizeModelKind(m.Kind), ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, SupportsVision: m.SupportsVision, @@ -298,6 +310,15 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { pipe.HSet(ctx, "meta:models", snap.Name, payload) } + if err := writeModelsMeta(ctx, pipe, modelsMetaInput{ + Source: "db", + Version: fmt.Sprintf("%d", time.Now().Unix()), + UpdatedAtSec: time.Now().Unix(), + Models: models, + }); err != nil { + return err + } + if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil { return err } @@ -447,3 +468,97 @@ func normalizeStatus(status string) string { return st } } + +func normalizeModelKind(kind string) string { + k := strings.ToLower(strings.TrimSpace(kind)) + if k == "" { + return "chat" + } + switch k { + case "chat", "embedding", "rerank", "other": + return k + default: + return "other" + } +} + +func checksumModelPayloads(payloads map[string]string) string { + keys := make([]string, 0, len(payloads)) + for k := range payloads { + keys = append(keys, k) + } + sort.Strings(keys) + + h := sha256.New() + for _, k := range keys { + _, _ = h.Write([]byte(k)) + _, _ = h.Write([]byte{'\n'}) + _, _ = h.Write([]byte(payloads[k])) + _, _ = h.Write([]byte{'\n'}) + } + return hex.EncodeToString(h.Sum(nil)) +} + +func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source string) error { + raw, err := s.rdb.HGetAll(ctx, "meta:models").Result() + if err != nil { + return fmt.Errorf("read meta:models: %w", err) + } + now := time.Now().Unix() + meta := map[string]interface{}{ + "version": fmt.Sprintf("%d", now), + "updated_at": fmt.Sprintf("%d", now), + "source": source, + "checksum": checksumModelPayloads(raw), + } + if err := s.rdb.HSet(ctx, "meta:models_meta", meta).Err(); err != nil { + return fmt.Errorf("write meta:models_meta: %w", err) + } + return nil +} + +type modelsMetaInput struct { + Source string + Version string + UpdatedAtSec int64 + Models []model.Model +} + +func writeModelsMeta(ctx context.Context, pipe redis.Pipeliner, in modelsMetaInput) error { + payloads := make(map[string]string, len(in.Models)) + for _, m := range in.Models { + snap := modelSnapshot{ + Name: m.Name, + Kind: normalizeModelKind(m.Kind), + ContextWindow: m.ContextWindow, + CostPerToken: m.CostPerToken, + SupportsVision: m.SupportsVision, + SupportsFunction: m.SupportsFunctions, + SupportsToolChoice: m.SupportsToolChoice, + SupportsFIM: m.SupportsFIM, + MaxOutputTokens: m.MaxOutputTokens, + } + b, err := jsoncodec.Marshal(snap) + if err != nil { + return fmt.Errorf("marshal model %s for meta: %w", m.Name, err) + } + payloads[snap.Name] = string(b) + } + + meta := map[string]string{ + "version": strings.TrimSpace(in.Version), + "updated_at": fmt.Sprintf("%d", in.UpdatedAtSec), + "source": strings.TrimSpace(in.Source), + "checksum": checksumModelPayloads(payloads), + } + if strings.TrimSpace(meta["version"]) == "" { + meta["version"] = fmt.Sprintf("%d", time.Now().Unix()) + } + if strings.TrimSpace(meta["source"]) == "" { + meta["source"] = "db" + } + if err := pipe.HSet(ctx, "meta:models_meta", meta).Err(); err != nil { + return fmt.Errorf("write meta:models_meta: %w", err) + } + return nil +}