diff --git a/internal/api/handler.go b/internal/api/handler.go index 596415e..f90cad7 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -3,6 +3,7 @@ package api import ( "net/http" "strconv" + "strings" "github.com/ez-api/ez-api/internal/dto" "github.com/ez-api/ez-api/internal/model" @@ -28,12 +29,17 @@ func (h *Handler) CreateKey(c *gin.Context) { return } + group := strings.TrimSpace(req.Group) + if group == "" { + group = "default" + } + key := model.Key{ - ProviderID: &req.ProviderID, - KeySecret: req.KeySecret, - Balance: req.Balance, - Status: req.Status, - Weight: req.Weight, + KeySecret: req.KeySecret, + Group: group, + Balance: req.Balance, + Status: req.Status, + Weight: req.Weight, } if err := h.db.Create(&key).Error; err != nil { @@ -42,7 +48,7 @@ func (h *Handler) CreateKey(c *gin.Context) { } // Write auth hash and refresh snapshots - if err := h.sync.SyncAll(h.db); err != nil { + if err := h.sync.SyncKey(&key); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync key to Redis", "details": err.Error()}) return } @@ -57,11 +63,17 @@ func (h *Handler) CreateProvider(c *gin.Context) { return } + group := strings.TrimSpace(req.Group) + if group == "" { + group = "default" + } + provider := model.Provider{ Name: req.Name, Type: req.Type, BaseURL: req.BaseURL, APIKey: req.APIKey, + Group: group, } if err := h.db.Create(&provider).Error; err != nil { @@ -69,7 +81,7 @@ func (h *Handler) CreateProvider(c *gin.Context) { return } - if err := h.sync.SyncAll(h.db); err != nil { + if err := h.sync.SyncProvider(&provider); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync provider", "details": err.Error()}) return } @@ -100,7 +112,7 @@ func (h *Handler) CreateModel(c *gin.Context) { return } - if err := h.sync.SyncAll(h.db); err != nil { + if err := h.sync.SyncModel(&modelReq); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync model", "details": err.Error()}) return } @@ -151,7 +163,7 @@ func (h *Handler) UpdateModel(c *gin.Context) { return } - if err := h.sync.SyncAll(h.db); err != nil { + if err := h.sync.SyncModel(&existing); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync model", "details": err.Error()}) return } diff --git a/internal/dto/key.go b/internal/dto/key.go index 976096a..685b98f 100644 --- a/internal/dto/key.go +++ b/internal/dto/key.go @@ -2,9 +2,9 @@ package dto // KeyDTO defines payload for key creation/update. type KeyDTO struct { - ProviderID uint `json:"provider_id"` - KeySecret string `json:"key_secret"` - Balance float64 `json:"balance"` - Status string `json:"status"` - Weight int `json:"weight"` + Group string `json:"group"` + KeySecret string `json:"key_secret"` + Balance float64 `json:"balance"` + Status string `json:"status"` + Weight int `json:"weight"` } diff --git a/internal/dto/provider.go b/internal/dto/provider.go index 373693b..69df49d 100644 --- a/internal/dto/provider.go +++ b/internal/dto/provider.go @@ -2,8 +2,9 @@ package dto // ProviderDTO defines inbound payload for provider creation/update. type ProviderDTO struct { - Name string `json:"name"` - Type string `json:"type"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` + Name string `json:"name"` + Type string `json:"type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + Group string `json:"group"` } diff --git a/internal/model/models.go b/internal/model/models.go index 1273332..0e42dcf 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -17,16 +17,16 @@ type Provider struct { Type string `gorm:"not null" json:"type"` // openai, anthropic, etc. BaseURL string `json:"base_url"` APIKey string `json:"api_key"` + Group string `gorm:"default:'default'" json:"group"` // routing group/tier } type Key struct { gorm.Model - ProviderID *uint `json:"provider_id"` - Provider *Provider `json:"-"` - KeySecret string `gorm:"not null" json:"key_secret"` - Balance float64 `json:"balance"` - Status string `gorm:"default:'active'" json:"status"` // active, suspended - Weight int `gorm:"default:10" json:"weight"` + KeySecret string `gorm:"not null" json:"key_secret"` + Group string `gorm:"default:'default'" json:"group"` // routing group/tier + Balance float64 `json:"balance"` + Status string `gorm:"default:'active'" json:"status"` // active, suspended + Weight int `gorm:"default:10" json:"weight"` } type Model struct { diff --git a/internal/service/sync.go b/internal/service/sync.go index c2df3ab..6542bf1 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strings" "github.com/ez-api/ez-api/internal/model" "github.com/redis/go-redis/v9" @@ -20,21 +21,62 @@ func NewSyncService(rdb *redis.Client) *SyncService { return &SyncService{rdb: rdb} } +// SyncKey writes a single key into Redis without rebuilding the entire snapshot. func (s *SyncService) SyncKey(key *model.Key) error { - tokenHash := hashToken(key.KeySecret) - redisKey := fmt.Sprintf("auth:token:%s", tokenHash) + ctx := context.Background() + snap := keySnapshot{ + ID: key.ID, + TokenHash: hashToken(key.KeySecret), + Group: normalizeGroup(key.Group), + Status: key.Status, + Weight: key.Weight, + Balance: key.Balance, + } + + if err := s.hsetJSON(ctx, "config:keys", snap.TokenHash, snap); err != nil { + return err + } fields := map[string]interface{}{ - "status": key.Status, - "balance": key.Balance, + "status": snap.Status, + "group": snap.Group, + "weight": snap.Weight, + "balance": snap.Balance, } - if key.ProviderID != nil { - fields["provider_id"] = *key.ProviderID - } else { - fields["provider_id"] = 0 + if err := s.rdb.HSet(ctx, fmt.Sprintf("auth:token:%s", snap.TokenHash), fields).Err(); err != nil { + return fmt.Errorf("write auth token: %w", err) } + return nil +} - return s.rdb.HSet(context.Background(), redisKey, fields).Err() +// SyncProvider writes a single provider into Redis hash storage. +func (s *SyncService) SyncProvider(provider *model.Provider) error { + ctx := context.Background() + snap := providerSnapshot{ + ID: provider.ID, + Name: provider.Name, + Type: provider.Type, + BaseURL: provider.BaseURL, + APIKey: provider.APIKey, + Group: normalizeGroup(provider.Group), + } + return s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap) +} + +// SyncModel writes a single model metadata record. +func (s *SyncService) SyncModel(m *model.Model) error { + ctx := context.Background() + snap := modelSnapshot{ + Name: m.Name, + ContextWindow: m.ContextWindow, + CostPerToken: m.CostPerToken, + SupportsVision: m.SupportsVision, + SupportsFunction: m.SupportsFunctions, + SupportsToolChoice: m.SupportsToolChoice, + SupportsFIM: m.SupportsFIM, + MaxOutputTokens: m.MaxOutputTokens, + } + return s.hsetJSON(ctx, "meta:models", snap.Name, snap) } type providerSnapshot struct { @@ -43,15 +85,16 @@ type providerSnapshot struct { Type string `json:"type"` BaseURL string `json:"base_url"` APIKey string `json:"api_key"` + Group string `json:"group"` } type keySnapshot struct { - ID uint `json:"id"` - ProviderID uint `json:"provider_id"` - TokenHash string `json:"token_hash"` - Status string `json:"status"` - Weight int `json:"weight"` - Balance float64 `json:"balance"` + ID uint `json:"id"` + TokenHash string `json:"token_hash"` + Group string `json:"group"` + Status string `json:"status"` + Weight int `json:"weight"` + Balance float64 `json:"balance"` } type modelSnapshot struct { @@ -65,69 +108,68 @@ type modelSnapshot struct { MaxOutputTokens int `json:"max_output_tokens"` } -// SyncAll writes full snapshots (providers/keys/models) into Redis for DP consumption. +// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes. func (s *SyncService) SyncAll(db *gorm.DB) error { ctx := context.Background() - // Providers snapshot var providers []model.Provider if err := db.Find(&providers).Error; err != nil { return fmt.Errorf("load providers: %w", err) } - providerSnap := make([]providerSnapshot, 0, len(providers)) + + var keys []model.Key + if err := db.Find(&keys).Error; err != nil { + return fmt.Errorf("load keys: %w", err) + } + + var models []model.Model + if err := db.Find(&models).Error; err != nil { + return fmt.Errorf("load models: %w", err) + } + + pipe := s.rdb.TxPipeline() + pipe.Del(ctx, "config:providers", "config:keys", "meta:models") + for _, p := range providers { - providerSnap = append(providerSnap, providerSnapshot{ + snap := providerSnapshot{ ID: p.ID, Name: p.Name, Type: p.Type, BaseURL: p.BaseURL, APIKey: p.APIKey, - }) - } - if err := s.storeJSON(ctx, "config:providers", providerSnap); err != nil { - return err + Group: normalizeGroup(p.Group), + } + payload, err := json.Marshal(snap) + if err != nil { + return fmt.Errorf("marshal provider %d: %w", p.ID, err) + } + pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload) } - // Keys snapshot + auth hashes - var keys []model.Key - if err := db.Find(&keys).Error; err != nil { - return fmt.Errorf("load keys: %w", err) - } - keySnap := make([]keySnapshot, 0, len(keys)) for _, k := range keys { - tokenHash := hashToken(k.KeySecret) - keySnap = append(keySnap, keySnapshot{ - ID: k.ID, - ProviderID: firstID(k.ProviderID), - TokenHash: tokenHash, - Status: k.Status, - Weight: k.Weight, - Balance: k.Balance, + snap := keySnapshot{ + ID: k.ID, + TokenHash: hashToken(k.KeySecret), + Group: normalizeGroup(k.Group), + Status: k.Status, + Weight: k.Weight, + Balance: k.Balance, + } + payload, err := json.Marshal(snap) + if err != nil { + return fmt.Errorf("marshal key %d: %w", k.ID, err) + } + pipe.HSet(ctx, "config:keys", snap.TokenHash, payload) + pipe.HSet(ctx, fmt.Sprintf("auth:token:%s", snap.TokenHash), map[string]interface{}{ + "status": snap.Status, + "group": snap.Group, + "weight": snap.Weight, + "balance": snap.Balance, }) - - // Maintain per-token auth hash for quick checks - fields := map[string]interface{}{ - "status": k.Status, - "provider_id": firstID(k.ProviderID), - "weight": k.Weight, - "balance": k.Balance, - } - if err := s.rdb.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), fields).Err(); err != nil { - return fmt.Errorf("write auth token: %w", err) - } - } - if err := s.storeJSON(ctx, "config:keys", keySnap); err != nil { - return err } - // Models snapshot - var models []model.Model - if err := db.Find(&models).Error; err != nil { - return fmt.Errorf("load models: %w", err) - } - modelSnap := make([]modelSnapshot, 0, len(models)) for _, m := range models { - modelSnap = append(modelSnap, modelSnapshot{ + snap := modelSnapshot{ Name: m.Name, ContextWindow: m.ContextWindow, CostPerToken: m.CostPerToken, @@ -136,22 +178,28 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { SupportsToolChoice: m.SupportsToolChoice, SupportsFIM: m.SupportsFIM, MaxOutputTokens: m.MaxOutputTokens, - }) + } + payload, err := json.Marshal(snap) + if err != nil { + return fmt.Errorf("marshal model %s: %w", m.Name, err) + } + pipe.HSet(ctx, "meta:models", snap.Name, payload) } - if err := s.storeJSON(ctx, "meta:models", modelSnap); err != nil { - return err + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("write snapshots: %w", err) } return nil } -func (s *SyncService) storeJSON(ctx context.Context, key string, val interface{}) error { +func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val interface{}) error { payload, err := json.Marshal(val) if err != nil { - return fmt.Errorf("marshal %s: %w", key, err) + return fmt.Errorf("marshal %s:%s: %w", key, field, err) } - if err := s.rdb.Set(ctx, key, payload, 0).Err(); err != nil { - return fmt.Errorf("write %s: %w", key, err) + if err := s.rdb.HSet(ctx, key, field, payload).Err(); err != nil { + return fmt.Errorf("write %s:%s: %w", key, field, err) } return nil } @@ -162,9 +210,9 @@ func hashToken(token string) string { return hex.EncodeToString(hasher.Sum(nil)) } -func firstID(id *uint) uint { - if id == nil { - return 0 +func normalizeGroup(group string) string { + if strings.TrimSpace(group) == "" { + return "default" } - return *id + return group }