mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
refactor(api): split Provider into ProviderGroup and APIKey models
Restructure the provider management system by separating the monolithic Provider model into two distinct entities: - ProviderGroup: defines shared upstream configuration (type, base_url, google settings, models, status) - APIKey: represents individual credentials within a group (api_key, weight, status, auto_ban, ban settings) This change also updates: - Binding model to reference GroupID instead of RouteGroup string - All CRUD handlers for the new provider-group and api-key endpoints - Sync service to rebuild provider snapshots from joined tables - Model registry to aggregate capabilities across group/key pairs - Access handler to validate namespace existence and subset constraints - Migration importer to handle the new schema structure - All related tests to use the new model relationships BREAKING CHANGE: Provider API endpoints replaced with /provider-groups and /api-keys endpoints; Binding.RouteGroup replaced with Binding.GroupID
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
groupx "github.com/ez-api/foundation/group"
|
||||
"github.com/ez-api/foundation/jsoncodec"
|
||||
"github.com/ez-api/foundation/modelcap"
|
||||
"github.com/ez-api/foundation/routing"
|
||||
@@ -373,6 +372,33 @@ type upstreamCap struct {
|
||||
SupportsTools boolVal
|
||||
}
|
||||
|
||||
func boolValEqual(a, b boolVal) bool {
|
||||
if a.Known != b.Known {
|
||||
return false
|
||||
}
|
||||
if !a.Known {
|
||||
return true
|
||||
}
|
||||
return a.Val == b.Val
|
||||
}
|
||||
|
||||
func intValEqual(a, b intVal) bool {
|
||||
if a.Known != b.Known {
|
||||
return false
|
||||
}
|
||||
if !a.Known {
|
||||
return true
|
||||
}
|
||||
return a.Val == b.Val
|
||||
}
|
||||
|
||||
func capsEqual(a, b upstreamCap) bool {
|
||||
return boolValEqual(a.SupportsVision, b.SupportsVision) &&
|
||||
boolValEqual(a.SupportsTools, b.SupportsTools) &&
|
||||
intValEqual(a.ContextWindow, b.ContextWindow) &&
|
||||
intValEqual(a.MaxOutputTokens, b.MaxOutputTokens)
|
||||
}
|
||||
|
||||
type modelsDevRegistry struct {
|
||||
ByProviderModel map[string]upstreamCap // key: providerID|modelID
|
||||
ByModel map[string]upstreamCap // fallback: modelID
|
||||
@@ -707,9 +733,13 @@ func (a *capAgg) finalize(name string) modelcap.Model {
|
||||
}
|
||||
|
||||
func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *modelsDevRegistry) (map[string]modelcap.Model, map[string]string, error) {
|
||||
var providers []model.Provider
|
||||
if err := s.db.Find(&providers).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load providers: %w", err)
|
||||
var groups []model.ProviderGroup
|
||||
if err := s.db.Find(&groups).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load provider groups: %w", err)
|
||||
}
|
||||
var apiKeys []model.APIKey
|
||||
if err := s.db.Find(&apiKeys).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load api keys: %w", err)
|
||||
}
|
||||
var bindings []model.Binding
|
||||
if err := s.db.Find(&bindings).Error; err != nil {
|
||||
@@ -718,21 +748,29 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
||||
|
||||
type providerLite struct {
|
||||
id uint
|
||||
group string
|
||||
ptype string
|
||||
models []string
|
||||
}
|
||||
providersByGroup := make(map[string][]providerLite)
|
||||
providersByGroupID := make(map[uint]providerLite)
|
||||
now := time.Now().Unix()
|
||||
for _, p := range providers {
|
||||
if strings.TrimSpace(p.Status) != "" && strings.TrimSpace(p.Status) != "active" {
|
||||
activeKeys := make(map[uint]bool)
|
||||
for _, k := range apiKeys {
|
||||
if strings.TrimSpace(k.Status) != "" && strings.TrimSpace(k.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
if p.BanUntil != nil && p.BanUntil.UTC().Unix() > now {
|
||||
if k.BanUntil != nil && k.BanUntil.UTC().Unix() > now {
|
||||
continue
|
||||
}
|
||||
group := groupx.Normalize(p.Group)
|
||||
rawModels := strings.Split(p.Models, ",")
|
||||
activeKeys[k.GroupID] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if strings.TrimSpace(g.Status) != "" && strings.TrimSpace(g.Status) != "active" {
|
||||
continue
|
||||
}
|
||||
if !activeKeys[g.ID] {
|
||||
continue
|
||||
}
|
||||
rawModels := strings.Split(g.Models, ",")
|
||||
var outModels []string
|
||||
for _, m := range rawModels {
|
||||
m = strings.TrimSpace(m)
|
||||
@@ -740,19 +778,20 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
||||
outModels = append(outModels, m)
|
||||
}
|
||||
}
|
||||
if group == "" || len(outModels) == 0 {
|
||||
if len(outModels) == 0 {
|
||||
continue
|
||||
}
|
||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
||||
id: p.ID,
|
||||
group: group,
|
||||
ptype: strings.TrimSpace(p.Type),
|
||||
providersByGroupID[g.ID] = providerLite{
|
||||
id: g.ID,
|
||||
ptype: strings.TrimSpace(g.Type),
|
||||
models: outModels,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
modelsOut := make(map[string]modelcap.Model)
|
||||
payloads := make(map[string]string)
|
||||
capBaseline := make(map[string]upstreamCap)
|
||||
capBaselineOK := make(map[string]bool)
|
||||
|
||||
for _, b := range bindings {
|
||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||
@@ -764,12 +803,8 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
||||
continue
|
||||
}
|
||||
key := ns + "." + pm
|
||||
rg := groupx.Normalize(b.RouteGroup)
|
||||
if rg == "" {
|
||||
continue
|
||||
}
|
||||
pgroup := providersByGroup[rg]
|
||||
if len(pgroup) == 0 {
|
||||
group := providersByGroupID[b.GroupID]
|
||||
if group.id == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -782,13 +817,25 @@ func (s *ModelRegistryService) buildBindingModels(ctx context.Context, reg *mode
|
||||
selectorType := routing.SelectorType(strings.TrimSpace(b.SelectorType))
|
||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||
|
||||
for _, p := range pgroup {
|
||||
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, p.models)
|
||||
if err != nil {
|
||||
continue
|
||||
up, err := routing.ResolveUpstreamModel(selectorType, selectorValue, pm, group.models)
|
||||
if err == nil {
|
||||
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(group.ptype), up)
|
||||
if baseOK, seen := capBaselineOK[key]; seen {
|
||||
if !ok || !baseOK || !capsEqual(capBaseline[key], cap) {
|
||||
return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key)
|
||||
}
|
||||
} else {
|
||||
capBaselineOK[key] = ok
|
||||
if ok {
|
||||
capBaseline[key] = cap
|
||||
}
|
||||
}
|
||||
cap, ok := lookupModelsDevCap(reg, modelsDevProviderKey(p.ptype), up)
|
||||
agg.merge(cap, ok)
|
||||
} else {
|
||||
if _, seen := capBaselineOK[key]; seen {
|
||||
return nil, nil, fmt.Errorf("bindingKey %s has inconsistent capabilities", key)
|
||||
}
|
||||
capBaselineOK[key] = false
|
||||
}
|
||||
|
||||
out := agg.finalize(key)
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestModelRegistry_Check(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,22 +58,31 @@ func TestModelRegistry_RefreshAndRollback(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}, &model.Model{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Provider{
|
||||
Name: "p1",
|
||||
Type: "openai",
|
||||
Group: "rg",
|
||||
Models: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
group := model.ProviderGroup{
|
||||
Name: "rg",
|
||||
Type: "openai",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Models: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
}
|
||||
if err := db.Create(&group).Error; err != nil {
|
||||
t.Fatalf("create provider group: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.APIKey{
|
||||
GroupID: group.ID,
|
||||
APIKey: "k",
|
||||
Status: "active",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("create provider: %v", err)
|
||||
t.Fatalf("create api key: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Binding{
|
||||
Namespace: "ns",
|
||||
PublicModel: "m",
|
||||
RouteGroup: "rg",
|
||||
GroupID: group.ID,
|
||||
Weight: 1,
|
||||
SelectorType: "exact",
|
||||
SelectorValue: "gpt-4o-mini",
|
||||
Status: "active",
|
||||
|
||||
@@ -77,78 +77,29 @@ func (s *SyncService) SyncMaster(master *model.Master) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
||||
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
||||
ctx := context.Background()
|
||||
group := groupx.Normalize(provider.Group)
|
||||
models := strings.Split(provider.Models, ",")
|
||||
|
||||
snap := providerSnapshot{
|
||||
ID: provider.ID,
|
||||
Name: provider.Name,
|
||||
Type: provider.Type,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: provider.APIKey,
|
||||
GoogleProject: provider.GoogleProject,
|
||||
GoogleLocation: provider.GoogleLocation,
|
||||
Group: group,
|
||||
Models: models,
|
||||
Weight: provider.Weight,
|
||||
Status: normalizeStatus(provider.Status),
|
||||
AutoBan: provider.AutoBan,
|
||||
BanReason: provider.BanReason,
|
||||
}
|
||||
if provider.BanUntil != nil {
|
||||
snap.BanUntil = provider.BanUntil.UTC().Unix()
|
||||
}
|
||||
|
||||
// 1. Update Provider Config
|
||||
if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id)
|
||||
// Note: This is an additive operation. Removing models requires full sync or smarter logic.
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, m := range models {
|
||||
m = strings.TrimSpace(m)
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
if snap.Status != "active" {
|
||||
continue
|
||||
}
|
||||
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
||||
continue
|
||||
}
|
||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
||||
pipe.SAdd(ctx, routeKey, provider.ID)
|
||||
}
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncProviderDelete removes provider snapshot and routing entries from Redis.
|
||||
func (s *SyncService) SyncProviderDelete(provider *model.Provider) error {
|
||||
if provider == nil {
|
||||
return fmt.Errorf("provider required")
|
||||
// SyncProviders rebuilds provider snapshots from ProviderGroup + APIKey tables.
|
||||
func (s *SyncService) SyncProviders(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db required")
|
||||
}
|
||||
ctx := context.Background()
|
||||
group := groupx.Normalize(provider.Group)
|
||||
models := strings.Split(provider.Models, ",")
|
||||
|
||||
var groups []model.ProviderGroup
|
||||
if err := db.Find(&groups).Error; err != nil {
|
||||
return fmt.Errorf("load provider groups: %w", err)
|
||||
}
|
||||
var apiKeys []model.APIKey
|
||||
if err := db.Find(&apiKeys).Error; err != nil {
|
||||
return fmt.Errorf("load api keys: %w", err)
|
||||
}
|
||||
|
||||
pipe := s.rdb.TxPipeline()
|
||||
pipe.HDel(ctx, "config:providers", fmt.Sprintf("%d", provider.ID))
|
||||
for _, m := range models {
|
||||
m = strings.TrimSpace(m)
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
||||
pipe.SRem(ctx, routeKey, provider.ID)
|
||||
pipe.Del(ctx, "config:providers")
|
||||
if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("delete provider snapshot: %w", err)
|
||||
return fmt.Errorf("write provider snapshot: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -203,6 +154,7 @@ type providerSnapshot struct {
|
||||
APIKey string `json:"api_key"`
|
||||
GoogleProject string `json:"google_project,omitempty"`
|
||||
GoogleLocation string `json:"google_location,omitempty"`
|
||||
GroupID uint `json:"group_id,omitempty"`
|
||||
Group string `json:"group"`
|
||||
Models []string `json:"models"`
|
||||
Weight int `json:"weight,omitempty"`
|
||||
@@ -212,15 +164,100 @@ type providerSnapshot struct {
|
||||
BanUntil int64 `json:"ban_until,omitempty"` // unix seconds
|
||||
}
|
||||
|
||||
func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pipeliner, groups []model.ProviderGroup, apiKeys []model.APIKey) error {
|
||||
groupMap := make(map[uint]model.ProviderGroup, len(groups))
|
||||
for _, g := range groups {
|
||||
groupMap[g.ID] = g
|
||||
}
|
||||
|
||||
for _, k := range apiKeys {
|
||||
g, ok := groupMap[k.GroupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
groupName := groupx.Normalize(g.Name)
|
||||
if strings.TrimSpace(groupName) == "" {
|
||||
groupName = "default"
|
||||
}
|
||||
groupStatus := normalizeStatus(g.Status)
|
||||
keyStatus := normalizeStatus(k.Status)
|
||||
status := keyStatus
|
||||
if groupStatus != "" && groupStatus != "active" {
|
||||
status = groupStatus
|
||||
}
|
||||
|
||||
rawModels := strings.Split(g.Models, ",")
|
||||
var models []string
|
||||
for _, m := range rawModels {
|
||||
m = strings.TrimSpace(m)
|
||||
if m != "" {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(g.Name)
|
||||
if name == "" {
|
||||
name = groupName
|
||||
}
|
||||
name = fmt.Sprintf("%s#%d", name, k.ID)
|
||||
|
||||
snap := providerSnapshot{
|
||||
ID: k.ID,
|
||||
Name: name,
|
||||
Type: strings.TrimSpace(g.Type),
|
||||
BaseURL: strings.TrimSpace(g.BaseURL),
|
||||
APIKey: strings.TrimSpace(k.APIKey),
|
||||
GoogleProject: strings.TrimSpace(g.GoogleProject),
|
||||
GoogleLocation: strings.TrimSpace(g.GoogleLocation),
|
||||
GroupID: g.ID,
|
||||
Group: groupName,
|
||||
Models: models,
|
||||
Weight: k.Weight,
|
||||
Status: status,
|
||||
AutoBan: k.AutoBan,
|
||||
BanReason: strings.TrimSpace(k.BanReason),
|
||||
}
|
||||
if k.BanUntil != nil {
|
||||
snap.BanUntil = k.BanUntil.UTC().Unix()
|
||||
}
|
||||
|
||||
payload, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal provider %d: %w", k.ID, err)
|
||||
}
|
||||
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", k.ID), payload)
|
||||
|
||||
// Legacy route table maintenance for compatibility.
|
||||
for _, m := range models {
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
if snap.Status != "active" {
|
||||
continue
|
||||
}
|
||||
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
||||
continue
|
||||
}
|
||||
routeKey := fmt.Sprintf("route:group:%s:%s", groupName, m)
|
||||
pipe.SAdd(ctx, routeKey, k.ID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// keySnapshot is no longer needed as we write directly to auth:token:*
|
||||
|
||||
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
|
||||
func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
ctx := context.Background()
|
||||
|
||||
var providers []model.Provider
|
||||
if err := db.Find(&providers).Error; err != nil {
|
||||
return fmt.Errorf("load providers: %w", err)
|
||||
var groups []model.ProviderGroup
|
||||
if err := db.Find(&groups).Error; err != nil {
|
||||
return fmt.Errorf("load provider groups: %w", err)
|
||||
}
|
||||
var apiKeys []model.APIKey
|
||||
if err := db.Find(&apiKeys).Error; err != nil {
|
||||
return fmt.Errorf("load api keys: %w", err)
|
||||
}
|
||||
|
||||
var keys []model.Key
|
||||
@@ -259,53 +296,8 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
pipe.Del(ctx, masterKeys...)
|
||||
}
|
||||
|
||||
// Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them)
|
||||
// For MVP, we rely on the fact that we are rebuilding.
|
||||
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
||||
|
||||
for _, p := range providers {
|
||||
group := groupx.Normalize(p.Group)
|
||||
models := strings.Split(p.Models, ",")
|
||||
|
||||
snap := providerSnapshot{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
BaseURL: p.BaseURL,
|
||||
APIKey: p.APIKey,
|
||||
GoogleProject: p.GoogleProject,
|
||||
GoogleLocation: p.GoogleLocation,
|
||||
Group: group,
|
||||
Models: models,
|
||||
Weight: p.Weight,
|
||||
Status: normalizeStatus(p.Status),
|
||||
AutoBan: p.AutoBan,
|
||||
BanReason: p.BanReason,
|
||||
}
|
||||
if p.BanUntil != nil {
|
||||
snap.BanUntil = p.BanUntil.UTC().Unix()
|
||||
}
|
||||
payload, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
|
||||
}
|
||||
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
|
||||
|
||||
// Rebuild Routing Table
|
||||
for _, m := range models {
|
||||
m = strings.TrimSpace(m)
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
if snap.Status != "active" {
|
||||
continue
|
||||
}
|
||||
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
||||
continue
|
||||
}
|
||||
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
||||
pipe.SAdd(ctx, routeKey, p.ID)
|
||||
}
|
||||
if err := s.writeProvidersSnapshot(ctx, pipe, groups, apiKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
@@ -382,7 +374,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil {
|
||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -398,9 +390,13 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
||||
func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
||||
ctx := context.Background()
|
||||
|
||||
var providers []model.Provider
|
||||
if err := db.Find(&providers).Error; err != nil {
|
||||
return fmt.Errorf("load providers: %w", err)
|
||||
var groups []model.ProviderGroup
|
||||
if err := db.Find(&groups).Error; err != nil {
|
||||
return fmt.Errorf("load provider groups: %w", err)
|
||||
}
|
||||
var apiKeys []model.APIKey
|
||||
if err := db.Find(&apiKeys).Error; err != nil {
|
||||
return fmt.Errorf("load api keys: %w", err)
|
||||
}
|
||||
var bindings []model.Binding
|
||||
if err := db.Find(&bindings).Error; err != nil {
|
||||
@@ -409,7 +405,7 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
||||
|
||||
pipe := s.rdb.TxPipeline()
|
||||
pipe.Del(ctx, "config:bindings", "meta:bindings_meta")
|
||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, providers); err != nil {
|
||||
if err := s.writeBindingsSnapshot(ctx, pipe, bindings, groups, apiKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
@@ -418,32 +414,65 @@ func (s *SyncService) SyncBindings(db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, providers []model.Provider) error {
|
||||
// Group providers by route group for selector resolution.
|
||||
type providerLite struct {
|
||||
id uint
|
||||
group string
|
||||
models []string
|
||||
func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipeliner, bindings []model.Binding, groups []model.ProviderGroup, apiKeys []model.APIKey) error {
|
||||
type groupLite struct {
|
||||
id uint
|
||||
name string
|
||||
ptype string
|
||||
baseURL string
|
||||
googleProject string
|
||||
googleLoc string
|
||||
models []string
|
||||
status string
|
||||
}
|
||||
providersByGroup := make(map[string][]providerLite)
|
||||
for _, p := range providers {
|
||||
group := groupx.Normalize(p.Group)
|
||||
models := strings.Split(p.Models, ",")
|
||||
groupsByID := make(map[uint]groupLite, len(groups))
|
||||
for _, g := range groups {
|
||||
rawModels := strings.Split(g.Models, ",")
|
||||
var outModels []string
|
||||
for _, m := range models {
|
||||
for _, m := range rawModels {
|
||||
m = strings.TrimSpace(m)
|
||||
if m != "" {
|
||||
outModels = append(outModels, m)
|
||||
}
|
||||
}
|
||||
providersByGroup[group] = append(providersByGroup[group], providerLite{
|
||||
id: p.ID,
|
||||
group: group,
|
||||
models: outModels,
|
||||
groupsByID[g.ID] = groupLite{
|
||||
id: g.ID,
|
||||
name: groupx.Normalize(g.Name),
|
||||
ptype: strings.TrimSpace(g.Type),
|
||||
baseURL: strings.TrimSpace(g.BaseURL),
|
||||
googleProject: strings.TrimSpace(g.GoogleProject),
|
||||
googleLoc: strings.TrimSpace(g.GoogleLocation),
|
||||
models: outModels,
|
||||
status: normalizeStatus(g.Status),
|
||||
}
|
||||
}
|
||||
|
||||
type apiKeyLite struct {
|
||||
id uint
|
||||
groupID uint
|
||||
status string
|
||||
weight int
|
||||
autoBan bool
|
||||
banUntil *time.Time
|
||||
}
|
||||
keysByGroup := make(map[uint][]apiKeyLite)
|
||||
for _, k := range apiKeys {
|
||||
keysByGroup[k.GroupID] = append(keysByGroup[k.GroupID], apiKeyLite{
|
||||
id: k.ID,
|
||||
groupID: k.GroupID,
|
||||
status: normalizeStatus(k.Status),
|
||||
weight: k.Weight,
|
||||
autoBan: k.AutoBan,
|
||||
banUntil: k.BanUntil,
|
||||
})
|
||||
}
|
||||
|
||||
type bindingAgg struct {
|
||||
snap routing.BindingSnapshot
|
||||
}
|
||||
snaps := make(map[string]*routing.BindingSnapshot)
|
||||
now := time.Now().Unix()
|
||||
|
||||
for _, b := range bindings {
|
||||
if strings.TrimSpace(b.Status) != "" && strings.TrimSpace(b.Status) != "active" {
|
||||
continue
|
||||
@@ -453,43 +482,65 @@ func (s *SyncService) writeBindingsSnapshot(ctx context.Context, pipe redis.Pipe
|
||||
if ns == "" || pm == "" {
|
||||
continue
|
||||
}
|
||||
rg := groupx.Normalize(b.RouteGroup)
|
||||
if rg == "" {
|
||||
group, ok := groupsByID[b.GroupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if group.status != "" && group.status != "active" {
|
||||
continue
|
||||
}
|
||||
|
||||
snap := struct {
|
||||
Namespace string `json:"namespace"`
|
||||
PublicModel string `json:"public_model"`
|
||||
RouteGroup string `json:"route_group"`
|
||||
SelectorType string `json:"selector_type,omitempty"`
|
||||
SelectorValue string `json:"selector_value,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
Upstreams map[string]string `json:"upstreams"`
|
||||
}{
|
||||
Namespace: ns,
|
||||
PublicModel: pm,
|
||||
RouteGroup: rg,
|
||||
key := ns + "." + pm
|
||||
snap := snaps[key]
|
||||
if snap == nil {
|
||||
snap = &routing.BindingSnapshot{
|
||||
Namespace: ns,
|
||||
PublicModel: pm,
|
||||
Status: "active",
|
||||
UpdatedAt: now,
|
||||
}
|
||||
snaps[key] = snap
|
||||
}
|
||||
|
||||
candidate := routing.BindingCandidate{
|
||||
GroupID: group.id,
|
||||
RouteGroup: group.name,
|
||||
Weight: normalizeWeight(b.Weight),
|
||||
SelectorType: strings.TrimSpace(b.SelectorType),
|
||||
SelectorValue: strings.TrimSpace(b.SelectorValue),
|
||||
Status: "active",
|
||||
UpdatedAt: now,
|
||||
Upstreams: make(map[string]string),
|
||||
}
|
||||
|
||||
selectorType := strings.TrimSpace(b.SelectorType)
|
||||
selectorValue := strings.TrimSpace(b.SelectorValue)
|
||||
keys := keysByGroup[b.GroupID]
|
||||
if len(keys) == 0 {
|
||||
candidate.Error = "no_provider"
|
||||
}
|
||||
|
||||
for _, p := range providersByGroup[rg] {
|
||||
up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, p.models)
|
||||
nowUnix := time.Now().Unix()
|
||||
for _, k := range keys {
|
||||
if k.status != "" && k.status != "active" {
|
||||
continue
|
||||
}
|
||||
if k.banUntil != nil && k.banUntil.UTC().Unix() > nowUnix {
|
||||
continue
|
||||
}
|
||||
up, err := routing.ResolveUpstreamModel(routing.SelectorType(selectorType), selectorValue, pm, group.models)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
snap.Upstreams[fmt.Sprintf("%d", p.id)] = up
|
||||
candidate.Upstreams[fmt.Sprintf("%d", k.id)] = up
|
||||
}
|
||||
if len(candidate.Upstreams) == 0 && candidate.Error == "" {
|
||||
candidate.Error = "config_error"
|
||||
}
|
||||
|
||||
key := ns + "." + pm
|
||||
snap.Candidates = append(snap.Candidates, candidate)
|
||||
}
|
||||
|
||||
for key, snap := range snaps {
|
||||
payload, err := jsoncodec.Marshal(snap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config:bindings:%s: %w", key, err)
|
||||
@@ -519,6 +570,13 @@ func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val inter
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeWeight(weight int) int {
|
||||
if weight <= 0 {
|
||||
return 1
|
||||
}
|
||||
return weight
|
||||
}
|
||||
|
||||
func normalizeStatus(status string) string {
|
||||
st := strings.ToLower(strings.TrimSpace(status))
|
||||
if st == "" {
|
||||
|
||||
@@ -13,10 +13,13 @@ import (
|
||||
)
|
||||
|
||||
type bindingSnapshot struct {
|
||||
Namespace string `json:"namespace"`
|
||||
PublicModel string `json:"public_model"`
|
||||
RouteGroup string `json:"route_group"`
|
||||
Upstreams map[string]string `json:"upstreams"`
|
||||
Namespace string `json:"namespace"`
|
||||
PublicModel string `json:"public_model"`
|
||||
Candidates []struct {
|
||||
RouteGroup string `json:"route_group"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Upstreams map[string]string `json:"upstreams"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
func TestSyncBindings_SelectorExact(t *testing.T) {
|
||||
@@ -26,15 +29,19 @@ func TestSyncBindings_SelectorExact(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil {
|
||||
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
p := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "m"}
|
||||
if err := db.Create(&p).Error; err != nil {
|
||||
t.Fatalf("create provider: %v", err)
|
||||
group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "m", Status: "active"}
|
||||
if err := db.Create(&group).Error; err != nil {
|
||||
t.Fatalf("create group: %v", err)
|
||||
}
|
||||
b := model.Binding{Namespace: "ns", PublicModel: "m", RouteGroup: "rg", SelectorType: "exact", Status: "active"}
|
||||
key := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"}
|
||||
if err := db.Create(&key).Error; err != nil {
|
||||
t.Fatalf("create api key: %v", err)
|
||||
}
|
||||
b := model.Binding{Namespace: "ns", PublicModel: "m", GroupID: group.ID, Weight: 1, SelectorType: "exact", Status: "active"}
|
||||
if err := db.Create(&b).Error; err != nil {
|
||||
t.Fatalf("create binding: %v", err)
|
||||
}
|
||||
@@ -54,8 +61,11 @@ func TestSyncBindings_SelectorExact(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if snap.Upstreams == nil || snap.Upstreams[jsonID(p.ID)] != "m" {
|
||||
t.Fatalf("unexpected upstreams: %+v", snap.Upstreams)
|
||||
if len(snap.Candidates) != 1 {
|
||||
t.Fatalf("expected 1 candidate, got %+v", snap.Candidates)
|
||||
}
|
||||
if snap.Candidates[0].Upstreams == nil || snap.Candidates[0].Upstreams[jsonID(key.ID)] != "m" {
|
||||
t.Fatalf("unexpected upstreams: %+v", snap.Candidates[0].Upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,27 +76,31 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Provider{}, &model.Binding{}); err != nil {
|
||||
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}, &model.Binding{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
p1 := model.Provider{Name: "p1", Type: "openai", Group: "rg", Models: "moonshot/kimi2,kimi2"}
|
||||
p2 := model.Provider{Name: "p2", Type: "openai", Group: "rg", Models: "moonshot/kimi2"}
|
||||
if err := db.Create(&p1).Error; err != nil {
|
||||
t.Fatalf("create provider1: %v", err)
|
||||
group := model.ProviderGroup{Name: "rg", Type: "openai", BaseURL: "https://api.openai.com/v1", Models: "moonshot/kimi2,kimi2", Status: "active"}
|
||||
if err := db.Create(&group).Error; err != nil {
|
||||
t.Fatalf("create group: %v", err)
|
||||
}
|
||||
if err := db.Create(&p2).Error; err != nil {
|
||||
t.Fatalf("create provider2: %v", err)
|
||||
k1 := model.APIKey{GroupID: group.ID, APIKey: "k1", Status: "active"}
|
||||
k2 := model.APIKey{GroupID: group.ID, APIKey: "k2", Status: "active"}
|
||||
if err := db.Create(&k1).Error; err != nil {
|
||||
t.Fatalf("create api key1: %v", err)
|
||||
}
|
||||
if err := db.Create(&k2).Error; err != nil {
|
||||
t.Fatalf("create api key2: %v", err)
|
||||
}
|
||||
|
||||
// Regex should match uniquely (moonshot/kimi2 only).
|
||||
bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", RouteGroup: "rg", SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"}
|
||||
bRegex := model.Binding{Namespace: "ns", PublicModel: "kimi2", GroupID: group.ID, Weight: 1, SelectorType: "regex", SelectorValue: "^moonshot/kimi2$", Status: "active"}
|
||||
if err := db.Create(&bRegex).Error; err != nil {
|
||||
t.Fatalf("create binding regex: %v", err)
|
||||
}
|
||||
|
||||
// Normalize_exact should match p2 (moonshot/kimi2) for "kimi2".
|
||||
bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", RouteGroup: "rg", SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"}
|
||||
bNorm := model.Binding{Namespace: "ns", PublicModel: "kimi2-n", GroupID: group.ID, Weight: 1, SelectorType: "normalize_exact", SelectorValue: "kimi2", Status: "active"}
|
||||
if err := db.Create(&bNorm).Error; err != nil {
|
||||
t.Fatalf("create binding normalize: %v", err)
|
||||
}
|
||||
@@ -104,8 +118,12 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(raw), &snapRegex); err != nil {
|
||||
t.Fatalf("unmarshal regex: %v", err)
|
||||
}
|
||||
if snapRegex.Upstreams[jsonID(p1.ID)] != "moonshot/kimi2" || snapRegex.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" {
|
||||
t.Fatalf("unexpected regex upstreams: %+v", snapRegex.Upstreams)
|
||||
if len(snapRegex.Candidates) != 1 {
|
||||
t.Fatalf("expected 1 candidate, got %+v", snapRegex.Candidates)
|
||||
}
|
||||
upstreams := snapRegex.Candidates[0].Upstreams
|
||||
if upstreams[jsonID(k1.ID)] != "moonshot/kimi2" || upstreams[jsonID(k2.ID)] != "moonshot/kimi2" {
|
||||
t.Fatalf("unexpected regex upstreams: %+v", upstreams)
|
||||
}
|
||||
|
||||
// Normalize_exact binding should include p2 but exclude p1 due to multi-match (moonshot/kimi2 + kimi2).
|
||||
@@ -114,11 +132,11 @@ func TestSyncBindings_SelectorRegexAndNormalize(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(raw), &snapNorm); err != nil {
|
||||
t.Fatalf("unmarshal normalize: %v", err)
|
||||
}
|
||||
if snapNorm.Upstreams[jsonID(p2.ID)] != "moonshot/kimi2" {
|
||||
t.Fatalf("expected p2 upstream, got %+v", snapNorm.Upstreams)
|
||||
if len(snapNorm.Candidates) != 1 {
|
||||
t.Fatalf("expected 1 candidate, got %+v", snapNorm.Candidates)
|
||||
}
|
||||
if _, ok := snapNorm.Upstreams[jsonID(p1.ID)]; ok {
|
||||
t.Fatalf("did not expect p1 upstream due to normalize multi-match, got %+v", snapNorm.Upstreams)
|
||||
if len(snapNorm.Candidates[0].Upstreams) != 0 || snapNorm.Candidates[0].Error != "config_error" {
|
||||
t.Fatalf("expected config_error with no upstreams, got %+v", snapNorm.Candidates[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,68 +2,76 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/ez-api/foundation/contract"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) {
|
||||
goldenRaw := contract.ProviderSnapshotJSON()
|
||||
var golden map[string]any
|
||||
if err := json.Unmarshal(goldenRaw, &golden); err != nil {
|
||||
t.Fatalf("parse golden json: %v", err)
|
||||
}
|
||||
|
||||
func TestSyncProviders_WritesSnapshotAndRouting(t *testing.T) {
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
|
||||
svc := NewSyncService(rdb)
|
||||
|
||||
p := &model.Provider{
|
||||
Name: "p1",
|
||||
db, err := gorm.Open(sqlite.Open("file:"+t.Name()+"?mode=memory&cache=shared"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.ProviderGroup{}, &model.APIKey{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
group := model.ProviderGroup{
|
||||
Name: "default",
|
||||
Type: "vertex-express",
|
||||
Group: "default",
|
||||
BaseURL: "https://vertex.example",
|
||||
GoogleLocation: "global",
|
||||
Models: "gemini-3-pro-preview",
|
||||
Status: "active",
|
||||
AutoBan: true,
|
||||
GoogleProject: "",
|
||||
GoogleLocation: "global",
|
||||
}
|
||||
p.ID = 42
|
||||
|
||||
if err := svc.SyncProvider(p); err != nil {
|
||||
t.Fatalf("SyncProvider: %v", err)
|
||||
if err := db.Create(&group).Error; err != nil {
|
||||
t.Fatalf("create group: %v", err)
|
||||
}
|
||||
key := model.APIKey{
|
||||
GroupID: group.ID,
|
||||
APIKey: "k",
|
||||
Status: "active",
|
||||
AutoBan: true,
|
||||
}
|
||||
if err := db.Create(&key).Error; err != nil {
|
||||
t.Fatalf("create key: %v", err)
|
||||
}
|
||||
|
||||
raw := mr.HGet("config:providers", "42")
|
||||
if err := svc.SyncProviders(db); err != nil {
|
||||
t.Fatalf("SyncProviders: %v", err)
|
||||
}
|
||||
|
||||
raw := mr.HGet("config:providers", jsonID(key.ID))
|
||||
if raw == "" {
|
||||
t.Fatalf("expected config:providers hash entry")
|
||||
}
|
||||
|
||||
var snap map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &snap); err != nil {
|
||||
t.Fatalf("invalid snapshot json: %v", err)
|
||||
}
|
||||
for k, v := range golden {
|
||||
if !reflect.DeepEqual(snap[k], v) {
|
||||
t.Fatalf("snapshot mismatch for %q: got=%#v want=%#v", k, snap[k], v)
|
||||
}
|
||||
if snap["group"] != "default" {
|
||||
t.Fatalf("expected group default, got %#v", snap["group"])
|
||||
}
|
||||
|
||||
routeKey := "route:group:default:gemini-3-pro-preview"
|
||||
if !mr.Exists(routeKey) {
|
||||
t.Fatalf("expected routing key %q to exist", routeKey)
|
||||
}
|
||||
ok, err := mr.SIsMember(routeKey, "42")
|
||||
ok, err := mr.SIsMember(routeKey, jsonID(key.ID))
|
||||
if err != nil {
|
||||
t.Fatalf("SIsMember: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("expected provider id 42 in routing set %q", routeKey)
|
||||
t.Fatalf("expected provider id in routing set %q", routeKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,34 +121,6 @@ func TestSyncModelDelete_RemovesMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncProviderDelete_RemovesSnapshotAndRouting(t *testing.T) {
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
svc := NewSyncService(rdb)
|
||||
|
||||
p := &model.Provider{
|
||||
Name: "p1",
|
||||
Type: "openai",
|
||||
Group: "default",
|
||||
Models: "gpt-4o-mini,gpt-4o",
|
||||
Status: "active",
|
||||
}
|
||||
p.ID = 7
|
||||
|
||||
if err := svc.SyncProvider(p); err != nil {
|
||||
t.Fatalf("SyncProvider: %v", err)
|
||||
}
|
||||
if err := svc.SyncProviderDelete(p); err != nil {
|
||||
t.Fatalf("SyncProviderDelete: %v", err)
|
||||
}
|
||||
|
||||
if got := mr.HGet("config:providers", "7"); got != "" {
|
||||
t.Fatalf("expected provider snapshot removed, got %q", got)
|
||||
}
|
||||
if ok, _ := mr.SIsMember("route:group:default:gpt-4o-mini", "7"); ok {
|
||||
t.Fatalf("expected provider removed from route set")
|
||||
}
|
||||
if ok, _ := mr.SIsMember("route:group:default:gpt-4o", "7"); ok {
|
||||
t.Fatalf("expected provider removed from route set")
|
||||
}
|
||||
func jsonID(id uint) string {
|
||||
return strconv.FormatUint(uint64(id), 10)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user