diff --git a/cmd/server/main.go b/cmd/server/main.go index a704aef..4136bfc 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -194,6 +194,18 @@ func main() { cleanerCtx, cancelCleaner := context.WithCancel(context.Background()) defer cancelCleaner() go logCleaner.Start(cleanerCtx) + tokenRefresher := cron.NewTokenRefresher( + db, + rdb, + syncService, + time.Duration(cfg.TokenRefresh.IntervalSeconds)*time.Second, + time.Duration(cfg.TokenRefresh.RefreshSkewSeconds)*time.Second, + cfg.TokenRefresh.BatchSize, + cfg.TokenRefresh.MaxRetries, + ) + tokenCtx, cancelToken := context.WithCancel(context.Background()) + defer cancelToken() + go tokenRefresher.Start(tokenCtx) adminService, err := service.NewAdminService() if err != nil { diff --git a/internal/cron/token_refresh.go b/internal/cron/token_refresh.go new file mode 100644 index 0000000..1e75487 --- /dev/null +++ b/internal/cron/token_refresh.go @@ -0,0 +1,395 @@ +package cron + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/ez-api/internal/service" + "github.com/ez-api/foundation/jsoncodec" + "github.com/ez-api/foundation/provider" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +const refreshEventKey = "events:refresh_provider" + +type RefreshEvent struct { + ProviderID uint `json:"provider_id"` + ProviderType string `json:"provider_type,omitempty"` + StatusCode int `json:"status_code,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +type TokenRefresher struct { + db *gorm.DB + rdb *redis.Client + sync *service.SyncService + interval time.Duration + refreshSkew time.Duration + batchSize int + maxRetries int + httpClient *http.Client +} + +func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, interval, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher { + if interval <= 0 { + interval = 30 * time.Minute + } + if refreshSkew <= 0 { + refreshSkew = 50 * time.Minute + } + if batchSize <= 0 { + batchSize = 200 + } + if maxRetries <= 0 { + maxRetries = 3 + } + return &TokenRefresher{ + db: db, + rdb: rdb, + sync: sync, + interval: interval, + refreshSkew: refreshSkew, + batchSize: batchSize, + maxRetries: maxRetries, + httpClient: &http.Client{Timeout: 15 * time.Second}, + } +} + +func (t *TokenRefresher) Start(ctx context.Context) { + if t == nil || t.db == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + ticker := time.NewTicker(t.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := t.refreshOnce(ctx); err != nil { + slog.Default().Warn("token refresh failed", "err", err) + } + } + } +} + +func (t *TokenRefresher) refreshOnce(ctx context.Context) error { + if t == nil || t.db == nil { + return nil + } + if err := t.processRefreshEvents(ctx); err != nil { + slog.Default().Warn("token refresh event handling failed", "err", err) + } + return t.refreshExpiring(ctx) +} + +func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error { + if t == nil || t.rdb == nil { + return nil + } + for i := 0; i < t.batchSize; i++ { + raw, err := t.rdb.LPop(ctx, refreshEventKey).Result() + if err == redis.Nil { + return nil + } + if err != nil { + return err + } + var evt RefreshEvent + if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 { + continue + } + if err := t.refreshByID(ctx, evt.ProviderID); err != nil { + slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err) + } + } + return nil +} + +func (t *TokenRefresher) refreshExpiring(ctx context.Context) error { + if t == nil || t.db == nil { + return nil + } + cutoff := time.Now().UTC().Add(t.refreshSkew) + var keys []model.APIKey + if err := t.db.WithContext(ctx). + Where("status = ?", "active"). + Where("refresh_token <> ''"). + Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff). + Limit(t.batchSize). + Find(&keys).Error; err != nil { + return err + } + for i := range keys { + if err := t.refreshKey(ctx, &keys[i]); err != nil { + slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err) + } + } + return nil +} + +func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error { + if t == nil || t.db == nil || id == 0 { + return nil + } + var key model.APIKey + if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil { + return err + } + return t.refreshKey(ctx, &key) +} + +func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error { + if t == nil || t.db == nil || key == nil { + return nil + } + var group model.ProviderGroup + if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil { + return err + } + ptype := provider.NormalizeType(group.Type) + if !isCPAProvider(ptype) { + return nil + } + if strings.TrimSpace(key.RefreshToken) == "" { + return nil + } + + var lastErr error + for attempt := 0; attempt < t.maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(time.Duration(attempt) * time.Second) + } + out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken) + if err == nil { + updates := map[string]any{ + "access_token": strings.TrimSpace(out.AccessToken), + "expires_at": out.ExpiresAt, + "status": "active", + } + if strings.TrimSpace(out.RefreshToken) != "" { + updates["refresh_token"] = strings.TrimSpace(out.RefreshToken) + } + if strings.TrimSpace(out.AccountID) != "" { + updates["account_id"] = strings.TrimSpace(out.AccountID) + } + if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil { + return err + } + if t.sync != nil { + _ = t.sync.SyncProvidersForAPIKey(t.db, key.ID) + } + return nil + } + lastErr = err + if rerr, ok := err.(*refreshError); ok && !rerr.Retryable { + if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{ + "status": "inactive", + "access_token": "", + "expires_at": nil, + }).Error; err != nil { + return err + } + return rerr + } + } + return lastErr +} + +func isCPAProvider(ptype string) bool { + switch ptype { + case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode: + return true + default: + return false + } +} + +type refreshOutput struct { + AccessToken string + RefreshToken string + ExpiresAt time.Time + AccountID string +} + +type refreshError struct { + Retryable bool + Code string + Err error +} + +func (e *refreshError) Error() string { + if e == nil { + return "" + } + if e.Code != "" { + return fmt.Sprintf("refresh %s: %v", e.Code, e.Err) + } + return fmt.Sprintf("refresh error: %v", e.Err) +} + +func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) { + switch ptype { + case provider.TypeCodex: + return t.refreshCodex(ctx, refreshToken) + case provider.TypeGeminiCLI: + return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret) + case provider.TypeAntigravity: + return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret) + case provider.TypeClaudeCode: + return t.refreshClaude(ctx, refreshToken) + default: + return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)} + } +} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + TokenType string `json:"token_type,omitempty"` + IDToken string `json:"id_token,omitempty"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +const ( + codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + geminiCLIClientSecret = "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" + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065" +) + +func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) { + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {codexClientID}, + "scope": {"openid profile email"}, + } + resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form) + if err != nil { + return nil, err + } + out := &refreshOutput{ + AccessToken: resp.AccessToken, + RefreshToken: resp.RefreshToken, + ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), + AccountID: parseAccountID(resp.IDToken), + } + return out, nil +} + +func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) { + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {claudeClientID}, + } + resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form) + if err != nil { + return nil, err + } + return &refreshOutput{ + AccessToken: resp.AccessToken, + RefreshToken: resp.RefreshToken, + ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), + }, nil +} + +func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) { + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {clientID}, + "client_secret": {clientSecret}, + } + resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form) + if err != nil { + return nil, err + } + return &refreshOutput{ + AccessToken: resp.AccessToken, + RefreshToken: resp.RefreshToken, + ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second), + }, nil +} + +func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, &refreshError{Retryable: true, Code: "build_request", Err: err} + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, &refreshError{Retryable: true, Code: "transport", Err: err} + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &refreshError{Retryable: true, Code: "read_body", Err: err} + } + var payload tokenResponse + if err := jsoncodec.Unmarshal(body, &payload); err != nil { + return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err} + } + if resp.StatusCode >= 400 { + code := strings.TrimSpace(payload.Error) + if code == "" { + code = "http_" + fmt.Sprint(resp.StatusCode) + } + retryable := resp.StatusCode >= 500 + if code == "invalid_grant" || code == "invalid_client" { + retryable = false + } + return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf(strings.TrimSpace(payload.ErrorDescription))} + } + if strings.TrimSpace(payload.AccessToken) == "" { + return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")} + } + return &payload, nil +} + +func parseAccountID(idToken string) string { + parts := strings.Split(idToken, ".") + if len(parts) < 2 { + return "" + } + raw, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + var payload struct { + AccountID string `json:"account_id"` + OrgID string `json:"org_id"` + OrganizationID string `json:"organization_id"` + } + if err := jsoncodec.Unmarshal(raw, &payload); err != nil { + return "" + } + if payload.AccountID != "" { + return payload.AccountID + } + if payload.OrgID != "" { + return payload.OrgID + } + return payload.OrganizationID +}