From f0fe9f0dadc87e570950f715acbcc76134b67a0c Mon Sep 17 00:00:00 2001 From: zenfun Date: Sun, 28 Dec 2025 02:49:54 +0800 Subject: [PATCH] feat(api): add OAuth token fields and new provider types support Add support for OAuth-based authentication with access/refresh tokens and expiration tracking for API keys. Extend provider groups with static headers configuration and headers profile options. Changes include: - Add AccessToken, RefreshToken, ExpiresAt, AccountID, ProjectID to APIKey model - Add StaticHeaders and HeadersProfile to ProviderGroup model - Add TokenRefresh configuration for background token management - Support new provider types: ClaudeCode, Codex, GeminiCLI, Antigravity - Update sync service to include new fields in provider snapshots --- internal/api/api_key_handler.go | 44 ++++++++++++++++++---- internal/api/provider_group_handler.go | 10 +++++ internal/config/config.go | 22 +++++++++++ internal/dto/api_key.go | 19 ++++++---- internal/dto/provider_group.go | 2 + internal/model/provider_group.go | 21 +++++++---- internal/service/provider_group_manager.go | 22 ++++++++++- internal/service/sync.go | 14 +++++++ 8 files changed, 132 insertions(+), 22 deletions(-) diff --git a/internal/api/api_key_handler.go b/internal/api/api_key_handler.go index 102958f..7b5ea20 100644 --- a/internal/api/api_key_handler.go +++ b/internal/api/api_key_handler.go @@ -49,12 +49,20 @@ func (h *Handler) CreateAPIKey(c *gin.Context) { } key := model.APIKey{ - GroupID: req.GroupID, - APIKey: apiKey, - Weight: normalizeWeight(req.Weight), - Status: status, - AutoBan: autoBan, - BanReason: strings.TrimSpace(req.BanReason), + GroupID: req.GroupID, + APIKey: apiKey, + AccessToken: strings.TrimSpace(req.AccessToken), + RefreshToken: strings.TrimSpace(req.RefreshToken), + AccountID: strings.TrimSpace(req.AccountID), + ProjectID: strings.TrimSpace(req.ProjectID), + Weight: normalizeWeight(req.Weight), + Status: status, + AutoBan: autoBan, + BanReason: strings.TrimSpace(req.BanReason), + } + if !req.ExpiresAt.IsZero() { + tu := req.ExpiresAt.UTC() + key.ExpiresAt = &tu } if !req.BanUntil.IsZero() { tu := req.BanUntil.UTC() @@ -180,6 +188,22 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) { if strings.TrimSpace(req.APIKey) != "" { update["api_key"] = strings.TrimSpace(req.APIKey) } + if strings.TrimSpace(req.AccessToken) != "" { + update["access_token"] = strings.TrimSpace(req.AccessToken) + } + if strings.TrimSpace(req.RefreshToken) != "" { + update["refresh_token"] = strings.TrimSpace(req.RefreshToken) + } + if !req.ExpiresAt.IsZero() { + tu := req.ExpiresAt.UTC() + update["expires_at"] = &tu + } + if strings.TrimSpace(req.AccountID) != "" { + update["account_id"] = strings.TrimSpace(req.AccountID) + } + if strings.TrimSpace(req.ProjectID) != "" { + update["project_id"] = strings.TrimSpace(req.ProjectID) + } if req.Weight > 0 { update["weight"] = normalizeWeight(req.Weight) } @@ -199,11 +223,17 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) { if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { update["ban_until"] = nil } - if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" { + if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" || strings.TrimSpace(req.AccessToken) != "" || strings.TrimSpace(req.RefreshToken) != "" { nextKey := key if v, ok := update["api_key"].(string); ok { nextKey.APIKey = v } + if v, ok := update["access_token"].(string); ok { + nextKey.AccessToken = v + } + if v, ok := update["refresh_token"].(string); ok { + nextKey.RefreshToken = v + } if req.GroupID != 0 { nextKey.GroupID = req.GroupID } diff --git a/internal/api/provider_group_handler.go b/internal/api/provider_group_handler.go index 302048d..a432a11 100644 --- a/internal/api/provider_group_handler.go +++ b/internal/api/provider_group_handler.go @@ -42,6 +42,8 @@ func (h *Handler) CreateProviderGroup(c *gin.Context) { BaseURL: strings.TrimSpace(req.BaseURL), GoogleProject: strings.TrimSpace(req.GoogleProject), GoogleLocation: strings.TrimSpace(req.GoogleLocation), + StaticHeaders: strings.TrimSpace(req.StaticHeaders), + HeadersProfile: strings.TrimSpace(req.HeadersProfile), Models: strings.Join(req.Models, ","), Status: strings.TrimSpace(req.Status), } @@ -167,6 +169,12 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) { if strings.TrimSpace(req.GoogleLocation) != "" { next.GoogleLocation = strings.TrimSpace(req.GoogleLocation) } + if strings.TrimSpace(req.StaticHeaders) != "" { + next.StaticHeaders = strings.TrimSpace(req.StaticHeaders) + } + if strings.TrimSpace(req.HeadersProfile) != "" { + next.HeadersProfile = strings.TrimSpace(req.HeadersProfile) + } if req.Models != nil { next.Models = strings.Join(req.Models, ",") } @@ -184,6 +192,8 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) { group.BaseURL = normalized.BaseURL group.GoogleProject = normalized.GoogleProject group.GoogleLocation = normalized.GoogleLocation + group.StaticHeaders = normalized.StaticHeaders + group.HeadersProfile = normalized.HeadersProfile group.Models = normalized.Models group.Status = normalized.Status diff --git a/internal/config/config.go b/internal/config/config.go index a07031c..b8e9548 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ type Config struct { Quota QuotaConfig Internal InternalConfig SyncOutbox SyncOutboxConfig + TokenRefresh TokenRefreshConfig } type ServerConfig struct { @@ -80,6 +81,13 @@ type SyncOutboxConfig struct { MaxRetries int } +type TokenRefreshConfig struct { + IntervalSeconds int + RefreshSkewSeconds int + BatchSize int + MaxRetries int +} + func Load() (*Config, error) { v := viper.New() @@ -111,6 +119,10 @@ func Load() (*Config, error) { v.SetDefault("sync_outbox.interval_seconds", 5) v.SetDefault("sync_outbox.batch_size", 200) v.SetDefault("sync_outbox.max_retries", 10) + v.SetDefault("token_refresh.interval_seconds", 1800) + v.SetDefault("token_refresh.refresh_skew_seconds", 3000) + v.SetDefault("token_refresh.batch_size", 200) + v.SetDefault("token_refresh.max_retries", 3) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() @@ -143,6 +155,10 @@ func Load() (*Config, error) { _ = v.BindEnv("sync_outbox.interval_seconds", "EZ_SYNC_OUTBOX_INTERVAL_SECONDS") _ = v.BindEnv("sync_outbox.batch_size", "EZ_SYNC_OUTBOX_BATCH_SIZE") _ = v.BindEnv("sync_outbox.max_retries", "EZ_SYNC_OUTBOX_MAX_RETRIES") + _ = v.BindEnv("token_refresh.interval_seconds", "EZ_TOKEN_REFRESH_INTERVAL_SECONDS") + _ = v.BindEnv("token_refresh.refresh_skew_seconds", "EZ_TOKEN_REFRESH_SKEW_SECONDS") + _ = v.BindEnv("token_refresh.batch_size", "EZ_TOKEN_REFRESH_BATCH_SIZE") + _ = v.BindEnv("token_refresh.max_retries", "EZ_TOKEN_REFRESH_MAX_RETRIES") if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" { v.SetConfigFile(configFile) @@ -208,6 +224,12 @@ func Load() (*Config, error) { BatchSize: v.GetInt("sync_outbox.batch_size"), MaxRetries: v.GetInt("sync_outbox.max_retries"), }, + TokenRefresh: TokenRefreshConfig{ + IntervalSeconds: v.GetInt("token_refresh.interval_seconds"), + RefreshSkewSeconds: v.GetInt("token_refresh.refresh_skew_seconds"), + BatchSize: v.GetInt("token_refresh.batch_size"), + MaxRetries: v.GetInt("token_refresh.max_retries"), + }, } return cfg, nil diff --git a/internal/dto/api_key.go b/internal/dto/api_key.go index 53ea860..d6e2d5d 100644 --- a/internal/dto/api_key.go +++ b/internal/dto/api_key.go @@ -4,11 +4,16 @@ import "time" // APIKeyDTO defines inbound payload for API key creation/update. type APIKeyDTO struct { - GroupID uint `json:"group_id"` - APIKey string `json:"api_key"` - Weight int `json:"weight,omitempty"` - Status string `json:"status"` - AutoBan *bool `json:"auto_ban,omitempty"` - BanReason string `json:"ban_reason,omitempty"` - BanUntil time.Time `json:"ban_until,omitempty"` + GroupID uint `json:"group_id"` + APIKey string `json:"api_key"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + AccountID string `json:"account_id,omitempty"` + ProjectID string `json:"project_id,omitempty"` + Weight int `json:"weight,omitempty"` + Status string `json:"status"` + AutoBan *bool `json:"auto_ban,omitempty"` + BanReason string `json:"ban_reason,omitempty"` + BanUntil time.Time `json:"ban_until,omitempty"` } diff --git a/internal/dto/provider_group.go b/internal/dto/provider_group.go index 2358cb5..9dfa124 100644 --- a/internal/dto/provider_group.go +++ b/internal/dto/provider_group.go @@ -7,6 +7,8 @@ type ProviderGroupDTO struct { BaseURL string `json:"base_url"` GoogleProject string `json:"google_project,omitempty"` GoogleLocation string `json:"google_location,omitempty"` + StaticHeaders string `json:"static_headers,omitempty"` + HeadersProfile string `json:"headers_profile,omitempty"` Models []string `json:"models"` Status string `json:"status"` } diff --git a/internal/model/provider_group.go b/internal/model/provider_group.go index 37e1e0c..dd3f406 100644 --- a/internal/model/provider_group.go +++ b/internal/model/provider_group.go @@ -14,6 +14,8 @@ type ProviderGroup struct { BaseURL string `gorm:"size:512;not null" json:"base_url"` GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` + StaticHeaders string `gorm:"type:text" json:"static_headers,omitempty"` + HeadersProfile string `gorm:"size:64" json:"headers_profile,omitempty"` Models string `json:"models"` // comma-separated list of supported models Status string `gorm:"size:50;default:'active'" json:"status"` } @@ -21,11 +23,16 @@ type ProviderGroup struct { // APIKey represents a credential within a provider group. type APIKey struct { gorm.Model - GroupID uint `gorm:"not null;index" json:"group_id"` - APIKey string `gorm:"not null" json:"api_key"` - Weight int `gorm:"default:1" json:"weight"` - Status string `gorm:"size:50;default:'active'" json:"status"` - AutoBan bool `gorm:"default:true" json:"auto_ban"` - BanReason string `gorm:"size:255" json:"ban_reason"` - BanUntil *time.Time `json:"ban_until"` + GroupID uint `gorm:"not null;index" json:"group_id"` + APIKey string `gorm:"not null" json:"api_key"` + AccessToken string `gorm:"type:text" json:"access_token,omitempty"` + RefreshToken string `gorm:"type:text" json:"-"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + AccountID string `gorm:"size:255" json:"account_id,omitempty"` + ProjectID string `gorm:"size:255" json:"project_id,omitempty"` + Weight int `gorm:"default:1" json:"weight"` + Status string `gorm:"size:50;default:'active'" json:"status"` + AutoBan bool `gorm:"default:true" json:"auto_ban"` + BanReason string `gorm:"size:255" json:"ban_reason"` + BanUntil *time.Time `json:"ban_until"` } diff --git a/internal/service/provider_group_manager.go b/internal/service/provider_group_manager.go index 82b01fc..d8b0aef 100644 --- a/internal/service/provider_group_manager.go +++ b/internal/service/provider_group_manager.go @@ -33,16 +33,30 @@ func (m *ProviderGroupManager) NormalizeGroup(group model.ProviderGroup) (model. group.BaseURL = strings.TrimSpace(group.BaseURL) group.GoogleProject = strings.TrimSpace(group.GoogleProject) group.GoogleLocation = strings.TrimSpace(group.GoogleLocation) + group.StaticHeaders = strings.TrimSpace(group.StaticHeaders) + group.HeadersProfile = strings.TrimSpace(group.HeadersProfile) switch ptype { case provider.TypeOpenAI: if group.BaseURL == "" { group.BaseURL = "https://api.openai.com/v1" } - case provider.TypeAnthropic, provider.TypeClaude: + case provider.TypeAnthropic, provider.TypeClaude, provider.TypeClaudeCode: if group.BaseURL == "" { group.BaseURL = "https://api.anthropic.com" } + case provider.TypeCodex: + if group.BaseURL == "" { + group.BaseURL = "https://chatgpt.com" + } + case provider.TypeGeminiCLI: + if group.BaseURL == "" { + group.BaseURL = "https://cloudcode-pa.googleapis.com" + } + case provider.TypeAntigravity: + if group.BaseURL == "" { + group.BaseURL = "https://daily-cloudcode-pa.googleapis.com" + } case provider.TypeCompatible: if group.BaseURL == "" { return model.ProviderGroup{}, fmt.Errorf("base_url required for compatible providers") @@ -72,8 +86,14 @@ func (m *ProviderGroupManager) ValidateAPIKey(group model.ProviderGroup, key mod return fmt.Errorf("provider group type required") } apiKey := strings.TrimSpace(key.APIKey) + accessToken := strings.TrimSpace(key.AccessToken) switch { + case ptype == provider.TypeCodex || ptype == provider.TypeGeminiCLI || ptype == provider.TypeAntigravity || ptype == provider.TypeClaudeCode: + if accessToken == "" { + return fmt.Errorf("access_token required") + } + return nil case provider.IsVertexFamily(ptype): // Vertex uses ADC; api_key can be empty. return nil diff --git a/internal/service/sync.go b/internal/service/sync.go index 9900548..ba35817 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -278,8 +278,14 @@ type providerSnapshot struct { Type string `json:"type"` BaseURL string `json:"base_url"` APIKey string `json:"api_key"` + AccessToken string `json:"access_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + AccountID string `json:"account_id,omitempty"` + ProjectID string `json:"project_id,omitempty"` GoogleProject string `json:"google_project,omitempty"` GoogleLocation string `json:"google_location,omitempty"` + StaticHeaders string `json:"static_headers,omitempty"` + HeadersProfile string `json:"headers_profile,omitempty"` GroupID uint `json:"group_id,omitempty"` Group string `json:"group"` Models []string `json:"models"` @@ -333,8 +339,13 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip Type: strings.TrimSpace(g.Type), BaseURL: strings.TrimSpace(g.BaseURL), APIKey: strings.TrimSpace(k.APIKey), + AccessToken: strings.TrimSpace(k.AccessToken), GoogleProject: strings.TrimSpace(g.GoogleProject), GoogleLocation: strings.TrimSpace(g.GoogleLocation), + StaticHeaders: strings.TrimSpace(g.StaticHeaders), + HeadersProfile: strings.TrimSpace(g.HeadersProfile), + AccountID: strings.TrimSpace(k.AccountID), + ProjectID: strings.TrimSpace(k.ProjectID), GroupID: g.ID, Group: groupName, Models: models, @@ -343,6 +354,9 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip AutoBan: k.AutoBan, BanReason: strings.TrimSpace(k.BanReason), } + if k.ExpiresAt != nil { + snap.ExpiresAt = k.ExpiresAt.UTC().Unix() + } if k.BanUntil != nil { snap.BanUntil = k.BanUntil.UTC().Unix() }