mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
feat(cron): add OAuth token refresh background job
Implement automatic token refresh mechanism for CPA providers (Codex, GeminiCLI, Antigravity, ClaudeCode) with the following features: - Periodic refresh of expiring tokens based on configurable interval - Redis event queue processing for on-demand token refresh - Retry logic with exponential backoff for transient failures - Automatic key deactivation on non-retryable errors - Provider-specific OAuth token refresh implementations - Sync service integration to update providers after refresh
This commit is contained in:
395
internal/cron/token_refresh.go
Normal file
395
internal/cron/token_refresh.go
Normal file
@@ -0,0 +1,395 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ez-api/ez-api/internal/model"
|
||||
"github.com/ez-api/ez-api/internal/service"
|
||||
"github.com/ez-api/foundation/jsoncodec"
|
||||
"github.com/ez-api/foundation/provider"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const refreshEventKey = "events:refresh_provider"
|
||||
|
||||
type RefreshEvent struct {
|
||||
ProviderID uint `json:"provider_id"`
|
||||
ProviderType string `json:"provider_type,omitempty"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
Timestamp int64 `json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
type TokenRefresher struct {
|
||||
db *gorm.DB
|
||||
rdb *redis.Client
|
||||
sync *service.SyncService
|
||||
interval time.Duration
|
||||
refreshSkew time.Duration
|
||||
batchSize int
|
||||
maxRetries int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, interval, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher {
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Minute
|
||||
}
|
||||
if refreshSkew <= 0 {
|
||||
refreshSkew = 50 * time.Minute
|
||||
}
|
||||
if batchSize <= 0 {
|
||||
batchSize = 200
|
||||
}
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = 3
|
||||
}
|
||||
return &TokenRefresher{
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
sync: sync,
|
||||
interval: interval,
|
||||
refreshSkew: refreshSkew,
|
||||
batchSize: batchSize,
|
||||
maxRetries: maxRetries,
|
||||
httpClient: &http.Client{Timeout: 15 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) Start(ctx context.Context) {
|
||||
if t == nil || t.db == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := t.refreshOnce(ctx); err != nil {
|
||||
slog.Default().Warn("token refresh failed", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshOnce(ctx context.Context) error {
|
||||
if t == nil || t.db == nil {
|
||||
return nil
|
||||
}
|
||||
if err := t.processRefreshEvents(ctx); err != nil {
|
||||
slog.Default().Warn("token refresh event handling failed", "err", err)
|
||||
}
|
||||
return t.refreshExpiring(ctx)
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error {
|
||||
if t == nil || t.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < t.batchSize; i++ {
|
||||
raw, err := t.rdb.LPop(ctx, refreshEventKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var evt RefreshEvent
|
||||
if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 {
|
||||
continue
|
||||
}
|
||||
if err := t.refreshByID(ctx, evt.ProviderID); err != nil {
|
||||
slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshExpiring(ctx context.Context) error {
|
||||
if t == nil || t.db == nil {
|
||||
return nil
|
||||
}
|
||||
cutoff := time.Now().UTC().Add(t.refreshSkew)
|
||||
var keys []model.APIKey
|
||||
if err := t.db.WithContext(ctx).
|
||||
Where("status = ?", "active").
|
||||
Where("refresh_token <> ''").
|
||||
Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff).
|
||||
Limit(t.batchSize).
|
||||
Find(&keys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := t.refreshKey(ctx, &keys[i]); err != nil {
|
||||
slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error {
|
||||
if t == nil || t.db == nil || id == 0 {
|
||||
return nil
|
||||
}
|
||||
var key model.APIKey
|
||||
if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return t.refreshKey(ctx, &key)
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error {
|
||||
if t == nil || t.db == nil || key == nil {
|
||||
return nil
|
||||
}
|
||||
var group model.ProviderGroup
|
||||
if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
ptype := provider.NormalizeType(group.Type)
|
||||
if !isCPAProvider(ptype) {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(key.RefreshToken) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < t.maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
}
|
||||
out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken)
|
||||
if err == nil {
|
||||
updates := map[string]any{
|
||||
"access_token": strings.TrimSpace(out.AccessToken),
|
||||
"expires_at": out.ExpiresAt,
|
||||
"status": "active",
|
||||
}
|
||||
if strings.TrimSpace(out.RefreshToken) != "" {
|
||||
updates["refresh_token"] = strings.TrimSpace(out.RefreshToken)
|
||||
}
|
||||
if strings.TrimSpace(out.AccountID) != "" {
|
||||
updates["account_id"] = strings.TrimSpace(out.AccountID)
|
||||
}
|
||||
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if t.sync != nil {
|
||||
_ = t.sync.SyncProvidersForAPIKey(t.db, key.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
if rerr, ok := err.(*refreshError); ok && !rerr.Retryable {
|
||||
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{
|
||||
"status": "inactive",
|
||||
"access_token": "",
|
||||
"expires_at": nil,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func isCPAProvider(ptype string) bool {
|
||||
switch ptype {
|
||||
case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type refreshOutput struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
AccountID string
|
||||
}
|
||||
|
||||
type refreshError struct {
|
||||
Retryable bool
|
||||
Code string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *refreshError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("refresh %s: %v", e.Code, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("refresh error: %v", e.Err)
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) {
|
||||
switch ptype {
|
||||
case provider.TypeCodex:
|
||||
return t.refreshCodex(ctx, refreshToken)
|
||||
case provider.TypeGeminiCLI:
|
||||
return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret)
|
||||
case provider.TypeAntigravity:
|
||||
return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret)
|
||||
case provider.TypeClaudeCode:
|
||||
return t.refreshClaude(ctx, refreshToken)
|
||||
default:
|
||||
return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)}
|
||||
}
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiCLIClientSecret = "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"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065"
|
||||
)
|
||||
|
||||
func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {codexClientID},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &refreshOutput{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
||||
AccountID: parseAccountID(resp.IDToken),
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {claudeClientID},
|
||||
}
|
||||
resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refreshOutput{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
}
|
||||
resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &refreshOutput{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, &refreshError{Retryable: true, Code: "build_request", Err: err}
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &refreshError{Retryable: true, Code: "transport", Err: err}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, &refreshError{Retryable: true, Code: "read_body", Err: err}
|
||||
}
|
||||
var payload tokenResponse
|
||||
if err := jsoncodec.Unmarshal(body, &payload); err != nil {
|
||||
return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err}
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
code := strings.TrimSpace(payload.Error)
|
||||
if code == "" {
|
||||
code = "http_" + fmt.Sprint(resp.StatusCode)
|
||||
}
|
||||
retryable := resp.StatusCode >= 500
|
||||
if code == "invalid_grant" || code == "invalid_client" {
|
||||
retryable = false
|
||||
}
|
||||
return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf(strings.TrimSpace(payload.ErrorDescription))}
|
||||
}
|
||||
if strings.TrimSpace(payload.AccessToken) == "" {
|
||||
return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")}
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
func parseAccountID(idToken string) string {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var payload struct {
|
||||
AccountID string `json:"account_id"`
|
||||
OrgID string `json:"org_id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
}
|
||||
if err := jsoncodec.Unmarshal(raw, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
if payload.AccountID != "" {
|
||||
return payload.AccountID
|
||||
}
|
||||
if payload.OrgID != "" {
|
||||
return payload.OrgID
|
||||
}
|
||||
return payload.OrganizationID
|
||||
}
|
||||
Reference in New Issue
Block a user