Files
ez-api/internal/service/sync.go
zenfun 305f2ebf18 feat(provider): add update endpoint and enforce status checks
Add `PUT /admin/providers/{id}` endpoint to allow updating provider
configurations, including status and ban details. Update synchronization
logic to exclude inactive or banned providers from routing tables to
ensure traffic is not routed to them.
2025-12-12 23:44:52 +08:00

318 lines
8.7 KiB
Go

package service
import (
"context"
"fmt"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/util"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type SyncService struct {
rdb *redis.Client
}
func NewSyncService(rdb *redis.Client) *SyncService {
return &SyncService{rdb: rdb}
}
// SyncKey writes a single key into Redis without rebuilding the entire snapshot.
func (s *SyncService) SyncKey(key *model.Key) error {
ctx := context.Background()
tokenHash := key.TokenHash
if strings.TrimSpace(tokenHash) == "" {
tokenHash = util.HashToken(key.KeySecret) // backward compatibility
}
if strings.TrimSpace(tokenHash) == "" {
return fmt.Errorf("token hash missing for key %d", key.ID)
}
fields := map[string]interface{}{
"master_id": key.MasterID,
"issued_at_epoch": key.IssuedAtEpoch,
"status": key.Status,
"group": key.Group,
"scopes": key.Scopes,
}
if err := s.rdb.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), fields).Err(); err != nil {
return fmt.Errorf("write auth token: %w", err)
}
return nil
}
// SyncMaster writes master metadata into Redis used by the balancer for validation.
func (s *SyncService) SyncMaster(master *model.Master) error {
ctx := context.Background()
key := fmt.Sprintf("auth:master:%d", master.ID)
if err := s.rdb.HSet(ctx, key, map[string]interface{}{
"epoch": master.Epoch,
"status": master.Status,
"global_qps": master.GlobalQPS,
}).Err(); err != nil {
return fmt.Errorf("write master metadata: %w", err)
}
return nil
}
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
func (s *SyncService) SyncProvider(provider *model.Provider) error {
ctx := context.Background()
group := normalizeGroup(provider.Group)
models := strings.Split(provider.Models, ",")
snap := providerSnapshot{
ID: provider.ID,
Name: provider.Name,
Type: provider.Type,
BaseURL: provider.BaseURL,
APIKey: provider.APIKey,
Group: group,
Models: models,
Status: normalizeStatus(provider.Status),
AutoBan: provider.AutoBan,
BanReason: provider.BanReason,
}
if provider.BanUntil != nil {
snap.BanUntil = provider.BanUntil.UTC().Unix()
}
// 1. Update Provider Config
if err := s.hsetJSON(ctx, "config:providers", fmt.Sprintf("%d", provider.ID), snap); err != nil {
return err
}
// 2. Update Routing Table: route:group:{group}:{model} -> Set(provider_id)
// Note: This is an additive operation. Removing models requires full sync or smarter logic.
pipe := s.rdb.Pipeline()
for _, m := range models {
m = strings.TrimSpace(m)
if m == "" {
continue
}
if snap.Status != "active" {
continue
}
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
continue
}
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
pipe.SAdd(ctx, routeKey, provider.ID)
}
_, err := pipe.Exec(ctx)
return err
}
// SyncModel writes a single model metadata record.
func (s *SyncService) SyncModel(m *model.Model) error {
ctx := context.Background()
snap := modelSnapshot{
Name: m.Name,
ContextWindow: m.ContextWindow,
CostPerToken: m.CostPerToken,
SupportsVision: m.SupportsVision,
SupportsFunction: m.SupportsFunctions,
SupportsToolChoice: m.SupportsToolChoice,
SupportsFIM: m.SupportsFIM,
MaxOutputTokens: m.MaxOutputTokens,
}
return s.hsetJSON(ctx, "meta:models", snap.Name, snap)
}
type providerSnapshot struct {
ID uint `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
Group string `json:"group"`
Models []string `json:"models"`
Status string `json:"status"`
AutoBan bool `json:"auto_ban"`
BanReason string `json:"ban_reason,omitempty"`
BanUntil int64 `json:"ban_until,omitempty"` // unix seconds
}
// keySnapshot is no longer needed as we write directly to auth:token:*
type modelSnapshot struct {
Name string `json:"name"`
ContextWindow int `json:"context_window"`
CostPerToken float64 `json:"cost_per_token"`
SupportsVision bool `json:"supports_vision"`
SupportsFunction bool `json:"supports_functions"`
SupportsToolChoice bool `json:"supports_tool_choice"`
SupportsFIM bool `json:"supports_fim"`
MaxOutputTokens int `json:"max_output_tokens"`
}
// SyncAll rebuilds Redis hashes from the database; use for cold starts or forced refreshes.
func (s *SyncService) SyncAll(db *gorm.DB) error {
ctx := context.Background()
var providers []model.Provider
if err := db.Find(&providers).Error; err != nil {
return fmt.Errorf("load providers: %w", err)
}
var keys []model.Key
if err := db.Find(&keys).Error; err != nil {
return fmt.Errorf("load keys: %w", err)
}
var masters []model.Master
if err := db.Find(&masters).Error; err != nil {
return fmt.Errorf("load masters: %w", err)
}
var models []model.Model
if err := db.Find(&models).Error; err != nil {
return fmt.Errorf("load models: %w", err)
}
pipe := s.rdb.TxPipeline()
pipe.Del(ctx, "config:providers", "config:keys", "meta:models")
// Also clear master keys
var masterKeys []string
iter := s.rdb.Scan(ctx, 0, "auth:master:*", 0).Iterator()
for iter.Next(ctx) {
masterKeys = append(masterKeys, iter.Val())
}
if err := iter.Err(); err != nil {
return fmt.Errorf("scan master keys: %w", err)
}
if len(masterKeys) > 0 {
pipe.Del(ctx, masterKeys...)
}
// Clear old routing tables (pattern scan would be better in prod, but keys are predictable if we knew them)
// For MVP, we rely on the fact that we are rebuilding.
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
for _, p := range providers {
group := normalizeGroup(p.Group)
models := strings.Split(p.Models, ",")
snap := providerSnapshot{
ID: p.ID,
Name: p.Name,
Type: p.Type,
BaseURL: p.BaseURL,
APIKey: p.APIKey,
Group: group,
Models: models,
Status: normalizeStatus(p.Status),
AutoBan: p.AutoBan,
BanReason: p.BanReason,
}
if p.BanUntil != nil {
snap.BanUntil = p.BanUntil.UTC().Unix()
}
payload, err := sonic.Marshal(snap)
if err != nil {
return fmt.Errorf("marshal provider %d: %w", p.ID, err)
}
pipe.HSet(ctx, "config:providers", fmt.Sprintf("%d", p.ID), payload)
// Rebuild Routing Table
for _, m := range models {
m = strings.TrimSpace(m)
if m == "" {
continue
}
if snap.Status != "active" {
continue
}
if snap.BanUntil > 0 && time.Now().Unix() < snap.BanUntil {
continue
}
routeKey := fmt.Sprintf("route:group:%s:%s", group, m)
pipe.SAdd(ctx, routeKey, p.ID)
}
}
for _, k := range keys {
tokenHash := strings.TrimSpace(k.TokenHash)
if tokenHash == "" {
tokenHash = util.HashToken(k.KeySecret) // fallback for legacy rows
}
if tokenHash == "" {
return fmt.Errorf("token hash missing for key %d", k.ID)
}
pipe.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), map[string]interface{}{
"master_id": k.MasterID,
"issued_at_epoch": k.IssuedAtEpoch,
"status": k.Status,
"group": k.Group,
"scopes": k.Scopes,
})
}
for _, m := range masters {
pipe.HSet(ctx, fmt.Sprintf("auth:master:%d", m.ID), map[string]interface{}{
"epoch": m.Epoch,
"status": m.Status,
"global_qps": m.GlobalQPS,
})
}
for _, m := range models {
snap := modelSnapshot{
Name: m.Name,
ContextWindow: m.ContextWindow,
CostPerToken: m.CostPerToken,
SupportsVision: m.SupportsVision,
SupportsFunction: m.SupportsFunctions,
SupportsToolChoice: m.SupportsToolChoice,
SupportsFIM: m.SupportsFIM,
MaxOutputTokens: m.MaxOutputTokens,
}
payload, err := sonic.Marshal(snap)
if err != nil {
return fmt.Errorf("marshal model %s: %w", m.Name, err)
}
pipe.HSet(ctx, "meta:models", snap.Name, payload)
}
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("write snapshots: %w", err)
}
return nil
}
func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val interface{}) error {
payload, err := sonic.Marshal(val)
if err != nil {
return fmt.Errorf("marshal %s:%s: %w", key, field, err)
}
if err := s.rdb.HSet(ctx, key, field, payload).Err(); err != nil {
return fmt.Errorf("write %s:%s: %w", key, field, err)
}
return nil
}
func normalizeGroup(group string) string {
if strings.TrimSpace(group) == "" {
return "default"
}
return group
}
func normalizeStatus(status string) string {
st := strings.ToLower(strings.TrimSpace(status))
if st == "" {
return "active"
}
switch st {
case "active", "auto_disabled", "manual_disabled":
return st
default:
return st
}
}