diff --git a/internal/api/handler.go b/internal/api/handler.go index f90cad7..22d43cd 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -74,6 +74,7 @@ func (h *Handler) CreateProvider(c *gin.Context) { BaseURL: req.BaseURL, APIKey: req.APIKey, Group: group, + Models: strings.Join(req.Models, ","), } if err := h.db.Create(&provider).Error; err != nil { diff --git a/internal/dto/provider.go b/internal/dto/provider.go index 69df49d..9c33c3c 100644 --- a/internal/dto/provider.go +++ b/internal/dto/provider.go @@ -2,9 +2,10 @@ package dto // ProviderDTO defines inbound payload for provider creation/update. type ProviderDTO struct { - Name string `json:"name"` - Type string `json:"type"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - Group string `json:"group"` + Name string `json:"name"` + Type string `json:"type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + Group string `json:"group"` + Models []string `json:"models"` // List of supported model names } diff --git a/internal/model/models.go b/internal/model/models.go index 0e42dcf..ca4306d 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -18,6 +18,7 @@ type Provider struct { BaseURL string `json:"base_url"` APIKey string `json:"api_key"` Group string `gorm:"default:'default'" json:"group"` // routing group/tier + Models string `json:"models"` // comma-separated list of supported models (e.g. "gpt-4,gpt-3.5-turbo") } type Key struct { diff --git a/internal/service/sync.go b/internal/service/sync.go index 6542bf1..f75a1ce 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -49,18 +49,40 @@ func (s *SyncService) SyncKey(key *model.Key) error { return nil } -// SyncProvider writes a single provider into Redis hash storage. +// SyncProvider writes a single provider into Redis hash storage and updates routing tables. func (s *SyncService) SyncProvider(provider *model.Provider) error { ctx := context.Background() + group := normalizeGroup(provider.Group) + models := strings.Split(provider.Models, ",") + snap := providerSnapshot{ ID: provider.ID, Name: provider.Name, Type: provider.Type, BaseURL: provider.BaseURL, APIKey: provider.APIKey, - Group: normalizeGroup(provider.Group), + Group: group, + Models: models, } - return s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap) + + // 1. Update Provider Config + if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil { + return err + } + + // 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id) + // Note: This is an additive operation. Removing models requires full sync or smarter logic. + pipe := s.rdb.Pipeline() + for _, m := range models { + m = strings.TrimSpace(m) + if m == "" { + continue + } + routeKey := fmt.Sprintf("route:group:%s:%s", group, m) + pipe.SAdd(ctx, routeKey, provider.ID) + } + _, err := pipe.Exec(ctx) + return err } // SyncModel writes a single model metadata record. @@ -80,12 +102,13 @@ func (s *SyncService) SyncModel(m *model.Model) error { } type providerSnapshot struct { - ID uint `json:"id"` - Name string `json:"name"` - Type string `json:"type"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - Group string `json:"group"` + ID uint `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + Group string `json:"group"` + Models []string `json:"models"` } type keySnapshot struct { @@ -130,20 +153,38 @@ func (s *SyncService) SyncAll(db *gorm.DB) error { pipe := s.rdb.TxPipeline() pipe.Del(ctx, "config:providers", "config:keys", "meta:models") + // Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them) + // For MVP, we rely on the fact that we are rebuilding. + // Ideally, we should scan "route:group:*" and del, but let's just rebuild. + for _, p := range providers { + group := normalizeGroup(p.Group) + models := strings.Split(p.Models, ",") + snap := providerSnapshot{ ID: p.ID, Name: p.Name, Type: p.Type, BaseURL: p.BaseURL, APIKey: p.APIKey, - Group: normalizeGroup(p.Group), + Group: group, + Models: models, } payload, err := json.Marshal(snap) if err != nil { return fmt.Errorf("marshal provider %d: %w", p.ID, err) } pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload) + + // Rebuild Routing Table + for _, m := range models { + m = strings.TrimSpace(m) + if m == "" { + continue + } + routeKey := fmt.Sprintf("route:group:%s:%s", group, m) + pipe.SAdd(ctx, routeKey, p.ID) + } } for _, k := range keys {