feat(cron): add OAuth token refresh background job

Implement automatic token refresh mechanism for CPA providers (Codex,
GeminiCLI, Antigravity, ClaudeCode) with the following features:

- Periodic refresh of expiring tokens based on configurable interval
- Redis event queue processing for on-demand token refresh
- Retry logic with exponential backoff for transient failures
- Automatic key deactivation on non-retryable errors
- Provider-specific OAuth token refresh implementations
- Sync service integration to update providers after refresh
This commit is contained in:
zenfun
2025-12-28 03:03:19 +08:00
parent f0fe9f0dad
commit 6170931454
2 changed files with 407 additions and 0 deletions

View File

@@ -194,6 +194,18 @@ func main() {
cleanerCtx, cancelCleaner := context.WithCancel(context.Background())
defer cancelCleaner()
go logCleaner.Start(cleanerCtx)
tokenRefresher := cron.NewTokenRefresher(
db,
rdb,
syncService,
time.Duration(cfg.TokenRefresh.IntervalSeconds)*time.Second,
time.Duration(cfg.TokenRefresh.RefreshSkewSeconds)*time.Second,
cfg.TokenRefresh.BatchSize,
cfg.TokenRefresh.MaxRetries,
)
tokenCtx, cancelToken := context.WithCancel(context.Background())
defer cancelToken()
go tokenRefresher.Start(tokenCtx)
adminService, err := service.NewAdminService()
if err != nil {

View File

@@ -0,0 +1,395 @@
package cron
import (
"context"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/service"
"github.com/ez-api/foundation/jsoncodec"
"github.com/ez-api/foundation/provider"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
const refreshEventKey = "events:refresh_provider"
type RefreshEvent struct {
ProviderID uint `json:"provider_id"`
ProviderType string `json:"provider_type,omitempty"`
StatusCode int `json:"status_code,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
}
type TokenRefresher struct {
db *gorm.DB
rdb *redis.Client
sync *service.SyncService
interval time.Duration
refreshSkew time.Duration
batchSize int
maxRetries int
httpClient *http.Client
}
func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, interval, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher {
if interval <= 0 {
interval = 30 * time.Minute
}
if refreshSkew <= 0 {
refreshSkew = 50 * time.Minute
}
if batchSize <= 0 {
batchSize = 200
}
if maxRetries <= 0 {
maxRetries = 3
}
return &TokenRefresher{
db: db,
rdb: rdb,
sync: sync,
interval: interval,
refreshSkew: refreshSkew,
batchSize: batchSize,
maxRetries: maxRetries,
httpClient: &http.Client{Timeout: 15 * time.Second},
}
}
func (t *TokenRefresher) Start(ctx context.Context) {
if t == nil || t.db == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := t.refreshOnce(ctx); err != nil {
slog.Default().Warn("token refresh failed", "err", err)
}
}
}
}
func (t *TokenRefresher) refreshOnce(ctx context.Context) error {
if t == nil || t.db == nil {
return nil
}
if err := t.processRefreshEvents(ctx); err != nil {
slog.Default().Warn("token refresh event handling failed", "err", err)
}
return t.refreshExpiring(ctx)
}
func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error {
if t == nil || t.rdb == nil {
return nil
}
for i := 0; i < t.batchSize; i++ {
raw, err := t.rdb.LPop(ctx, refreshEventKey).Result()
if err == redis.Nil {
return nil
}
if err != nil {
return err
}
var evt RefreshEvent
if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 {
continue
}
if err := t.refreshByID(ctx, evt.ProviderID); err != nil {
slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err)
}
}
return nil
}
func (t *TokenRefresher) refreshExpiring(ctx context.Context) error {
if t == nil || t.db == nil {
return nil
}
cutoff := time.Now().UTC().Add(t.refreshSkew)
var keys []model.APIKey
if err := t.db.WithContext(ctx).
Where("status = ?", "active").
Where("refresh_token <> ''").
Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff).
Limit(t.batchSize).
Find(&keys).Error; err != nil {
return err
}
for i := range keys {
if err := t.refreshKey(ctx, &keys[i]); err != nil {
slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err)
}
}
return nil
}
func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error {
if t == nil || t.db == nil || id == 0 {
return nil
}
var key model.APIKey
if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil {
return err
}
return t.refreshKey(ctx, &key)
}
func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error {
if t == nil || t.db == nil || key == nil {
return nil
}
var group model.ProviderGroup
if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil {
return err
}
ptype := provider.NormalizeType(group.Type)
if !isCPAProvider(ptype) {
return nil
}
if strings.TrimSpace(key.RefreshToken) == "" {
return nil
}
var lastErr error
for attempt := 0; attempt < t.maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(time.Duration(attempt) * time.Second)
}
out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken)
if err == nil {
updates := map[string]any{
"access_token": strings.TrimSpace(out.AccessToken),
"expires_at": out.ExpiresAt,
"status": "active",
}
if strings.TrimSpace(out.RefreshToken) != "" {
updates["refresh_token"] = strings.TrimSpace(out.RefreshToken)
}
if strings.TrimSpace(out.AccountID) != "" {
updates["account_id"] = strings.TrimSpace(out.AccountID)
}
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
return err
}
if t.sync != nil {
_ = t.sync.SyncProvidersForAPIKey(t.db, key.ID)
}
return nil
}
lastErr = err
if rerr, ok := err.(*refreshError); ok && !rerr.Retryable {
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{
"status": "inactive",
"access_token": "",
"expires_at": nil,
}).Error; err != nil {
return err
}
return rerr
}
}
return lastErr
}
func isCPAProvider(ptype string) bool {
switch ptype {
case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode:
return true
default:
return false
}
}
type refreshOutput struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time
AccountID string
}
type refreshError struct {
Retryable bool
Code string
Err error
}
func (e *refreshError) Error() string {
if e == nil {
return ""
}
if e.Code != "" {
return fmt.Sprintf("refresh %s: %v", e.Code, e.Err)
}
return fmt.Sprintf("refresh error: %v", e.Err)
}
func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) {
switch ptype {
case provider.TypeCodex:
return t.refreshCodex(ctx, refreshToken)
case provider.TypeGeminiCLI:
return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret)
case provider.TypeAntigravity:
return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret)
case provider.TypeClaudeCode:
return t.refreshClaude(ctx, refreshToken)
default:
return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)}
}
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
TokenType string `json:"token_type,omitempty"`
IDToken string `json:"id_token,omitempty"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}
const (
codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiCLIClientSecret = "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"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065"
)
func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {codexClientID},
"scope": {"openid profile email"},
}
resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form)
if err != nil {
return nil, err
}
out := &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
AccountID: parseAccountID(resp.IDToken),
}
return out, nil
}
func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {claudeClientID},
}
resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form)
if err != nil {
return nil, err
}
return &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
}, nil
}
func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {clientID},
"client_secret": {clientSecret},
}
resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form)
if err != nil {
return nil, err
}
return &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
}, nil
}
func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, &refreshError{Retryable: true, Code: "build_request", Err: err}
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, &refreshError{Retryable: true, Code: "transport", Err: err}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &refreshError{Retryable: true, Code: "read_body", Err: err}
}
var payload tokenResponse
if err := jsoncodec.Unmarshal(body, &payload); err != nil {
return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err}
}
if resp.StatusCode >= 400 {
code := strings.TrimSpace(payload.Error)
if code == "" {
code = "http_" + fmt.Sprint(resp.StatusCode)
}
retryable := resp.StatusCode >= 500
if code == "invalid_grant" || code == "invalid_client" {
retryable = false
}
return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf(strings.TrimSpace(payload.ErrorDescription))}
}
if strings.TrimSpace(payload.AccessToken) == "" {
return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")}
}
return &payload, nil
}
func parseAccountID(idToken string) string {
parts := strings.Split(idToken, ".")
if len(parts) < 2 {
return ""
}
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
var payload struct {
AccountID string `json:"account_id"`
OrgID string `json:"org_id"`
OrganizationID string `json:"organization_id"`
}
if err := jsoncodec.Unmarshal(raw, &payload); err != nil {
return ""
}
if payload.AccountID != "" {
return payload.AccountID
}
if payload.OrgID != "" {
return payload.OrgID
}
return payload.OrganizationID
}