mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-14 01:17:52 +00:00
feat(provider): add Models field to ProviderDTO and update provider handling
This commit is contained in:
@@ -74,6 +74,7 @@ func (h *Handler) CreateProvider(c *gin.Context) {
|
|||||||
BaseURL: req.BaseURL,
|
BaseURL: req.BaseURL,
|
||||||
APIKey: req.APIKey,
|
APIKey: req.APIKey,
|
||||||
Group: group,
|
Group: group,
|
||||||
|
Models: strings.Join(req.Models, ","),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Create(&provider).Error; err != nil {
|
if err := h.db.Create(&provider).Error; err != nil {
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package dto
|
|||||||
|
|
||||||
// ProviderDTO defines inbound payload for provider creation/update.
|
// ProviderDTO defines inbound payload for provider creation/update.
|
||||||
type ProviderDTO struct {
|
type ProviderDTO struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
APIKey string `json:"api_key"`
|
APIKey string `json:"api_key"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
|
Models []string `json:"models"` // List of supported model names
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Provider struct {
|
|||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
APIKey string `json:"api_key"`
|
APIKey string `json:"api_key"`
|
||||||
Group string `gorm:"default:'default'" json:"group"` // routing group/tier
|
Group string `gorm:"default:'default'" json:"group"` // routing group/tier
|
||||||
|
Models string `json:"models"` // comma-separated list of supported models (e.g. "gpt-4,gpt-3.5-turbo")
|
||||||
}
|
}
|
||||||
|
|
||||||
type Key struct {
|
type Key struct {
|
||||||
|
|||||||
@@ -49,18 +49,40 @@ func (s *SyncService) SyncKey(key *model.Key) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncProvider writes a single provider into Redis hash storage.
|
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
||||||
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
group := normalizeGroup(provider.Group)
|
||||||
|
models := strings.Split(provider.Models, ",")
|
||||||
|
|
||||||
snap := providerSnapshot{
|
snap := providerSnapshot{
|
||||||
ID: provider.ID,
|
ID: provider.ID,
|
||||||
Name: provider.Name,
|
Name: provider.Name,
|
||||||
Type: provider.Type,
|
Type: provider.Type,
|
||||||
BaseURL: provider.BaseURL,
|
BaseURL: provider.BaseURL,
|
||||||
APIKey: provider.APIKey,
|
APIKey: provider.APIKey,
|
||||||
Group: normalizeGroup(provider.Group),
|
Group: group,
|
||||||
|
Models: models,
|
||||||
}
|
}
|
||||||
return s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap)
|
|
||||||
|
// 1. Update Provider Config
|
||||||
|
if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id)
|
||||||
|
// Note: This is an additive operation. Removing models requires full sync or smarter logic.
|
||||||
|
pipe := s.rdb.Pipeline()
|
||||||
|
for _, m := range models {
|
||||||
|
m = strings.TrimSpace(m)
|
||||||
|
if m == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
||||||
|
pipe.SAdd(ctx, routeKey, provider.ID)
|
||||||
|
}
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncModel writes a single model metadata record.
|
// SyncModel writes a single model metadata record.
|
||||||
@@ -80,12 +102,13 @@ func (s *SyncService) SyncModel(m *model.Model) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type providerSnapshot struct {
|
type providerSnapshot struct {
|
||||||
ID uint `json:"id"`
|
ID uint `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
APIKey string `json:"api_key"`
|
APIKey string `json:"api_key"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
|
Models []string `json:"models"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type keySnapshot struct {
|
type keySnapshot struct {
|
||||||
@@ -130,20 +153,38 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
pipe := s.rdb.TxPipeline()
|
pipe := s.rdb.TxPipeline()
|
||||||
pipe.Del(ctx, "config:providers", "config:keys", "meta:models")
|
pipe.Del(ctx, "config:providers", "config:keys", "meta:models")
|
||||||
|
|
||||||
|
// Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them)
|
||||||
|
// For MVP, we rely on the fact that we are rebuilding.
|
||||||
|
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
||||||
|
|
||||||
for _, p := range providers {
|
for _, p := range providers {
|
||||||
|
group := normalizeGroup(p.Group)
|
||||||
|
models := strings.Split(p.Models, ",")
|
||||||
|
|
||||||
snap := providerSnapshot{
|
snap := providerSnapshot{
|
||||||
ID: p.ID,
|
ID: p.ID,
|
||||||
Name: p.Name,
|
Name: p.Name,
|
||||||
Type: p.Type,
|
Type: p.Type,
|
||||||
BaseURL: p.BaseURL,
|
BaseURL: p.BaseURL,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
Group: normalizeGroup(p.Group),
|
Group: group,
|
||||||
|
Models: models,
|
||||||
}
|
}
|
||||||
payload, err := json.Marshal(snap)
|
payload, err := json.Marshal(snap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
|
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
|
||||||
}
|
}
|
||||||
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
|
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
|
||||||
|
|
||||||
|
// Rebuild Routing Table
|
||||||
|
for _, m := range models {
|
||||||
|
m = strings.TrimSpace(m)
|
||||||
|
if m == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
||||||
|
pipe.SAdd(ctx, routeKey, p.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
|
|||||||
Reference in New Issue
Block a user