Files
ez-api/internal/cron/token_refresh.go
zenfun 05caed37c2 refactor(cron): migrate cron jobs to foundation scheduler
Replace custom goroutine-based scheduling in cron jobs with centralized
foundation scheduler. Each cron job now exposes a RunOnce method called
by the scheduler instead of managing its own ticker loop.

Changes:
- Remove interval/enabled config from cron job structs
- Convert Start() methods to RunOnce() for all cron jobs
- Add scheduler setup in main.go with configurable intervals
- Update foundation dependency to v0.6.0 for scheduler support
- Update tests to validate RunOnce nil-safety
2025-12-31 20:42:25 +08:00

379 lines
11 KiB
Go

package cron
import (
"context"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/service"
"github.com/ez-api/foundation/jsoncodec"
"github.com/ez-api/foundation/provider"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
const refreshEventKey = "events:refresh_provider"
type RefreshEvent struct {
ProviderID uint `json:"provider_id"`
ProviderType string `json:"provider_type,omitempty"`
StatusCode int `json:"status_code,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
}
type TokenRefresher struct {
db *gorm.DB
rdb *redis.Client
sync *service.SyncService
refreshSkew time.Duration
batchSize int
maxRetries int
httpClient *http.Client
}
func NewTokenRefresher(db *gorm.DB, rdb *redis.Client, sync *service.SyncService, refreshSkew time.Duration, batchSize, maxRetries int) *TokenRefresher {
if refreshSkew <= 0 {
refreshSkew = 50 * time.Minute
}
if batchSize <= 0 {
batchSize = 200
}
if maxRetries <= 0 {
maxRetries = 3
}
return &TokenRefresher{
db: db,
rdb: rdb,
sync: sync,
refreshSkew: refreshSkew,
batchSize: batchSize,
maxRetries: maxRetries,
httpClient: &http.Client{Timeout: 15 * time.Second},
}
}
// RunOnce executes a single token refresh cycle. Called by scheduler.
func (t *TokenRefresher) RunOnce(ctx context.Context) {
if t == nil || t.db == nil {
return
}
if err := t.refreshOnce(ctx); err != nil {
slog.Default().Warn("token refresh failed", "err", err)
}
}
func (t *TokenRefresher) refreshOnce(ctx context.Context) error {
if t == nil || t.db == nil {
return nil
}
if err := t.processRefreshEvents(ctx); err != nil {
slog.Default().Warn("token refresh event handling failed", "err", err)
}
return t.refreshExpiring(ctx)
}
func (t *TokenRefresher) processRefreshEvents(ctx context.Context) error {
if t == nil || t.rdb == nil {
return nil
}
for i := 0; i < t.batchSize; i++ {
raw, err := t.rdb.LPop(ctx, refreshEventKey).Result()
if err == redis.Nil {
return nil
}
if err != nil {
return err
}
var evt RefreshEvent
if err := jsoncodec.Unmarshal([]byte(raw), &evt); err != nil || evt.ProviderID == 0 {
continue
}
if err := t.refreshByID(ctx, evt.ProviderID); err != nil {
slog.Default().Warn("token refresh event failed", "provider_id", evt.ProviderID, "err", err)
}
}
return nil
}
func (t *TokenRefresher) refreshExpiring(ctx context.Context) error {
if t == nil || t.db == nil {
return nil
}
cutoff := time.Now().UTC().Add(t.refreshSkew)
var keys []model.APIKey
if err := t.db.WithContext(ctx).
Where("status = ?", "active").
Where("refresh_token <> ''").
Where("expires_at IS NOT NULL AND expires_at <= ?", cutoff).
Limit(t.batchSize).
Find(&keys).Error; err != nil {
return err
}
for i := range keys {
if err := t.refreshKey(ctx, &keys[i]); err != nil {
slog.Default().Warn("token refresh failed", "key_id", keys[i].ID, "err", err)
}
}
return nil
}
func (t *TokenRefresher) refreshByID(ctx context.Context, id uint) error {
if t == nil || t.db == nil || id == 0 {
return nil
}
var key model.APIKey
if err := t.db.WithContext(ctx).First(&key, id).Error; err != nil {
return err
}
return t.refreshKey(ctx, &key)
}
func (t *TokenRefresher) refreshKey(ctx context.Context, key *model.APIKey) error {
if t == nil || t.db == nil || key == nil {
return nil
}
var group model.ProviderGroup
if err := t.db.WithContext(ctx).First(&group, key.GroupID).Error; err != nil {
return err
}
ptype := provider.NormalizeType(group.Type)
if !isCPAProvider(ptype) {
return nil
}
if strings.TrimSpace(key.RefreshToken) == "" {
return nil
}
var lastErr error
for attempt := 0; attempt < t.maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(time.Duration(attempt) * time.Second)
}
out, err := t.refreshAccessToken(ctx, ptype, key.RefreshToken)
if err == nil {
updates := map[string]any{
"access_token": strings.TrimSpace(out.AccessToken),
"expires_at": out.ExpiresAt,
"status": "active",
}
if strings.TrimSpace(out.RefreshToken) != "" {
updates["refresh_token"] = strings.TrimSpace(out.RefreshToken)
}
if strings.TrimSpace(out.AccountID) != "" {
updates["account_id"] = strings.TrimSpace(out.AccountID)
}
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
return err
}
if t.sync != nil {
_ = t.sync.SyncProvidersForAPIKey(t.db, key.ID)
}
return nil
}
lastErr = err
if rerr, ok := err.(*refreshError); ok && !rerr.Retryable {
if err := t.db.WithContext(ctx).Model(&model.APIKey{}).Where("id = ?", key.ID).Updates(map[string]any{
"status": "inactive",
"access_token": "",
"expires_at": nil,
}).Error; err != nil {
return err
}
return rerr
}
}
return lastErr
}
func isCPAProvider(ptype string) bool {
switch ptype {
case provider.TypeCodex, provider.TypeGeminiCLI, provider.TypeAntigravity, provider.TypeClaudeCode:
return true
default:
return false
}
}
type refreshOutput struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time
AccountID string
}
type refreshError struct {
Retryable bool
Code string
Err error
}
func (e *refreshError) Error() string {
if e == nil {
return ""
}
if e.Code != "" {
return fmt.Sprintf("refresh %s: %v", e.Code, e.Err)
}
return fmt.Sprintf("refresh error: %v", e.Err)
}
func (t *TokenRefresher) refreshAccessToken(ctx context.Context, ptype, refreshToken string) (*refreshOutput, error) {
switch ptype {
case provider.TypeCodex:
return t.refreshCodex(ctx, refreshToken)
case provider.TypeGeminiCLI:
return t.refreshGoogle(ctx, refreshToken, geminiCLIClientID, geminiCLIClientSecret)
case provider.TypeAntigravity:
return t.refreshGoogle(ctx, refreshToken, antigravityClientID, antigravityClientSecret)
case provider.TypeClaudeCode:
return t.refreshClaude(ctx, refreshToken)
default:
return nil, &refreshError{Retryable: false, Code: "unsupported_provider", Err: fmt.Errorf("provider type %s unsupported", ptype)}
}
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
TokenType string `json:"token_type,omitempty"`
IDToken string `json:"id_token,omitempty"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}
const (
codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
geminiCLIClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiCLIClientSecret = "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"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
claudeClientID = "9d1c250a-e61b-44d3-bcd4-8fbe4b736065"
)
func (t *TokenRefresher) refreshCodex(ctx context.Context, refreshToken string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {codexClientID},
"scope": {"openid profile email"},
}
resp, err := t.postForm(ctx, "https://auth.openai.com/oauth/token", form)
if err != nil {
return nil, err
}
out := &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
AccountID: parseAccountID(resp.IDToken),
}
return out, nil
}
func (t *TokenRefresher) refreshClaude(ctx context.Context, refreshToken string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {claudeClientID},
}
resp, err := t.postForm(ctx, "https://claude.ai/oauth2/token", form)
if err != nil {
return nil, err
}
return &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
}, nil
}
func (t *TokenRefresher) refreshGoogle(ctx context.Context, refreshToken, clientID, clientSecret string) (*refreshOutput, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {clientID},
"client_secret": {clientSecret},
}
resp, err := t.postForm(ctx, "https://oauth2.googleapis.com/token", form)
if err != nil {
return nil, err
}
return &refreshOutput{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresAt: time.Now().UTC().Add(time.Duration(resp.ExpiresIn) * time.Second),
}, nil
}
func (t *TokenRefresher) postForm(ctx context.Context, endpoint string, form url.Values) (*tokenResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, &refreshError{Retryable: true, Code: "build_request", Err: err}
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, &refreshError{Retryable: true, Code: "transport", Err: err}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &refreshError{Retryable: true, Code: "read_body", Err: err}
}
var payload tokenResponse
if err := jsoncodec.Unmarshal(body, &payload); err != nil {
return nil, &refreshError{Retryable: true, Code: "parse_body", Err: err}
}
if resp.StatusCode >= 400 {
code := strings.TrimSpace(payload.Error)
if code == "" {
code = "http_" + fmt.Sprint(resp.StatusCode)
}
retryable := resp.StatusCode >= 500
if code == "invalid_grant" || code == "invalid_client" {
retryable = false
}
return nil, &refreshError{Retryable: retryable, Code: code, Err: fmt.Errorf("%s", strings.TrimSpace(payload.ErrorDescription))}
}
if strings.TrimSpace(payload.AccessToken) == "" {
return nil, &refreshError{Retryable: true, Code: "empty_token", Err: fmt.Errorf("missing access_token")}
}
return &payload, nil
}
func parseAccountID(idToken string) string {
parts := strings.Split(idToken, ".")
if len(parts) < 2 {
return ""
}
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
var payload struct {
AccountID string `json:"account_id"`
OrgID string `json:"org_id"`
OrganizationID string `json:"organization_id"`
}
if err := jsoncodec.Unmarshal(raw, &payload); err != nil {
return ""
}
if payload.AccountID != "" {
return payload.AccountID
}
if payload.OrgID != "" {
return payload.OrgID
}
return payload.OrganizationID
}