mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
refactor(deps): use foundation shared utilities
This commit is contained in:
@@ -44,6 +44,7 @@ go test ./...
|
|||||||
### 阶段 3(已落地一部分:契约测试)
|
### 阶段 3(已落地一部分:契约测试)
|
||||||
|
|
||||||
- 与 DP 的 provider snapshot schema 契约:`internal/service/testdata/provider_snapshot.json` + SyncProvider 输出回归
|
- 与 DP 的 provider snapshot schema 契约:`internal/service/testdata/provider_snapshot.json` + SyncProvider 输出回归
|
||||||
|
- golden 来自 `github.com/ez-api/foundation/contract`(go:embed),避免 DP/CP 两边复制文件
|
||||||
|
|
||||||
待扩展:
|
待扩展:
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/ez-api/ez-api/internal/dto"
|
"github.com/ez-api/ez-api/internal/dto"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/ez-api/internal/service"
|
"github.com/ez-api/ez-api/internal/service"
|
||||||
|
groupx "github.com/ez-api/foundation/group"
|
||||||
"github.com/ez-api/foundation/provider"
|
"github.com/ez-api/foundation/provider"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -157,7 +158,7 @@ func (h *Handler) UpdateProvider(c *gin.Context) {
|
|||||||
update["models"] = strings.Join(req.Models, ",")
|
update["models"] = strings.Join(req.Models, ",")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.Group) != "" {
|
if strings.TrimSpace(req.Group) != "" {
|
||||||
update["group"] = normalizeGroup(req.Group)
|
update["group"] = groupx.Normalize(req.Group)
|
||||||
}
|
}
|
||||||
if req.AutoBan != nil {
|
if req.AutoBan != nil {
|
||||||
update["auto_ban"] = *req.AutoBan
|
update["auto_ban"] = *req.AutoBan
|
||||||
@@ -354,10 +355,3 @@ func (h *Handler) IngestLog(c *gin.Context) {
|
|||||||
h.logger.Write(rec)
|
h.logger.Write(rec)
|
||||||
c.JSON(http.StatusAccepted, gin.H{"status": "queued"})
|
c.JSON(http.StatusAccepted, gin.H{"status": "queued"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeGroup(group string) string {
|
|
||||||
if strings.TrimSpace(group) == "" {
|
|
||||||
return "default"
|
|
||||||
}
|
|
||||||
return group
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,35 +1,20 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"github.com/ez-api/foundation/requestid"
|
||||||
"encoding/hex"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestID ensures every request has an X-Request-ID and echoes it back to the client.
|
// RequestID ensures every request has an X-Request-ID and echoes it back to the client.
|
||||||
func RequestID() gin.HandlerFunc {
|
func RequestID() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := strings.TrimSpace(c.GetHeader("X-Request-ID"))
|
id := requestid.Extract(c.GetHeader)
|
||||||
if id == "" {
|
if id == "" {
|
||||||
id = strings.TrimSpace(c.GetHeader("X-Request-Id"))
|
id = requestid.New()
|
||||||
}
|
}
|
||||||
if id == "" {
|
c.Request.Header.Set(requestid.HeaderName, id)
|
||||||
id = newRequestID()
|
c.Writer.Header().Set(requestid.HeaderName, id)
|
||||||
}
|
|
||||||
c.Request.Header.Set("X-Request-ID", id)
|
|
||||||
c.Writer.Header().Set("X-Request-ID", id)
|
|
||||||
c.Set("request_id", id)
|
c.Set("request_id", id)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestID() string {
|
|
||||||
var b [16]byte
|
|
||||||
if _, err := rand.Read(b[:]); err != nil {
|
|
||||||
return hex.EncodeToString([]byte(time.Now().Format(time.RFC3339Nano)))
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(b[:])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/ez-api/internal/util"
|
"github.com/ez-api/foundation/tokenhash"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -32,7 +32,7 @@ func (s *MasterService) CreateMaster(name, group string, maxChildKeys, globalQPS
|
|||||||
return nil, "", fmt.Errorf("failed to hash master key: %w", err)
|
return nil, "", fmt.Errorf("failed to hash master key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
masterKeyDigest := util.HashToken(rawMasterKey)
|
masterKeyDigest := tokenhash.HashToken(rawMasterKey)
|
||||||
|
|
||||||
master := &model.Master{
|
master := &model.Master{
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -53,7 +53,7 @@ func (s *MasterService) CreateMaster(name, group string, maxChildKeys, globalQPS
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *MasterService) ValidateMasterKey(masterKey string) (*model.Master, error) {
|
func (s *MasterService) ValidateMasterKey(masterKey string) (*model.Master, error) {
|
||||||
digest := util.HashToken(masterKey)
|
digest := tokenhash.HashToken(masterKey)
|
||||||
|
|
||||||
var master model.Master
|
var master model.Master
|
||||||
if err := s.db.Where("master_key_digest = ?", digest).First(&master).Error; err != nil {
|
if err := s.db.Where("master_key_digest = ?", digest).First(&master).Error; err != nil {
|
||||||
@@ -108,7 +108,7 @@ func (s *MasterService) IssueChildKey(masterID uint, group string, scopes string
|
|||||||
return nil, "", fmt.Errorf("failed to generate child key: %w", err)
|
return nil, "", fmt.Errorf("failed to generate child key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenHash := util.HashToken(rawChildKey)
|
tokenHash := tokenhash.HashToken(rawChildKey)
|
||||||
|
|
||||||
hashedChildKey, err := bcrypt.GenerateFromPassword([]byte(rawChildKey), bcrypt.DefaultCost)
|
hashedChildKey, err := bcrypt.GenerateFromPassword([]byte(rawChildKey), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
"github.com/ez-api/ez-api/internal/util"
|
groupx "github.com/ez-api/foundation/group"
|
||||||
"github.com/ez-api/foundation/jsoncodec"
|
"github.com/ez-api/foundation/jsoncodec"
|
||||||
|
"github.com/ez-api/foundation/tokenhash"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -26,7 +27,7 @@ func (s *SyncService) SyncKey(key *model.Key) error {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tokenHash := key.TokenHash
|
tokenHash := key.TokenHash
|
||||||
if strings.TrimSpace(tokenHash) == "" {
|
if strings.TrimSpace(tokenHash) == "" {
|
||||||
tokenHash = util.HashToken(key.KeySecret) // backward compatibility
|
tokenHash = tokenhash.HashToken(key.KeySecret) // backward compatibility
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(tokenHash) == "" {
|
if strings.TrimSpace(tokenHash) == "" {
|
||||||
return fmt.Errorf("token hash missing for key %d", key.ID)
|
return fmt.Errorf("token hash missing for key %d", key.ID)
|
||||||
@@ -62,7 +63,7 @@ func (s *SyncService) SyncMaster(master *model.Master) error {
|
|||||||
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
// SyncProvider writes a single provider into Redis hash storage and updates routing tables.
|
||||||
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
func (s *SyncService) SyncProvider(provider *model.Provider) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
group := normalizeGroup(provider.Group)
|
group := groupx.Normalize(provider.Group)
|
||||||
models := strings.Split(provider.Models, ",")
|
models := strings.Split(provider.Models, ",")
|
||||||
|
|
||||||
snap := providerSnapshot{
|
snap := providerSnapshot{
|
||||||
@@ -198,7 +199,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
// Ideally, we should scan "route:group:*" and del, but let's just rebuild.
|
||||||
|
|
||||||
for _, p := range providers {
|
for _, p := range providers {
|
||||||
group := normalizeGroup(p.Group)
|
group := groupx.Normalize(p.Group)
|
||||||
models := strings.Split(p.Models, ",")
|
models := strings.Split(p.Models, ",")
|
||||||
|
|
||||||
snap := providerSnapshot{
|
snap := providerSnapshot{
|
||||||
@@ -244,7 +245,7 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
|
|||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
tokenHash := strings.TrimSpace(k.TokenHash)
|
tokenHash := strings.TrimSpace(k.TokenHash)
|
||||||
if tokenHash == "" {
|
if tokenHash == "" {
|
||||||
tokenHash = util.HashToken(k.KeySecret) // fallback for legacy rows
|
tokenHash = tokenhash.HashToken(k.KeySecret) // fallback for legacy rows
|
||||||
}
|
}
|
||||||
if tokenHash == "" {
|
if tokenHash == "" {
|
||||||
return fmt.Errorf("token hash missing for key %d", k.ID)
|
return fmt.Errorf("token hash missing for key %d", k.ID)
|
||||||
@@ -302,13 +303,6 @@ func (s *SyncService) hsetJSON(ctx context.Context, key, field string, val inter
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeGroup(group string) string {
|
|
||||||
if strings.TrimSpace(group) == "" {
|
|
||||||
return "default"
|
|
||||||
}
|
|
||||||
return group
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeStatus(status string) string {
|
func normalizeStatus(status string) string {
|
||||||
st := strings.ToLower(strings.TrimSpace(status))
|
st := strings.ToLower(strings.TrimSpace(status))
|
||||||
if st == "" {
|
if st == "" {
|
||||||
|
|||||||
@@ -2,22 +2,17 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/alicebob/miniredis/v2"
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/ez-api/ez-api/internal/model"
|
"github.com/ez-api/ez-api/internal/model"
|
||||||
|
"github.com/ez-api/foundation/contract"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) {
|
func TestSyncProvider_WritesSnapshotAndRouting(t *testing.T) {
|
||||||
goldenPath := filepath.Join("testdata", "provider_snapshot.json")
|
goldenRaw := contract.ProviderSnapshotJSON()
|
||||||
goldenRaw, err := os.ReadFile(goldenPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("read golden %s: %v", goldenPath, err)
|
|
||||||
}
|
|
||||||
var golden map[string]any
|
var golden map[string]any
|
||||||
if err := json.Unmarshal(goldenRaw, &golden); err != nil {
|
if err := json.Unmarshal(goldenRaw, &golden); err != nil {
|
||||||
t.Fatalf("parse golden json: %v", err)
|
t.Fatalf("parse golden json: %v", err)
|
||||||
|
|||||||
13
internal/service/testdata/provider_snapshot.json
vendored
13
internal/service/testdata/provider_snapshot.json
vendored
@@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"id": 42,
|
|
||||||
"name": "p1",
|
|
||||||
"type": "vertex-express",
|
|
||||||
"base_url": "",
|
|
||||||
"api_key": "",
|
|
||||||
"google_location": "global",
|
|
||||||
"group": "default",
|
|
||||||
"models": ["gemini-3-pro-preview"],
|
|
||||||
"status": "active",
|
|
||||||
"auto_ban": true
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/ez-api/ez-api/internal/util"
|
"github.com/ez-api/foundation/tokenhash"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ type TokenInfo struct {
|
|||||||
// ValidateToken checks a child key against Redis for validity.
|
// ValidateToken checks a child key against Redis for validity.
|
||||||
// This is designed to be called by the data plane (balancer).
|
// This is designed to be called by the data plane (balancer).
|
||||||
func (s *TokenService) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
|
func (s *TokenService) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
|
||||||
tokenHash := util.HashToken(token)
|
tokenHash := tokenhash.HashToken(token)
|
||||||
tokenKey := fmt.Sprintf("auth:token:%s", tokenHash)
|
tokenKey := fmt.Sprintf("auth:token:%s", tokenHash)
|
||||||
|
|
||||||
// 1. Get token metadata from Redis
|
// 1. Get token metadata from Redis
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HashToken(token string) string {
|
|
||||||
hasher := sha256.New()
|
|
||||||
hasher.Write([]byte(token))
|
|
||||||
return hex.EncodeToString(hasher.Sum(nil))
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user