refactor(deps): use foundation shared utilities

This commit is contained in:
zenfun
2025-12-14 23:52:46 +08:00
parent 71c183a480
commit d1d1b1c42a
9 changed files with 22 additions and 78 deletions

View File

@@ -44,6 +44,7 @@ go test ./...
### 阶段 3已落地一部分契约测试 ### 阶段 3已落地一部分契约测试
- 与 DP 的 provider snapshot schema 契约:`internal/service/testdata/provider_snapshot.json` + SyncProvider 输出回归 - 与 DP 的 provider snapshot schema 契约:`internal/service/testdata/provider_snapshot.json` + SyncProvider 输出回归
- golden 来自 `github.com/ez-api/foundation/contract`go:embed避免 DP/CP 两边复制文件
待扩展: 待扩展:

View File

@@ -8,6 +8,7 @@ import (
"github.com/ez-api/ez-api/internal/dto" "github.com/ez-api/ez-api/internal/dto"
"github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/service" "github.com/ez-api/ez-api/internal/service"
groupx "github.com/ez-api/foundation/group"
"github.com/ez-api/foundation/provider" "github.com/ez-api/foundation/provider"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -157,7 +158,7 @@ func (h *Handler) UpdateProvider(c *gin.Context) {
update["models"] = strings.Join(req.Models, ",") update["models"] = strings.Join(req.Models, ",")
} }
if strings.TrimSpace(req.Group) != "" { if strings.TrimSpace(req.Group) != "" {
update["group"] = normalizeGroup(req.Group) update["group"] = groupx.Normalize(req.Group)
} }
if req.AutoBan != nil { if req.AutoBan != nil {
update["auto_ban"] = *req.AutoBan update["auto_ban"] = *req.AutoBan
@@ -354,10 +355,3 @@ func (h *Handler) IngestLog(c *gin.Context) {
h.logger.Write(rec) h.logger.Write(rec)
c.JSON(http.StatusAccepted, gin.H{"status": "queued"}) c.JSON(http.StatusAccepted, gin.H{"status": "queued"})
} }
func normalizeGroup(group string) string {
if strings.TrimSpace(group) == "" {
return "default"
}
return group
}

View File

@@ -1,35 +1,20 @@
package middleware package middleware
import ( import (
"crypto/rand" "github.com/ez-api/foundation/requestid"
"encoding/hex"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// RequestID ensures every request has an X-Request-ID and echoes it back to the client. // RequestID ensures every request has an X-Request-ID and echoes it back to the client.
func RequestID() gin.HandlerFunc { func RequestID() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
id := strings.TrimSpace(c.GetHeader("X-Request-ID")) id := requestid.Extract(c.GetHeader)
if id == "" { if id == "" {
id = strings.TrimSpace(c.GetHeader("X-Request-Id")) id = requestid.New()
} }
if id == "" { c.Request.Header.Set(requestid.HeaderName, id)
id = newRequestID() c.Writer.Header().Set(requestid.HeaderName, id)
}
c.Request.Header.Set("X-Request-ID", id)
c.Writer.Header().Set("X-Request-ID", id)
c.Set("request_id", id) c.Set("request_id", id)
c.Next() c.Next()
} }
} }
func newRequestID() string {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return hex.EncodeToString([]byte(time.Now().Format(time.RFC3339Nano)))
}
return hex.EncodeToString(b[:])
}

View File

@@ -8,7 +8,7 @@ import (
"strings" "strings"
"github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/util" "github.com/ez-api/foundation/tokenhash"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,7 +32,7 @@ func (s *MasterService) CreateMaster(name, group string, maxChildKeys, globalQPS
return nil, "", fmt.Errorf("failed to hash master key: %w", err) return nil, "", fmt.Errorf("failed to hash master key: %w", err)
} }
masterKeyDigest := util.HashToken(rawMasterKey) masterKeyDigest := tokenhash.HashToken(rawMasterKey)
master := &model.Master{ master := &model.Master{
Name: name, Name: name,
@@ -53,7 +53,7 @@ func (s *MasterService) CreateMaster(name, group string, maxChildKeys, globalQPS
} }
func (s *MasterService) ValidateMasterKey(masterKey string) (*model.Master, error) { func (s *MasterService) ValidateMasterKey(masterKey string) (*model.Master, error) {
digest := util.HashToken(masterKey) digest := tokenhash.HashToken(masterKey)
var master model.Master var master model.Master
if err := s.db.Where("master_key_digest = ?", digest).First(&master).Error; err != nil { if err := s.db.Where("master_key_digest = ?", digest).First(&master).Error; err != nil {
@@ -108,7 +108,7 @@ func (s *MasterService) IssueChildKey(masterID uint, group string, scopes string
return nil, "", fmt.Errorf("failed to generate child key: %w", err) return nil, "", fmt.Errorf("failed to generate child key: %w", err)
} }
tokenHash := util.HashToken(rawChildKey) tokenHash := tokenhash.HashToken(rawChildKey)
hashedChildKey, err := bcrypt.GenerateFromPassword([]byte(rawChildKey), bcrypt.DefaultCost) hashedChildKey, err := bcrypt.GenerateFromPassword([]byte(rawChildKey), bcrypt.DefaultCost)
if err != nil { if err != nil {

View File

@@ -7,8 +7,9 @@ import (
"time" "time"
"github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/util" groupx "github.com/ez-api/foundation/group"
"github.com/ez-api/foundation/jsoncodec" "github.com/ez-api/foundation/jsoncodec"
"github.com/ez-api/foundation/tokenhash"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -26,7 +27,7 @@ func (s *SyncService) SyncKey(key *model.Key) error {
ctx := context.Background() ctx := context.Background()
tokenHash := key.TokenHash tokenHash := key.TokenHash
if strings.TrimSpace(tokenHash) == "" { if strings.TrimSpace(tokenHash) == "" {
tokenHash = util.HashToken(key.KeySecret) // backward compatibility tokenHash = tokenhash.HashToken(key.KeySecret) // backward compatibility
} }
if strings.TrimSpace(tokenHash) == "" { if strings.TrimSpace(tokenHash) == "" {
return fmt.Errorf("token hash missing for key %d", key.ID) return fmt.Errorf("token hash missing for key %d", key.ID)
@@ -62,7 +63,7 @@ func (s *SyncService) SyncMaster(master *model.Master) error {
// SyncProvider writes a single provider into Redis hash storage and updates routing tables. // SyncProvider writes a single provider into Redis hash storage and updates routing tables.
func (s *SyncService) SyncProvider(provider *model.Provider) error { func (s *SyncService) SyncProvider(provider *model.Provider) error {
ctx := context.Background() ctx := context.Background()
group := normalizeGroup(provider.Group) group := groupx.Normalize(provider.Group)
models := strings.Split(provider.Models, ",") models := strings.Split(provider.Models, ",")
snap := providerSnapshot{ snap := providerSnapshot{
@@ -198,7 +199,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
// Ideally, we should scan "route:group:*" and del, but let's just rebuild. // Ideally, we should scan "route:group:*" and del, but let's just rebuild.
for _, p := range providers { for _, p := range providers {
group := normalizeGroup(p.Group) group := groupx.Normalize(p.Group)
models := strings.Split(p.Models, ",") models := strings.Split(p.Models, ",")
snap := providerSnapshot{ snap := providerSnapshot{
@@ -244,7 +245,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
for _, k := range keys { for _, k := range keys {
tokenHash := strings.TrimSpace(k.TokenHash) tokenHash := strings.TrimSpace(k.TokenHash)
if tokenHash == "" { if tokenHash == "" {
tokenHash = util.HashToken(k.KeySecret) // fallback for legacy rows tokenHash = tokenhash.HashToken(k.KeySecret) // fallback for legacy rows
} }
if tokenHash == "" { if tokenHash == "" {
return fmt.Errorf("token hash missing for key %d", k.ID) return fmt.Errorf("token hash missing for key %d", k.ID)
@@ -302,13 +303,6 @@ func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val inter
return nil return nil
} }
func normalizeGroup(group string) string {
if strings.TrimSpace(group) == "" {
return "default"
}
return group
}
func normalizeStatus(status string) string { func normalizeStatus(status string) string {
st := strings.ToLower(strings.TrimSpace(status)) st := strings.ToLower(strings.TrimSpace(status))
if st == "" { if st == "" {

View File

@@ -2,22 +2,17 @@ package service
import ( import (
"encoding/json" "encoding/json"
"os"
"path/filepath"
"reflect" "reflect"
"testing" "testing"
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/foundation/contract"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) { func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) {
goldenPath := filepath.Join("testdata", "provider_snapshot.json") goldenRaw := contract.ProviderSnapshotJSON()
goldenRaw, err := os.ReadFile(goldenPath)
if err != nil {
t.Fatalf("read golden %s: %v", goldenPath, err)
}
var golden map[string]any var golden map[string]any
if err := json.Unmarshal(goldenRaw, &golden); err != nil { if err := json.Unmarshal(goldenRaw, &golden); err != nil {
t.Fatalf("parse golden json: %v", err) t.Fatalf("parse golden json: %v", err)

View File

@@ -1,13 +0,0 @@
{
"id": 42,
"name": "p1",
"type": "vertex-express",
"base_url": "",
"api_key": "",
"google_location": "global",
"group": "default",
"models": ["gemini-3-pro-preview"],
"status": "active",
"auto_ban": true
}

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/ez-api/ez-api/internal/util" "github.com/ez-api/foundation/tokenhash"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -28,7 +28,7 @@ type TokenInfo struct {
// ValidateToken checks a child key against Redis for validity. // ValidateToken checks a child key against Redis for validity.
// This is designed to be called by the data plane (balancer). // This is designed to be called by the data plane (balancer).
func (s *TokenService) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) { func (s *TokenService) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
tokenHash := util.HashToken(token) tokenHash := tokenhash.HashToken(token)
tokenKey := fmt.Sprintf("auth:token:%s", tokenHash) tokenKey := fmt.Sprintf("auth:token:%s", tokenHash)
// 1. Get token metadata from Redis // 1. Get token metadata from Redis

View File

@@ -1,12 +0,0 @@
package util
import (
"crypto/sha256"
"encoding/hex"
)
func HashToken(token string) string {
hasher := sha256.New()
hasher.Write([]byte(token))
return hex.EncodeToString(hasher.Sum(nil))
}