package cron import ( "context" "encoding/base64" "fmt" "io" "log/slog" "net/http" "net/url" "strings" "time" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/service" "github.com/ez-api/foundation/jsoncodec" "github.com/ez-api/foundation/provider" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) const refreshEventKey = "events:refresh_provider" type RefreshEvent struct { ProviderID uint `json:"provider_id"` ProviderType string `json:"provider_type,omitempty"` StatusCode int `json:"status_code,omitempty"` Timestamp int64 `json:"timestamp,omitempty"` } type TokenRefresher struct { db *gorm.DB rdb *redis.Client sync *service.SyncService interval time.Duration refreshSkew time.Duration batchSize int maxRetries int httpClient *http.Client } func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, interval, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher { if interval <= 0 { interval = 30 * time.Minute } if refreshSkew <= 0 { refreshSkew = 50 * time.Minute } if batchSize <= 0 { batchSize = 200 } if maxRetries <= 0 { maxRetries = 3 } return &TokenRefresher{ db: db, rdb: rdb, sync: sync, interval: interval, refreshSkew: refreshSkew, batchSize: batchSize, maxRetries: maxRetries, httpClient: &http.Client{Timeout: 15 * time.Second}, } } func (t *TokenRefresher) Start(ctx context.Context) { if t == nil || t.db == nil { return } if ctx == nil { ctx = context.Background() } ticker := time.NewTicker(t.interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := t.refreshOnce(ctx); err != nil { slog.Default().Warn("token refresh failed", "err", err) } } } } func (t *TokenRefresher) refreshOnce(ctx context.Context) error { if t == nil || t.db == nil { return nil } if err := t.processRefreshEvents(ctx); err != nil { slog.Default().Warn("token refresh event handling failed", "err", err) } return t.refreshExpiring(ctx) } func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error { if t == nil || t.rdb == nil { return nil } for i := 0; i < t.batchSize; i++ { raw, err := t.rdb.LPop(ctx, refreshEventKey).Result() if err == redis.Nil { return nil } if err != nil { return err } var evt RefreshEvent if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 { continue } if err := t.refreshByID(ctx, evt.ProviderID); err != nil { slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err) } } return nil } func (t *TokenRefresher) refreshExpiring(ctx context.Context) error { if t == nil || t.db == nil { return nil } cutoff := time.Now().UTC().Add(t.refreshSkew) var keys []model.APIKey if err := t.db.WithContext(ctx). Where("status = ?", "active"). Where("refresh_token <> ''"). Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff). Limit(t.batchSize). Find(&keys).Error; err != nil { return err } for i := range keys { if err := t.refreshKey(ctx, &keys[i]); err != nil { slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err) } } return nil } func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error { if t == nil || t.db == nil || id == 0 { return nil } var key model.APIKey if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil { return err } return t.refreshKey(ctx, &key) } func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error { if t == nil || t.db == nil || key == nil { return nil } var group model.ProviderGroup if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil { return err } ptype := provider.NormalizeType(group.Type) if !isCPAProvider(ptype) { return nil } if strings.TrimSpace(key.RefreshToken) == "" { return nil } var lastErr error for attempt := 0; attempt < t.maxRetries; attempt++ { if attempt > 0 { time.Sleep(time.Duration(attempt) * time.Second) } out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken) if err == nil { updates := map[string]any{ "access_token": strings.TrimSpace(out.AccessToken), "expires_at": out.ExpiresAt, "status": "active", } if strings.TrimSpace(out.RefreshToken) != "" { updates["refresh_token"] = strings.TrimSpace(out.RefreshToken) } if strings.TrimSpace(out.AccountID) != "" { updates["account_id"] = strings.TrimSpace(out.AccountID) } if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil { return err } if t.sync != nil { _ = t.sync.SyncProvidersForAPIKey(t.db, key.ID) } return nil } lastErr = err if rerr, ok := err.(*refreshError); ok && !rerr.Retryable { if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{ "status": "inactive", "access_token": "", "expires_at": nil, }).Error; err != nil { return err } return rerr } } return lastErr } func isCPAProvider(ptype string) bool { switch ptype { case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode: return true default: return false } } type refreshOutput struct { AccessToken string RefreshToken string ExpiresAt time.Time AccountID string } type refreshError struct { Retryable bool Code string Err error } func (e *refreshError) Error() string { if e == nil { return "" } if e.Code != "" { return fmt.Sprintf("refresh %s: %v", e.Code, e.Err) } return fmt.Sprintf("refresh error: %v", e.Err) } func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) { switch ptype { case provider.TypeCodex: return t.refreshCodex(ctx, refreshToken) case provider.TypeGeminiCLI: return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret) case provider.TypeAntigravity: return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret) case provider.TypeClaudeCode: return t.refreshClaude(ctx, refreshToken) default: return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)} } } type tokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int64 `json:"expires_in,omitempty"` TokenType string `json:"token_type,omitempty"` IDToken string `json:"id_token,omitempty"` Error string `json:"error,omitempty"` ErrorDescription string `json:"error_description,omitempty"` } const ( codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" geminiCLIClientSecret = "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" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065" ) func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) { form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "client_id": {codexClientID}, "scope": {"openid profile email"}, } resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form) if err != nil { return nil, err } out := &refreshOutput{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), AccountID: parseAccountID(resp.IDToken), } return out, nil } func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) { form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "client_id": {claudeClientID}, } resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form) if err != nil { return nil, err } return &refreshOutput{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), }, nil } func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) { form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "client_id": {clientID}, "client_secret": {clientSecret}, } resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form) if err != nil { return nil, err } return &refreshOutput{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), }, nil } func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { return nil, &refreshError{Retryable: true, Code: "build_request", Err: err} } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := t.httpClient.Do(req) if err != nil { return nil, &refreshError{Retryable: true, Code: "transport", Err: err} } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, &refreshError{Retryable: true, Code: "read_body", Err: err} } var payload tokenResponse if err := jsoncodec.Unmarshal(body, &payload); err != nil { return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err} } if resp.StatusCode >= 400 { code := strings.TrimSpace(payload.Error) if code == "" { code = "http_" + fmt.Sprint(resp.StatusCode) } retryable := resp.StatusCode >= 500 if code == "invalid_grant" || code == "invalid_client" { retryable = false } return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf("%s", strings.TrimSpace(payload.ErrorDescription))} } if strings.TrimSpace(payload.AccessToken) == "" { return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")} } return &payload, nil } func parseAccountID(idToken string) string { parts := strings.Split(idToken, ".") if len(parts) < 2 { return "" } raw, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return "" } var payload struct { AccountID string `json:"account_id"` OrgID string `json:"org_id"` OrganizationID string `json:"organization_id"` } if err := jsoncodec.Unmarshal(raw, &payload); err != nil { return "" } if payload.AccountID != "" { return payload.AccountID } if payload.OrgID != "" { return payload.OrgID } return payload.OrganizationID }