feat(api): add OAuth token fields and new provider types support

Add support for OAuth-based authentication with access/refresh tokens
and expiration tracking for API keys. Extend provider groups with
static headers configuration and headers profile options.

Changes include:
- Add AccessToken, RefreshToken, ExpiresAt, AccountID, ProjectID to APIKey model
- Add StaticHeaders and HeadersProfile to ProviderGroup model
- Add TokenRefresh configuration for background token management
- Support new provider types: ClaudeCode, Codex, GeminiCLI, Antigravity
- Update sync service to include new fields in provider snapshots
This commit is contained in:
zenfun
2025-12-28 02:49:54 +08:00
parent cca0802620
commit f0fe9f0dad
8 changed files with 132 additions and 22 deletions

View File

@@ -51,11 +51,19 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
key := model.APIKey{ key := model.APIKey{
GroupID: req.GroupID, GroupID: req.GroupID,
APIKey: apiKey, APIKey: apiKey,
AccessToken: strings.TrimSpace(req.AccessToken),
RefreshToken: strings.TrimSpace(req.RefreshToken),
AccountID: strings.TrimSpace(req.AccountID),
ProjectID: strings.TrimSpace(req.ProjectID),
Weight: normalizeWeight(req.Weight), Weight: normalizeWeight(req.Weight),
Status: status, Status: status,
AutoBan: autoBan, AutoBan: autoBan,
BanReason: strings.TrimSpace(req.BanReason), BanReason: strings.TrimSpace(req.BanReason),
} }
if !req.ExpiresAt.IsZero() {
tu := req.ExpiresAt.UTC()
key.ExpiresAt = &tu
}
if !req.BanUntil.IsZero() { if !req.BanUntil.IsZero() {
tu := req.BanUntil.UTC() tu := req.BanUntil.UTC()
key.BanUntil = &tu key.BanUntil = &tu
@@ -180,6 +188,22 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
if strings.TrimSpace(req.APIKey) != "" { if strings.TrimSpace(req.APIKey) != "" {
update["api_key"] = strings.TrimSpace(req.APIKey) update["api_key"] = strings.TrimSpace(req.APIKey)
} }
if strings.TrimSpace(req.AccessToken) != "" {
update["access_token"] = strings.TrimSpace(req.AccessToken)
}
if strings.TrimSpace(req.RefreshToken) != "" {
update["refresh_token"] = strings.TrimSpace(req.RefreshToken)
}
if !req.ExpiresAt.IsZero() {
tu := req.ExpiresAt.UTC()
update["expires_at"] = &tu
}
if strings.TrimSpace(req.AccountID) != "" {
update["account_id"] = strings.TrimSpace(req.AccountID)
}
if strings.TrimSpace(req.ProjectID) != "" {
update["project_id"] = strings.TrimSpace(req.ProjectID)
}
if req.Weight > 0 { if req.Weight > 0 {
update["weight"] = normalizeWeight(req.Weight) update["weight"] = normalizeWeight(req.Weight)
} }
@@ -199,11 +223,17 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" { if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" {
update["ban_until"] = nil update["ban_until"] = nil
} }
if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" { if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" || strings.TrimSpace(req.AccessToken) != "" || strings.TrimSpace(req.RefreshToken) != "" {
nextKey := key nextKey := key
if v, ok := update["api_key"].(string); ok { if v, ok := update["api_key"].(string); ok {
nextKey.APIKey = v nextKey.APIKey = v
} }
if v, ok := update["access_token"].(string); ok {
nextKey.AccessToken = v
}
if v, ok := update["refresh_token"].(string); ok {
nextKey.RefreshToken = v
}
if req.GroupID != 0 { if req.GroupID != 0 {
nextKey.GroupID = req.GroupID nextKey.GroupID = req.GroupID
} }

View File

@@ -42,6 +42,8 @@ func (h *Handler) CreateProviderGroup(c *gin.Context) {
BaseURL: strings.TrimSpace(req.BaseURL), BaseURL: strings.TrimSpace(req.BaseURL),
GoogleProject: strings.TrimSpace(req.GoogleProject), GoogleProject: strings.TrimSpace(req.GoogleProject),
GoogleLocation: strings.TrimSpace(req.GoogleLocation), GoogleLocation: strings.TrimSpace(req.GoogleLocation),
StaticHeaders: strings.TrimSpace(req.StaticHeaders),
HeadersProfile: strings.TrimSpace(req.HeadersProfile),
Models: strings.Join(req.Models, ","), Models: strings.Join(req.Models, ","),
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
} }
@@ -167,6 +169,12 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) {
if strings.TrimSpace(req.GoogleLocation) != "" { if strings.TrimSpace(req.GoogleLocation) != "" {
next.GoogleLocation = strings.TrimSpace(req.GoogleLocation) next.GoogleLocation = strings.TrimSpace(req.GoogleLocation)
} }
if strings.TrimSpace(req.StaticHeaders) != "" {
next.StaticHeaders = strings.TrimSpace(req.StaticHeaders)
}
if strings.TrimSpace(req.HeadersProfile) != "" {
next.HeadersProfile = strings.TrimSpace(req.HeadersProfile)
}
if req.Models != nil { if req.Models != nil {
next.Models = strings.Join(req.Models, ",") next.Models = strings.Join(req.Models, ",")
} }
@@ -184,6 +192,8 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) {
group.BaseURL = normalized.BaseURL group.BaseURL = normalized.BaseURL
group.GoogleProject = normalized.GoogleProject group.GoogleProject = normalized.GoogleProject
group.GoogleLocation = normalized.GoogleLocation group.GoogleLocation = normalized.GoogleLocation
group.StaticHeaders = normalized.StaticHeaders
group.HeadersProfile = normalized.HeadersProfile
group.Models = normalized.Models group.Models = normalized.Models
group.Status = normalized.Status group.Status = normalized.Status

View File

@@ -20,6 +20,7 @@ type Config struct {
Quota QuotaConfig Quota QuotaConfig
Internal InternalConfig Internal InternalConfig
SyncOutbox SyncOutboxConfig SyncOutbox SyncOutboxConfig
TokenRefresh TokenRefreshConfig
} }
type ServerConfig struct { type ServerConfig struct {
@@ -80,6 +81,13 @@ type SyncOutboxConfig struct {
MaxRetries int MaxRetries int
} }
type TokenRefreshConfig struct {
IntervalSeconds int
RefreshSkewSeconds int
BatchSize int
MaxRetries int
}
func Load() (*Config, error) { func Load() (*Config, error) {
v := viper.New() v := viper.New()
@@ -111,6 +119,10 @@ func Load() (*Config, error) {
v.SetDefault("sync_outbox.interval_seconds", 5) v.SetDefault("sync_outbox.interval_seconds", 5)
v.SetDefault("sync_outbox.batch_size", 200) v.SetDefault("sync_outbox.batch_size", 200)
v.SetDefault("sync_outbox.max_retries", 10) v.SetDefault("sync_outbox.max_retries", 10)
v.SetDefault("token_refresh.interval_seconds", 1800)
v.SetDefault("token_refresh.refresh_skew_seconds", 3000)
v.SetDefault("token_refresh.batch_size", 200)
v.SetDefault("token_refresh.max_retries", 3)
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv() v.AutomaticEnv()
@@ -143,6 +155,10 @@ func Load() (*Config, error) {
_ = v.BindEnv("sync_outbox.interval_seconds", "EZ_SYNC_OUTBOX_INTERVAL_SECONDS") _ = v.BindEnv("sync_outbox.interval_seconds", "EZ_SYNC_OUTBOX_INTERVAL_SECONDS")
_ = v.BindEnv("sync_outbox.batch_size", "EZ_SYNC_OUTBOX_BATCH_SIZE") _ = v.BindEnv("sync_outbox.batch_size", "EZ_SYNC_OUTBOX_BATCH_SIZE")
_ = v.BindEnv("sync_outbox.max_retries", "EZ_SYNC_OUTBOX_MAX_RETRIES") _ = v.BindEnv("sync_outbox.max_retries", "EZ_SYNC_OUTBOX_MAX_RETRIES")
_ = v.BindEnv("token_refresh.interval_seconds", "EZ_TOKEN_REFRESH_INTERVAL_SECONDS")
_ = v.BindEnv("token_refresh.refresh_skew_seconds", "EZ_TOKEN_REFRESH_SKEW_SECONDS")
_ = v.BindEnv("token_refresh.batch_size", "EZ_TOKEN_REFRESH_BATCH_SIZE")
_ = v.BindEnv("token_refresh.max_retries", "EZ_TOKEN_REFRESH_MAX_RETRIES")
if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" { if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" {
v.SetConfigFile(configFile) v.SetConfigFile(configFile)
@@ -208,6 +224,12 @@ func Load() (*Config, error) {
BatchSize: v.GetInt("sync_outbox.batch_size"), BatchSize: v.GetInt("sync_outbox.batch_size"),
MaxRetries: v.GetInt("sync_outbox.max_retries"), MaxRetries: v.GetInt("sync_outbox.max_retries"),
}, },
TokenRefresh: TokenRefreshConfig{
IntervalSeconds: v.GetInt("token_refresh.interval_seconds"),
RefreshSkewSeconds: v.GetInt("token_refresh.refresh_skew_seconds"),
BatchSize: v.GetInt("token_refresh.batch_size"),
MaxRetries: v.GetInt("token_refresh.max_retries"),
},
} }
return cfg, nil return cfg, nil

View File

@@ -6,6 +6,11 @@ import "time"
type APIKeyDTO struct { type APIKeyDTO struct {
GroupID uint `json:"group_id"` GroupID uint `json:"group_id"`
APIKey string `json:"api_key"` APIKey string `json:"api_key"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
AccountID string `json:"account_id,omitempty"`
ProjectID string `json:"project_id,omitempty"`
Weight int `json:"weight,omitempty"` Weight int `json:"weight,omitempty"`
Status string `json:"status"` Status string `json:"status"`
AutoBan *bool `json:"auto_ban,omitempty"` AutoBan *bool `json:"auto_ban,omitempty"`

View File

@@ -7,6 +7,8 @@ type ProviderGroupDTO struct {
BaseURL string `json:"base_url"` BaseURL string `json:"base_url"`
GoogleProject string `json:"google_project,omitempty"` GoogleProject string `json:"google_project,omitempty"`
GoogleLocation string `json:"google_location,omitempty"` GoogleLocation string `json:"google_location,omitempty"`
StaticHeaders string `json:"static_headers,omitempty"`
HeadersProfile string `json:"headers_profile,omitempty"`
Models []string `json:"models"` Models []string `json:"models"`
Status string `json:"status"` Status string `json:"status"`
} }

View File

@@ -14,6 +14,8 @@ type ProviderGroup struct {
BaseURL string `gorm:"size:512;not null" json:"base_url"` BaseURL string `gorm:"size:512;not null" json:"base_url"`
GoogleProject string `gorm:"size:128" json:"google_project,omitempty"` GoogleProject string `gorm:"size:128" json:"google_project,omitempty"`
GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"` GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"`
StaticHeaders string `gorm:"type:text" json:"static_headers,omitempty"`
HeadersProfile string `gorm:"size:64" json:"headers_profile,omitempty"`
Models string `json:"models"` // comma-separated list of supported models Models string `json:"models"` // comma-separated list of supported models
Status string `gorm:"size:50;default:'active'" json:"status"` Status string `gorm:"size:50;default:'active'" json:"status"`
} }
@@ -23,6 +25,11 @@ type APIKey struct {
gorm.Model gorm.Model
GroupID uint `gorm:"not null;index" json:"group_id"` GroupID uint `gorm:"not null;index" json:"group_id"`
APIKey string `gorm:"not null" json:"api_key"` APIKey string `gorm:"not null" json:"api_key"`
AccessToken string `gorm:"type:text" json:"access_token,omitempty"`
RefreshToken string `gorm:"type:text" json:"-"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
AccountID string `gorm:"size:255" json:"account_id,omitempty"`
ProjectID string `gorm:"size:255" json:"project_id,omitempty"`
Weight int `gorm:"default:1" json:"weight"` Weight int `gorm:"default:1" json:"weight"`
Status string `gorm:"size:50;default:'active'" json:"status"` Status string `gorm:"size:50;default:'active'" json:"status"`
AutoBan bool `gorm:"default:true" json:"auto_ban"` AutoBan bool `gorm:"default:true" json:"auto_ban"`

View File

@@ -33,16 +33,30 @@ func (m *ProviderGroupManager) NormalizeGroup(group model.ProviderGroup) (model.
group.BaseURL = strings.TrimSpace(group.BaseURL) group.BaseURL = strings.TrimSpace(group.BaseURL)
group.GoogleProject = strings.TrimSpace(group.GoogleProject) group.GoogleProject = strings.TrimSpace(group.GoogleProject)
group.GoogleLocation = strings.TrimSpace(group.GoogleLocation) group.GoogleLocation = strings.TrimSpace(group.GoogleLocation)
group.StaticHeaders = strings.TrimSpace(group.StaticHeaders)
group.HeadersProfile = strings.TrimSpace(group.HeadersProfile)
switch ptype { switch ptype {
case provider.TypeOpenAI: case provider.TypeOpenAI:
if group.BaseURL == "" { if group.BaseURL == "" {
group.BaseURL = "https://api.openai.com/v1" group.BaseURL = "https://api.openai.com/v1"
} }
case provider.TypeAnthropic, provider.TypeClaude: case provider.TypeAnthropic, provider.TypeClaude, provider.TypeClaudeCode:
if group.BaseURL == "" { if group.BaseURL == "" {
group.BaseURL = "https://api.anthropic.com" group.BaseURL = "https://api.anthropic.com"
} }
case provider.TypeCodex:
if group.BaseURL == "" {
group.BaseURL = "https://chatgpt.com"
}
case provider.TypeGeminiCLI:
if group.BaseURL == "" {
group.BaseURL = "https://cloudcode-pa.googleapis.com"
}
case provider.TypeAntigravity:
if group.BaseURL == "" {
group.BaseURL = "https://daily-cloudcode-pa.googleapis.com"
}
case provider.TypeCompatible: case provider.TypeCompatible:
if group.BaseURL == "" { if group.BaseURL == "" {
return model.ProviderGroup{}, fmt.Errorf("base_url required for compatible providers") return model.ProviderGroup{}, fmt.Errorf("base_url required for compatible providers")
@@ -72,8 +86,14 @@ func (m *ProviderGroupManager) ValidateAPIKey(group model.ProviderGroup, key mod
return fmt.Errorf("provider group type required") return fmt.Errorf("provider group type required")
} }
apiKey := strings.TrimSpace(key.APIKey) apiKey := strings.TrimSpace(key.APIKey)
accessToken := strings.TrimSpace(key.AccessToken)
switch { switch {
case ptype == provider.TypeCodex || ptype == provider.TypeGeminiCLI || ptype == provider.TypeAntigravity || ptype == provider.TypeClaudeCode:
if accessToken == "" {
return fmt.Errorf("access_token required")
}
return nil
case provider.IsVertexFamily(ptype): case provider.IsVertexFamily(ptype):
// Vertex uses ADC; api_key can be empty. // Vertex uses ADC; api_key can be empty.
return nil return nil

View File

@@ -278,8 +278,14 @@ type providerSnapshot struct {
Type string `json:"type"` Type string `json:"type"`
BaseURL string `json:"base_url"` BaseURL string `json:"base_url"`
APIKey string `json:"api_key"` APIKey string `json:"api_key"`
AccessToken string `json:"access_token,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"`
AccountID string `json:"account_id,omitempty"`
ProjectID string `json:"project_id,omitempty"`
GoogleProject string `json:"google_project,omitempty"` GoogleProject string `json:"google_project,omitempty"`
GoogleLocation string `json:"google_location,omitempty"` GoogleLocation string `json:"google_location,omitempty"`
StaticHeaders string `json:"static_headers,omitempty"`
HeadersProfile string `json:"headers_profile,omitempty"`
GroupID uint `json:"group_id,omitempty"` GroupID uint `json:"group_id,omitempty"`
Group string `json:"group"` Group string `json:"group"`
Models []string `json:"models"` Models []string `json:"models"`
@@ -333,8 +339,13 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip
Type: strings.TrimSpace(g.Type), Type: strings.TrimSpace(g.Type),
BaseURL: strings.TrimSpace(g.BaseURL), BaseURL: strings.TrimSpace(g.BaseURL),
APIKey: strings.TrimSpace(k.APIKey), APIKey: strings.TrimSpace(k.APIKey),
AccessToken: strings.TrimSpace(k.AccessToken),
GoogleProject: strings.TrimSpace(g.GoogleProject), GoogleProject: strings.TrimSpace(g.GoogleProject),
GoogleLocation: strings.TrimSpace(g.GoogleLocation), GoogleLocation: strings.TrimSpace(g.GoogleLocation),
StaticHeaders: strings.TrimSpace(g.StaticHeaders),
HeadersProfile: strings.TrimSpace(g.HeadersProfile),
AccountID: strings.TrimSpace(k.AccountID),
ProjectID: strings.TrimSpace(k.ProjectID),
GroupID: g.ID, GroupID: g.ID,
Group: groupName, Group: groupName,
Models: models, Models: models,
@@ -343,6 +354,9 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip
AutoBan: k.AutoBan, AutoBan: k.AutoBan,
BanReason: strings.TrimSpace(k.BanReason), BanReason: strings.TrimSpace(k.BanReason),
} }
if k.ExpiresAt != nil {
snap.ExpiresAt = k.ExpiresAt.UTC().Unix()
}
if k.BanUntil != nil { if k.BanUntil != nil {
snap.BanUntil = k.BanUntil.UTC().Unix() snap.BanUntil = k.BanUntil.UTC().Unix()
} }