mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Implement automatic token refresh mechanism for CPA providers (Codex, GeminiCLI, Antigravity, ClaudeCode) with the following features: - Periodic refresh of expiring tokens based on configurable interval - Redis event queue processing for on-demand token refresh - Retry logic with exponential backoff for transient failures - Automatic key deactivation on non-retryable errors - Provider-specific OAuth token refresh implementations - Sync service integration to update providers after refresh
396 lines
11 KiB
Go
396 lines
11 KiB
Go
package cron
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/ez-api/ez-api/internal/service"
|
|
"github.com/ez-api/foundation/jsoncodec"
|
|
"github.com/ez-api/foundation/provider"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const refreshEventKey = "events:refresh_provider"
|
|
|
|
type RefreshEvent struct {
|
|
ProviderID uint `json:"provider_id"`
|
|
ProviderType string `json:"provider_type,omitempty"`
|
|
StatusCode int `json:"status_code,omitempty"`
|
|
Timestamp int64 `json:"timestamp,omitempty"`
|
|
}
|
|
|
|
type TokenRefresher struct {
|
|
db *gorm.DB
|
|
rdb *redis.Client
|
|
sync *service.SyncService
|
|
interval time.Duration
|
|
refreshSkew time.Duration
|
|
batchSize int
|
|
maxRetries int
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, interval, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher {
|
|
if interval <= 0 {
|
|
interval = 30 * time.Minute
|
|
}
|
|
if refreshSkew <= 0 {
|
|
refreshSkew = 50 * time.Minute
|
|
}
|
|
if batchSize <= 0 {
|
|
batchSize = 200
|
|
}
|
|
if maxRetries <= 0 {
|
|
maxRetries = 3
|
|
}
|
|
return &TokenRefresher{
|
|
db: db,
|
|
rdb: rdb,
|
|
sync: sync,
|
|
interval: interval,
|
|
refreshSkew: refreshSkew,
|
|
batchSize: batchSize,
|
|
maxRetries: maxRetries,
|
|
httpClient: &http.Client{Timeout: 15 * time.Second},
|
|
}
|
|
}
|
|
|
|
func (t *TokenRefresher) Start(ctx context.Context) {
|
|
if t == nil || t.db == nil {
|
|
return
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
ticker := time.NewTicker(t.interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := t.refreshOnce(ctx); err != nil {
|
|
slog.Default().Warn("token refresh failed", "err", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshOnce(ctx context.Context) error {
|
|
if t == nil || t.db == nil {
|
|
return nil
|
|
}
|
|
if err := t.processRefreshEvents(ctx); err != nil {
|
|
slog.Default().Warn("token refresh event handling failed", "err", err)
|
|
}
|
|
return t.refreshExpiring(ctx)
|
|
}
|
|
|
|
func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error {
|
|
if t == nil || t.rdb == nil {
|
|
return nil
|
|
}
|
|
for i := 0; i < t.batchSize; i++ {
|
|
raw, err := t.rdb.LPop(ctx, refreshEventKey).Result()
|
|
if err == redis.Nil {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var evt RefreshEvent
|
|
if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 {
|
|
continue
|
|
}
|
|
if err := t.refreshByID(ctx, evt.ProviderID); err != nil {
|
|
slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshExpiring(ctx context.Context) error {
|
|
if t == nil || t.db == nil {
|
|
return nil
|
|
}
|
|
cutoff := time.Now().UTC().Add(t.refreshSkew)
|
|
var keys []model.APIKey
|
|
if err := t.db.WithContext(ctx).
|
|
Where("status = ?", "active").
|
|
Where("refresh_token <> ''").
|
|
Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff).
|
|
Limit(t.batchSize).
|
|
Find(&keys).Error; err != nil {
|
|
return err
|
|
}
|
|
for i := range keys {
|
|
if err := t.refreshKey(ctx, &keys[i]); err != nil {
|
|
slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error {
|
|
if t == nil || t.db == nil || id == 0 {
|
|
return nil
|
|
}
|
|
var key model.APIKey
|
|
if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil {
|
|
return err
|
|
}
|
|
return t.refreshKey(ctx, &key)
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error {
|
|
if t == nil || t.db == nil || key == nil {
|
|
return nil
|
|
}
|
|
var group model.ProviderGroup
|
|
if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil {
|
|
return err
|
|
}
|
|
ptype := provider.NormalizeType(group.Type)
|
|
if !isCPAProvider(ptype) {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(key.RefreshToken) == "" {
|
|
return nil
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < t.maxRetries; attempt++ {
|
|
if attempt > 0 {
|
|
time.Sleep(time.Duration(attempt) * time.Second)
|
|
}
|
|
out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken)
|
|
if err == nil {
|
|
updates := map[string]any{
|
|
"access_token": strings.TrimSpace(out.AccessToken),
|
|
"expires_at": out.ExpiresAt,
|
|
"status": "active",
|
|
}
|
|
if strings.TrimSpace(out.RefreshToken) != "" {
|
|
updates["refresh_token"] = strings.TrimSpace(out.RefreshToken)
|
|
}
|
|
if strings.TrimSpace(out.AccountID) != "" {
|
|
updates["account_id"] = strings.TrimSpace(out.AccountID)
|
|
}
|
|
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
|
|
return err
|
|
}
|
|
if t.sync != nil {
|
|
_ = t.sync.SyncProvidersForAPIKey(t.db, key.ID)
|
|
}
|
|
return nil
|
|
}
|
|
lastErr = err
|
|
if rerr, ok := err.(*refreshError); ok && !rerr.Retryable {
|
|
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{
|
|
"status": "inactive",
|
|
"access_token": "",
|
|
"expires_at": nil,
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
return rerr
|
|
}
|
|
}
|
|
return lastErr
|
|
}
|
|
|
|
func isCPAProvider(ptype string) bool {
|
|
switch ptype {
|
|
case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
type refreshOutput struct {
|
|
AccessToken string
|
|
RefreshToken string
|
|
ExpiresAt time.Time
|
|
AccountID string
|
|
}
|
|
|
|
type refreshError struct {
|
|
Retryable bool
|
|
Code string
|
|
Err error
|
|
}
|
|
|
|
func (e *refreshError) Error() string {
|
|
if e == nil {
|
|
return ""
|
|
}
|
|
if e.Code != "" {
|
|
return fmt.Sprintf("refresh %s: %v", e.Code, e.Err)
|
|
}
|
|
return fmt.Sprintf("refresh error: %v", e.Err)
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) {
|
|
switch ptype {
|
|
case provider.TypeCodex:
|
|
return t.refreshCodex(ctx, refreshToken)
|
|
case provider.TypeGeminiCLI:
|
|
return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret)
|
|
case provider.TypeAntigravity:
|
|
return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret)
|
|
case provider.TypeClaudeCode:
|
|
return t.refreshClaude(ctx, refreshToken)
|
|
default:
|
|
return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)}
|
|
}
|
|
}
|
|
|
|
type tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
TokenType string `json:"token_type,omitempty"`
|
|
IDToken string `json:"id_token,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
ErrorDescription string `json:"error_description,omitempty"`
|
|
}
|
|
|
|
const (
|
|
codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|
geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
|
geminiCLIClientSecret = "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"
|
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
|
antigravityClientSecret = "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"
|
|
claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065"
|
|
)
|
|
|
|
func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) {
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
"client_id": {codexClientID},
|
|
"scope": {"openid profile email"},
|
|
}
|
|
resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := &refreshOutput{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
|
AccountID: parseAccountID(resp.IDToken),
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) {
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
"client_id": {claudeClientID},
|
|
}
|
|
resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &refreshOutput{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
|
}, nil
|
|
}
|
|
|
|
func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) {
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
"client_id": {clientID},
|
|
"client_secret": {clientSecret},
|
|
}
|
|
resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &refreshOutput{
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: resp.RefreshToken,
|
|
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
|
}, nil
|
|
}
|
|
|
|
func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, &refreshError{Retryable: true, Code: "build_request", Err: err}
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := t.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, &refreshError{Retryable: true, Code: "transport", Err: err}
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, &refreshError{Retryable: true, Code: "read_body", Err: err}
|
|
}
|
|
var payload tokenResponse
|
|
if err := jsoncodec.Unmarshal(body, &payload); err != nil {
|
|
return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err}
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
code := strings.TrimSpace(payload.Error)
|
|
if code == "" {
|
|
code = "http_" + fmt.Sprint(resp.StatusCode)
|
|
}
|
|
retryable := resp.StatusCode >= 500
|
|
if code == "invalid_grant" || code == "invalid_client" {
|
|
retryable = false
|
|
}
|
|
return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf(strings.TrimSpace(payload.ErrorDescription))}
|
|
}
|
|
if strings.TrimSpace(payload.AccessToken) == "" {
|
|
return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")}
|
|
}
|
|
return &payload, nil
|
|
}
|
|
|
|
func parseAccountID(idToken string) string {
|
|
parts := strings.Split(idToken, ".")
|
|
if len(parts) < 2 {
|
|
return ""
|
|
}
|
|
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
var payload struct {
|
|
AccountID string `json:"account_id"`
|
|
OrgID string `json:"org_id"`
|
|
OrganizationID string `json:"organization_id"`
|
|
}
|
|
if err := jsoncodec.Unmarshal(raw, &payload); err != nil {
|
|
return ""
|
|
}
|
|
if payload.AccountID != "" {
|
|
return payload.AccountID
|
|
}
|
|
if payload.OrgID != "" {
|
|
return payload.OrgID
|
|
}
|
|
return payload.OrganizationID
|
|
}
|