mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Add `PUT /admin/providers/{id}` endpoint to allow updating provider
configurations, including status and ban details. Update synchronization
logic to exclude inactive or banned providers from routing tables to
ensure traffic is not routed to them.
318 lines
8.7 KiB
Go
318 lines
8.7 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/ez-api/ez-api/internal/util"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type SyncService struct {
|
|
rdb *redis.Client
|
|
}
|
|
|
|
func NewSyncService(rdb *redis.Client) *SyncService {
|
|
return &SyncService{rdb: rdb}
|
|
}
|
|
|
|
// SyncKey writes a single key into Redis without rebuilding the entire snapshot.
|
|
func (s *SyncService) SyncKey(key *model.Key) error {
|
|
ctx := context.Background()
|
|
tokenHash := key.TokenHash
|
|
if strings.TrimSpace(tokenHash) == "" {
|
|
tokenHash = util.HashToken(key.KeySecret) // backward compatibility
|
|
}
|
|
if strings.TrimSpace(tokenHash) == "" {
|
|
return fmt.Errorf("token hash missing for key %d", key.ID)
|
|
}
|
|
|
|
fields := map[string]interface{}{
|
|
"master_id": key.MasterID,
|
|
"issued_at_epoch": key.IssuedAtEpoch,
|
|
"status": key.Status,
|
|
"group": key.Group,
|
|
"scopes": key.Scopes,
|
|
}
|
|
if err := s.rdb.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), fields).Err(); err != nil {
|
|
return fmt.Errorf("write auth token: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SyncMaster writes master metadata into Redis used by the balancer for validation.
|
|
func (s *SyncService) SyncMaster(master *model.Master) error {
|
|
ctx := context.Background()
|
|
key := fmt.Sprintf("auth:master:%d", master.ID)
|
|
if err := s.rdb.HSet(ctx, key, map[string]interface{}{
|
|
"epoch": master.Epoch,
|
|
"status": master.Status,
|
|
"global_qps": master.GlobalQPS,
|
|
}).Err(); err != nil {
|
|
return fmt.Errorf("write master metadata: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
|
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
|
ctx := context.Background()
|
|
group := normalizeGroup(provider.Group)
|
|
models := strings.Split(provider.Models, ",")
|
|
|
|
snap := providerSnapshot{
|
|
ID: provider.ID,
|
|
Name: provider.Name,
|
|
Type: provider.Type,
|
|
BaseURL: provider.BaseURL,
|
|
APIKey: provider.APIKey,
|
|
Group: group,
|
|
Models: models,
|
|
Status: normalizeStatus(provider.Status),
|
|
AutoBan: provider.AutoBan,
|
|
BanReason: provider.BanReason,
|
|
}
|
|
if provider.BanUntil != nil {
|
|
snap.BanUntil = provider.BanUntil.UTC().Unix()
|
|
}
|
|
|
|
// 1. Update Provider Config
|
|
if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id)
|
|
// Note: This is an additive operation. Removing models requires full sync or smarter logic.
|
|
pipe := s.rdb.Pipeline()
|
|
for _, m := range models {
|
|
m = strings.TrimSpace(m)
|
|
if m == "" {
|
|
continue
|
|
}
|
|
if snap.Status != "active" {
|
|
continue
|
|
}
|
|
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
|
continue
|
|
}
|
|
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
|
pipe.SAdd(ctx, routeKey, provider.ID)
|
|
}
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// SyncModel writes a single model metadata record.
|
|
func (s *SyncService) SyncModel(m *model.Model) error {
|
|
ctx := context.Background()
|
|
snap := modelSnapshot{
|
|
Name: m.Name,
|
|
ContextWindow: m.ContextWindow,
|
|
CostPerToken: m.CostPerToken,
|
|
SupportsVision: m.SupportsVision,
|
|
SupportsFunction: m.SupportsFunctions,
|
|
SupportsToolChoice: m.SupportsToolChoice,
|
|
SupportsFIM: m.SupportsFIM,
|
|
MaxOutputTokens: m.MaxOutputTokens,
|
|
}
|
|
return s.hsetJSON(ctx, "meta:models", snap.Name, snap)
|
|
}
|
|
|
|
type providerSnapshot struct {
|
|
ID uint `json:"id"`
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
BaseURL string `json:"base_url"`
|
|
APIKey string `json:"api_key"`
|
|
Group string `json:"group"`
|
|
Models []string `json:"models"`
|
|
Status string `json:"status"`
|
|
AutoBan bool `json:"auto_ban"`
|
|
BanReason string `json:"ban_reason,omitempty"`
|
|
BanUntil int64 `json:"ban_until,omitempty"` // unix seconds
|
|
}
|
|
|
|
// keySnapshot is no longer needed as we write directly to auth:token:*
|
|
|
|
type modelSnapshot struct {
|
|
Name string `json:"name"`
|
|
ContextWindow int `json:"context_window"`
|
|
CostPerToken float64 `json:"cost_per_token"`
|
|
SupportsVision bool `json:"supports_vision"`
|
|
SupportsFunction bool `json:"supports_functions"`
|
|
SupportsToolChoice bool `json:"supports_tool_choice"`
|
|
SupportsFIM bool `json:"supports_fim"`
|
|
MaxOutputTokens int `json:"max_output_tokens"`
|
|
}
|
|
|
|
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
|
|
func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|
ctx := context.Background()
|
|
|
|
var providers []model.Provider
|
|
if err := db.Find(&providers).Error; err != nil {
|
|
return fmt.Errorf("load providers: %w", err)
|
|
}
|
|
|
|
var keys []model.Key
|
|
if err := db.Find(&keys).Error; err != nil {
|
|
return fmt.Errorf("load keys: %w", err)
|
|
}
|
|
|
|
var masters []model.Master
|
|
if err := db.Find(&masters).Error; err != nil {
|
|
return fmt.Errorf("load masters: %w", err)
|
|
}
|
|
|
|
var models []model.Model
|
|
if err := db.Find(&models).Error; err != nil {
|
|
return fmt.Errorf("load models: %w", err)
|
|
}
|
|
|
|
pipe := s.rdb.TxPipeline()
|
|
pipe.Del(ctx, "config:providers", "config:keys", "meta:models")
|
|
// Also clear master keys
|
|
var masterKeys []string
|
|
iter := s.rdb.Scan(ctx, 0, "auth:master:*", 0).Iterator()
|
|
for iter.Next(ctx) {
|
|
masterKeys = append(masterKeys, iter.Val())
|
|
}
|
|
if err := iter.Err(); err != nil {
|
|
return fmt.Errorf("scan master keys: %w", err)
|
|
}
|
|
if len(masterKeys) > 0 {
|
|
pipe.Del(ctx, masterKeys...)
|
|
}
|
|
|
|
// Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them)
|
|
// For MVP, we rely on the fact that we are rebuilding.
|
|
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
|
|
|
for _, p := range providers {
|
|
group := normalizeGroup(p.Group)
|
|
models := strings.Split(p.Models, ",")
|
|
|
|
snap := providerSnapshot{
|
|
ID: p.ID,
|
|
Name: p.Name,
|
|
Type: p.Type,
|
|
BaseURL: p.BaseURL,
|
|
APIKey: p.APIKey,
|
|
Group: group,
|
|
Models: models,
|
|
Status: normalizeStatus(p.Status),
|
|
AutoBan: p.AutoBan,
|
|
BanReason: p.BanReason,
|
|
}
|
|
if p.BanUntil != nil {
|
|
snap.BanUntil = p.BanUntil.UTC().Unix()
|
|
}
|
|
payload, err := sonic.Marshal(snap)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
|
|
}
|
|
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
|
|
|
|
// Rebuild Routing Table
|
|
for _, m := range models {
|
|
m = strings.TrimSpace(m)
|
|
if m == "" {
|
|
continue
|
|
}
|
|
if snap.Status != "active" {
|
|
continue
|
|
}
|
|
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
|
|
continue
|
|
}
|
|
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
|
|
pipe.SAdd(ctx, routeKey, p.ID)
|
|
}
|
|
}
|
|
|
|
for _, k := range keys {
|
|
tokenHash := strings.TrimSpace(k.TokenHash)
|
|
if tokenHash == "" {
|
|
tokenHash = util.HashToken(k.KeySecret) // fallback for legacy rows
|
|
}
|
|
if tokenHash == "" {
|
|
return fmt.Errorf("token hash missing for key %d", k.ID)
|
|
}
|
|
pipe.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), map[string]interface{}{
|
|
"master_id": k.MasterID,
|
|
"issued_at_epoch": k.IssuedAtEpoch,
|
|
"status": k.Status,
|
|
"group": k.Group,
|
|
"scopes": k.Scopes,
|
|
})
|
|
}
|
|
|
|
for _, m := range masters {
|
|
pipe.HSet(ctx, fmt.Sprintf("auth:master:%d", m.ID), map[string]interface{}{
|
|
"epoch": m.Epoch,
|
|
"status": m.Status,
|
|
"global_qps": m.GlobalQPS,
|
|
})
|
|
}
|
|
|
|
for _, m := range models {
|
|
snap := modelSnapshot{
|
|
Name: m.Name,
|
|
ContextWindow: m.ContextWindow,
|
|
CostPerToken: m.CostPerToken,
|
|
SupportsVision: m.SupportsVision,
|
|
SupportsFunction: m.SupportsFunctions,
|
|
SupportsToolChoice: m.SupportsToolChoice,
|
|
SupportsFIM: m.SupportsFIM,
|
|
MaxOutputTokens: m.MaxOutputTokens,
|
|
}
|
|
payload, err := sonic.Marshal(snap)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal model %s: %w", m.Name, err)
|
|
}
|
|
pipe.HSet(ctx, "meta:models", snap.Name, payload)
|
|
}
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return fmt.Errorf("write snapshots: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val interface{}) error {
|
|
payload, err := sonic.Marshal(val)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal %s:%s: %w", key, field, err)
|
|
}
|
|
if err := s.rdb.HSet(ctx, key, field, payload).Err(); err != nil {
|
|
return fmt.Errorf("write %s:%s: %w", key, field, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeGroup(group string) string {
|
|
if strings.TrimSpace(group) == "" {
|
|
return "default"
|
|
}
|
|
return group
|
|
}
|
|
|
|
func normalizeStatus(status string) string {
|
|
st := strings.ToLower(strings.TrimSpace(status))
|
|
if st == "" {
|
|
return "active"
|
|
}
|
|
switch st {
|
|
case "active", "auto_disabled", "manual_disabled":
|
|
return st
|
|
default:
|
|
return st
|
|
}
|
|
}
|