mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(api): add OAuth token fields and new provider types support
Add support for OAuth-based authentication with access/refresh tokens and expiration tracking for API keys. Extend provider groups with static headers configuration and headers profile options. Changes include: - Add AccessToken, RefreshToken, ExpiresAt, AccountID, ProjectID to APIKey model - Add StaticHeaders and HeadersProfile to ProviderGroup model - Add TokenRefresh configuration for background token management - Support new provider types: ClaudeCode, Codex, GeminiCLI, Antigravity - Update sync service to include new fields in provider snapshots
This commit is contained in:
@@ -51,11 +51,19 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
|
||||
key := model.APIKey{
|
||||
GroupID: req.GroupID,
|
||||
APIKey: apiKey,
|
||||
AccessToken: strings.TrimSpace(req.AccessToken),
|
||||
RefreshToken: strings.TrimSpace(req.RefreshToken),
|
||||
AccountID: strings.TrimSpace(req.AccountID),
|
||||
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||
Weight: normalizeWeight(req.Weight),
|
||||
Status: status,
|
||||
AutoBan: autoBan,
|
||||
BanReason: strings.TrimSpace(req.BanReason),
|
||||
}
|
||||
if !req.ExpiresAt.IsZero() {
|
||||
tu := req.ExpiresAt.UTC()
|
||||
key.ExpiresAt = &tu
|
||||
}
|
||||
if !req.BanUntil.IsZero() {
|
||||
tu := req.BanUntil.UTC()
|
||||
key.BanUntil = &tu
|
||||
@@ -180,6 +188,22 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
|
||||
if strings.TrimSpace(req.APIKey) != "" {
|
||||
update["api_key"] = strings.TrimSpace(req.APIKey)
|
||||
}
|
||||
if strings.TrimSpace(req.AccessToken) != "" {
|
||||
update["access_token"] = strings.TrimSpace(req.AccessToken)
|
||||
}
|
||||
if strings.TrimSpace(req.RefreshToken) != "" {
|
||||
update["refresh_token"] = strings.TrimSpace(req.RefreshToken)
|
||||
}
|
||||
if !req.ExpiresAt.IsZero() {
|
||||
tu := req.ExpiresAt.UTC()
|
||||
update["expires_at"] = &tu
|
||||
}
|
||||
if strings.TrimSpace(req.AccountID) != "" {
|
||||
update["account_id"] = strings.TrimSpace(req.AccountID)
|
||||
}
|
||||
if strings.TrimSpace(req.ProjectID) != "" {
|
||||
update["project_id"] = strings.TrimSpace(req.ProjectID)
|
||||
}
|
||||
if req.Weight > 0 {
|
||||
update["weight"] = normalizeWeight(req.Weight)
|
||||
}
|
||||
@@ -199,11 +223,17 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
|
||||
if req.BanUntil.IsZero() && strings.TrimSpace(req.Status) == "active" {
|
||||
update["ban_until"] = nil
|
||||
}
|
||||
if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" {
|
||||
if req.GroupID != 0 || strings.TrimSpace(req.APIKey) != "" || strings.TrimSpace(req.AccessToken) != "" || strings.TrimSpace(req.RefreshToken) != "" {
|
||||
nextKey := key
|
||||
if v, ok := update["api_key"].(string); ok {
|
||||
nextKey.APIKey = v
|
||||
}
|
||||
if v, ok := update["access_token"].(string); ok {
|
||||
nextKey.AccessToken = v
|
||||
}
|
||||
if v, ok := update["refresh_token"].(string); ok {
|
||||
nextKey.RefreshToken = v
|
||||
}
|
||||
if req.GroupID != 0 {
|
||||
nextKey.GroupID = req.GroupID
|
||||
}
|
||||
|
||||
@@ -42,6 +42,8 @@ func (h *Handler) CreateProviderGroup(c *gin.Context) {
|
||||
BaseURL: strings.TrimSpace(req.BaseURL),
|
||||
GoogleProject: strings.TrimSpace(req.GoogleProject),
|
||||
GoogleLocation: strings.TrimSpace(req.GoogleLocation),
|
||||
StaticHeaders: strings.TrimSpace(req.StaticHeaders),
|
||||
HeadersProfile: strings.TrimSpace(req.HeadersProfile),
|
||||
Models: strings.Join(req.Models, ","),
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
}
|
||||
@@ -167,6 +169,12 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) {
|
||||
if strings.TrimSpace(req.GoogleLocation) != "" {
|
||||
next.GoogleLocation = strings.TrimSpace(req.GoogleLocation)
|
||||
}
|
||||
if strings.TrimSpace(req.StaticHeaders) != "" {
|
||||
next.StaticHeaders = strings.TrimSpace(req.StaticHeaders)
|
||||
}
|
||||
if strings.TrimSpace(req.HeadersProfile) != "" {
|
||||
next.HeadersProfile = strings.TrimSpace(req.HeadersProfile)
|
||||
}
|
||||
if req.Models != nil {
|
||||
next.Models = strings.Join(req.Models, ",")
|
||||
}
|
||||
@@ -184,6 +192,8 @@ func (h *Handler) UpdateProviderGroup(c *gin.Context) {
|
||||
group.BaseURL = normalized.BaseURL
|
||||
group.GoogleProject = normalized.GoogleProject
|
||||
group.GoogleLocation = normalized.GoogleLocation
|
||||
group.StaticHeaders = normalized.StaticHeaders
|
||||
group.HeadersProfile = normalized.HeadersProfile
|
||||
group.Models = normalized.Models
|
||||
group.Status = normalized.Status
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ type Config struct {
|
||||
Quota QuotaConfig
|
||||
Internal InternalConfig
|
||||
SyncOutbox SyncOutboxConfig
|
||||
TokenRefresh TokenRefreshConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -80,6 +81,13 @@ type SyncOutboxConfig struct {
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
type TokenRefreshConfig struct {
|
||||
IntervalSeconds int
|
||||
RefreshSkewSeconds int
|
||||
BatchSize int
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
@@ -111,6 +119,10 @@ func Load() (*Config, error) {
|
||||
v.SetDefault("sync_outbox.interval_seconds", 5)
|
||||
v.SetDefault("sync_outbox.batch_size", 200)
|
||||
v.SetDefault("sync_outbox.max_retries", 10)
|
||||
v.SetDefault("token_refresh.interval_seconds", 1800)
|
||||
v.SetDefault("token_refresh.refresh_skew_seconds", 3000)
|
||||
v.SetDefault("token_refresh.batch_size", 200)
|
||||
v.SetDefault("token_refresh.max_retries", 3)
|
||||
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
@@ -143,6 +155,10 @@ func Load() (*Config, error) {
|
||||
_ = v.BindEnv("sync_outbox.interval_seconds", "EZ_SYNC_OUTBOX_INTERVAL_SECONDS")
|
||||
_ = v.BindEnv("sync_outbox.batch_size", "EZ_SYNC_OUTBOX_BATCH_SIZE")
|
||||
_ = v.BindEnv("sync_outbox.max_retries", "EZ_SYNC_OUTBOX_MAX_RETRIES")
|
||||
_ = v.BindEnv("token_refresh.interval_seconds", "EZ_TOKEN_REFRESH_INTERVAL_SECONDS")
|
||||
_ = v.BindEnv("token_refresh.refresh_skew_seconds", "EZ_TOKEN_REFRESH_SKEW_SECONDS")
|
||||
_ = v.BindEnv("token_refresh.batch_size", "EZ_TOKEN_REFRESH_BATCH_SIZE")
|
||||
_ = v.BindEnv("token_refresh.max_retries", "EZ_TOKEN_REFRESH_MAX_RETRIES")
|
||||
|
||||
if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" {
|
||||
v.SetConfigFile(configFile)
|
||||
@@ -208,6 +224,12 @@ func Load() (*Config, error) {
|
||||
BatchSize: v.GetInt("sync_outbox.batch_size"),
|
||||
MaxRetries: v.GetInt("sync_outbox.max_retries"),
|
||||
},
|
||||
TokenRefresh: TokenRefreshConfig{
|
||||
IntervalSeconds: v.GetInt("token_refresh.interval_seconds"),
|
||||
RefreshSkewSeconds: v.GetInt("token_refresh.refresh_skew_seconds"),
|
||||
BatchSize: v.GetInt("token_refresh.batch_size"),
|
||||
MaxRetries: v.GetInt("token_refresh.max_retries"),
|
||||
},
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
|
||||
@@ -6,6 +6,11 @@ import "time"
|
||||
type APIKeyDTO struct {
|
||||
GroupID uint `json:"group_id"`
|
||||
APIKey string `json:"api_key"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
Weight int `json:"weight,omitempty"`
|
||||
Status string `json:"status"`
|
||||
AutoBan *bool `json:"auto_ban,omitempty"`
|
||||
|
||||
@@ -7,6 +7,8 @@ type ProviderGroupDTO struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
GoogleProject string `json:"google_project,omitempty"`
|
||||
GoogleLocation string `json:"google_location,omitempty"`
|
||||
StaticHeaders string `json:"static_headers,omitempty"`
|
||||
HeadersProfile string `json:"headers_profile,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ type ProviderGroup struct {
|
||||
BaseURL string `gorm:"size:512;not null" json:"base_url"`
|
||||
GoogleProject string `gorm:"size:128" json:"google_project,omitempty"`
|
||||
GoogleLocation string `gorm:"size:64" json:"google_location,omitempty"`
|
||||
StaticHeaders string `gorm:"type:text" json:"static_headers,omitempty"`
|
||||
HeadersProfile string `gorm:"size:64" json:"headers_profile,omitempty"`
|
||||
Models string `json:"models"` // comma-separated list of supported models
|
||||
Status string `gorm:"size:50;default:'active'" json:"status"`
|
||||
}
|
||||
@@ -23,6 +25,11 @@ type APIKey struct {
|
||||
gorm.Model
|
||||
GroupID uint `gorm:"not null;index" json:"group_id"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
AccessToken string `gorm:"type:text" json:"access_token,omitempty"`
|
||||
RefreshToken string `gorm:"type:text" json:"-"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
AccountID string `gorm:"size:255" json:"account_id,omitempty"`
|
||||
ProjectID string `gorm:"size:255" json:"project_id,omitempty"`
|
||||
Weight int `gorm:"default:1" json:"weight"`
|
||||
Status string `gorm:"size:50;default:'active'" json:"status"`
|
||||
AutoBan bool `gorm:"default:true" json:"auto_ban"`
|
||||
|
||||
@@ -33,16 +33,30 @@ func (m *ProviderGroupManager) NormalizeGroup(group model.ProviderGroup) (model.
|
||||
group.BaseURL = strings.TrimSpace(group.BaseURL)
|
||||
group.GoogleProject = strings.TrimSpace(group.GoogleProject)
|
||||
group.GoogleLocation = strings.TrimSpace(group.GoogleLocation)
|
||||
group.StaticHeaders = strings.TrimSpace(group.StaticHeaders)
|
||||
group.HeadersProfile = strings.TrimSpace(group.HeadersProfile)
|
||||
|
||||
switch ptype {
|
||||
case provider.TypeOpenAI:
|
||||
if group.BaseURL == "" {
|
||||
group.BaseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
case provider.TypeAnthropic, provider.TypeClaude:
|
||||
case provider.TypeAnthropic, provider.TypeClaude, provider.TypeClaudeCode:
|
||||
if group.BaseURL == "" {
|
||||
group.BaseURL = "https://api.anthropic.com"
|
||||
}
|
||||
case provider.TypeCodex:
|
||||
if group.BaseURL == "" {
|
||||
group.BaseURL = "https://chatgpt.com"
|
||||
}
|
||||
case provider.TypeGeminiCLI:
|
||||
if group.BaseURL == "" {
|
||||
group.BaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
}
|
||||
case provider.TypeAntigravity:
|
||||
if group.BaseURL == "" {
|
||||
group.BaseURL = "https://daily-cloudcode-pa.googleapis.com"
|
||||
}
|
||||
case provider.TypeCompatible:
|
||||
if group.BaseURL == "" {
|
||||
return model.ProviderGroup{}, fmt.Errorf("base_url required for compatible providers")
|
||||
@@ -72,8 +86,14 @@ func (m *ProviderGroupManager) ValidateAPIKey(group model.ProviderGroup, key mod
|
||||
return fmt.Errorf("provider group type required")
|
||||
}
|
||||
apiKey := strings.TrimSpace(key.APIKey)
|
||||
accessToken := strings.TrimSpace(key.AccessToken)
|
||||
|
||||
switch {
|
||||
case ptype == provider.TypeCodex || ptype == provider.TypeGeminiCLI || ptype == provider.TypeAntigravity || ptype == provider.TypeClaudeCode:
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("access_token required")
|
||||
}
|
||||
return nil
|
||||
case provider.IsVertexFamily(ptype):
|
||||
// Vertex uses ADC; api_key can be empty.
|
||||
return nil
|
||||
|
||||
@@ -278,8 +278,14 @@ type providerSnapshot struct {
|
||||
Type string `json:"type"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
GoogleProject string `json:"google_project,omitempty"`
|
||||
GoogleLocation string `json:"google_location,omitempty"`
|
||||
StaticHeaders string `json:"static_headers,omitempty"`
|
||||
HeadersProfile string `json:"headers_profile,omitempty"`
|
||||
GroupID uint `json:"group_id,omitempty"`
|
||||
Group string `json:"group"`
|
||||
Models []string `json:"models"`
|
||||
@@ -333,8 +339,13 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip
|
||||
Type: strings.TrimSpace(g.Type),
|
||||
BaseURL: strings.TrimSpace(g.BaseURL),
|
||||
APIKey: strings.TrimSpace(k.APIKey),
|
||||
AccessToken: strings.TrimSpace(k.AccessToken),
|
||||
GoogleProject: strings.TrimSpace(g.GoogleProject),
|
||||
GoogleLocation: strings.TrimSpace(g.GoogleLocation),
|
||||
StaticHeaders: strings.TrimSpace(g.StaticHeaders),
|
||||
HeadersProfile: strings.TrimSpace(g.HeadersProfile),
|
||||
AccountID: strings.TrimSpace(k.AccountID),
|
||||
ProjectID: strings.TrimSpace(k.ProjectID),
|
||||
GroupID: g.ID,
|
||||
Group: groupName,
|
||||
Models: models,
|
||||
@@ -343,6 +354,9 @@ func (s *SyncService) writeProvidersSnapshot(ctx context.Context, pipe redis.Pip
|
||||
AutoBan: k.AutoBan,
|
||||
BanReason: strings.TrimSpace(k.BanReason),
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
snap.ExpiresAt = k.ExpiresAt.UTC().Unix()
|
||||
}
|
||||
if k.BanUntil != nil {
|
||||
snap.BanUntil = k.BanUntil.UTC().Unix()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user