Files
ez-api/internal/migrate/importer.go
zenfun dea8363e41 refactor(api): split Provider into ProviderGroup and APIKey models
Restructure the provider management system by separating the monolithic
Provider model into two distinct entities:

- ProviderGroup: defines shared upstream configuration (type, base_url,
  google settings, models, status)
- APIKey: represents individual credentials within a group (api_key,
  weight, status, auto_ban, ban settings)

This change also updates:
- Binding model to reference GroupID instead of RouteGroup string
- All CRUD handlers for the new provider-group and api-key endpoints
- Sync service to rebuild provider snapshots from joined tables
- Model registry to aggregate capabilities across group/key pairs
- Access handler to validate namespace existence and subset constraints
- Migration importer to handle the new schema structure
- All related tests to use the new model relationships

BREAKING CHANGE: Provider API endpoints replaced with /provider-groups
and /api-keys endpoints; Binding.RouteGroup replaced with Binding.GroupID
2025-12-24 02:15:52 +08:00

633 lines
16 KiB
Go

package migrate
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/foundation/tokenhash"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
const (
ConflictSkip = "skip"
ConflictOverwrite = "overwrite"
)
type ImportOptions struct {
DryRun bool
ConflictPolicy string
IncludeBindings bool
}
type MasterCredential struct {
Name string `json:"name"`
MasterKey string `json:"master_key"`
}
type ImportSummary struct {
ProvidersCreated int `json:"providers_created"`
ProvidersUpdated int `json:"providers_updated"`
ProvidersSkipped int `json:"providers_skipped"`
MastersCreated int `json:"masters_created"`
MastersUpdated int `json:"masters_updated"`
MastersSkipped int `json:"masters_skipped"`
KeysCreated int `json:"keys_created"`
KeysUpdated int `json:"keys_updated"`
KeysSkipped int `json:"keys_skipped"`
BindingsCreated int `json:"bindings_created"`
BindingsUpdated int `json:"bindings_updated"`
BindingsSkipped int `json:"bindings_skipped"`
NamespacesCreated int `json:"namespaces_created"`
MasterKeys []MasterCredential `json:"master_keys,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
type Importer struct {
db *gorm.DB
opts ImportOptions
}
func NewImporter(db *gorm.DB, opts ImportOptions) *Importer {
opts.ConflictPolicy = strings.ToLower(strings.TrimSpace(opts.ConflictPolicy))
if opts.ConflictPolicy == "" {
opts.ConflictPolicy = ConflictSkip
}
return &Importer{db: db, opts: opts}
}
func (i *Importer) ImportFile(path string) (*ImportSummary, error) {
raw, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
var payload ExportResult
if err := json.Unmarshal(raw, &payload); err != nil {
return nil, fmt.Errorf("parse json: %w", err)
}
return i.Import(&payload)
}
func (i *Importer) Import(payload *ExportResult) (*ImportSummary, error) {
if payload == nil {
return nil, errors.New("empty payload")
}
if i.db == nil {
return nil, errors.New("db is required")
}
summary := &ImportSummary{}
if len(payload.Warnings) > 0 {
summary.Warnings = append(summary.Warnings, payload.Warnings...)
}
if err := i.ensureNamespaces(payload, summary); err != nil {
return nil, err
}
masterMap, err := i.importMasters(payload.Data.Masters, summary)
if err != nil {
return nil, err
}
if err := i.importProviders(payload.Data.Providers, summary); err != nil {
return nil, err
}
if err := i.importKeys(payload.Data.Keys, masterMap, summary); err != nil {
return nil, err
}
if i.opts.IncludeBindings {
if err := i.importBindings(payload.Data.Bindings, summary); err != nil {
return nil, err
}
}
return summary, nil
}
func (i *Importer) ensureNamespaces(payload *ExportResult, summary *ImportSummary) error {
namespaceSet := make(map[string]struct{})
namespaceSet["default"] = struct{}{}
for _, master := range payload.Data.Masters {
addNamespaces(namespaceSet, master.DefaultNamespace)
addNamespaces(namespaceSet, master.Namespaces...)
}
for _, key := range payload.Data.Keys {
addNamespaces(namespaceSet, key.Namespaces...)
}
if i.opts.IncludeBindings {
for _, binding := range payload.Data.Bindings {
addNamespaces(namespaceSet, binding.Namespace)
}
}
for name := range namespaceSet {
if err := i.ensureNamespace(name, summary); err != nil {
return err
}
}
return nil
}
func (i *Importer) ensureNamespace(name string, summary *ImportSummary) error {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
var existing model.Namespace
err := i.db.Where("name = ?", name).First(&existing).Error
if err == nil {
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if i.opts.DryRun {
summary.NamespacesCreated++
return nil
}
ns := &model.Namespace{Name: name, Status: "active"}
if err := i.db.Create(ns).Error; err != nil {
return err
}
summary.NamespacesCreated++
return nil
}
func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[string]model.Master, error) {
out := make(map[string]model.Master)
for _, item := range items {
name := strings.TrimSpace(item.Name)
if name == "" {
summary.Warnings = append(summary.Warnings, "skip master with empty name")
continue
}
var existing model.Master
err := i.db.Where("name = ?", name).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if err == nil {
out[name] = existing
defaultNS := normalizeNamespace(item.DefaultNamespace)
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.MastersUpdated++
continue
}
update := map[string]any{
"group": normalizeGroup(item.Group),
"default_namespace": defaultNS,
"namespaces": normalizeNamespaces(item.Namespaces, defaultNS),
"max_child_keys": normalizeMaxChildKeys(item.MaxChildKeys),
"global_qps": item.GlobalQPS,
"status": normalizeStatus(item.Status, "active"),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return nil, err
}
out[name] = existing
summary.MastersUpdated++
default:
summary.MastersSkipped++
}
continue
}
defaultNS := normalizeNamespace(item.DefaultNamespace)
normalizedNamespaces := normalizeNamespaces(item.Namespaces, defaultNS)
if i.opts.DryRun {
summary.MastersCreated++
out[name] = model.Master{
Name: name,
Group: normalizeGroup(item.Group),
DefaultNamespace: defaultNS,
Namespaces: normalizedNamespaces,
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
GlobalQPS: item.GlobalQPS,
Status: normalizeStatus(item.Status, "active"),
Epoch: 1,
}
continue
}
rawKey, err := generateRandomKey(32)
if err != nil {
return nil, err
}
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
digest := tokenhash.HashToken(rawKey)
master := model.Master{
Name: name,
MasterKey: string(hashedKey),
MasterKeyDigest: digest,
Group: normalizeGroup(item.Group),
DefaultNamespace: defaultNS,
Namespaces: normalizedNamespaces,
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
GlobalQPS: item.GlobalQPS,
Status: normalizeStatus(item.Status, "active"),
Epoch: 1,
}
if err := i.db.Create(&master).Error; err != nil {
return nil, err
}
out[name] = master
summary.MastersCreated++
summary.MasterKeys = append(summary.MasterKeys, MasterCredential{
Name: master.Name,
MasterKey: rawKey,
})
}
return out, nil
}
func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error {
groupCache := make(map[string]model.ProviderGroup)
for _, item := range items {
groupName := normalizeGroup(item.PrimaryGroup)
if strings.TrimSpace(groupName) == "" {
groupName = "default"
}
group, ok := groupCache[groupName]
if !ok {
var existing model.ProviderGroup
err := i.db.Where("name = ?", groupName).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
group = existing
} else {
group = model.ProviderGroup{
Name: groupName,
Type: strings.TrimSpace(item.Type),
BaseURL: strings.TrimSpace(item.BaseURL),
Models: strings.Join(item.Models, ","),
Status: normalizeStatus(item.Status, "active"),
}
if !i.opts.DryRun {
if err := i.db.Create(&group).Error; err != nil {
return err
}
}
}
groupCache[groupName] = group
}
apiKey := strings.TrimSpace(item.APIKey)
if apiKey == "" {
summary.Warnings = append(summary.Warnings, "skip api key with empty api_key")
continue
}
var existingKey model.APIKey
err := i.db.Where("group_id = ? AND api_key = ?", group.ID, apiKey).First(&existingKey).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.ProvidersUpdated++
continue
}
update := map[string]any{
"weight": resolveWeight(item.Weight, item.Priority),
"status": normalizeProviderStatus(item.Status),
"auto_ban": item.AutoBan,
}
if err := i.db.Model(&existingKey).Updates(update).Error; err != nil {
return err
}
summary.ProvidersUpdated++
default:
summary.ProvidersSkipped++
}
continue
}
if i.opts.DryRun {
summary.ProvidersCreated++
continue
}
key := model.APIKey{
GroupID: group.ID,
APIKey: apiKey,
Weight: resolveWeight(item.Weight, item.Priority),
Status: normalizeProviderStatus(item.Status),
AutoBan: item.AutoBan,
}
if err := i.db.Create(&key).Error; err != nil {
return err
}
summary.ProvidersCreated++
}
return nil
}
func (i *Importer) importKeys(items []Key, masters map[string]model.Master, summary *ImportSummary) error {
for _, item := range items {
masterName := strings.TrimSpace(item.MasterRef)
if masterName == "" {
summary.Warnings = append(summary.Warnings, "skip key with empty master_ref")
continue
}
master, ok := masters[masterName]
if !ok {
return fmt.Errorf("master not found for key: %s", masterName)
}
rawToken := strings.TrimSpace(item.OriginalToken)
if rawToken == "" || !item.TokenPlaintextAvailable {
summary.Warnings = append(summary.Warnings, fmt.Sprintf("skip key for master %s: missing plaintext token", masterName))
continue
}
tokenHash := tokenhash.HashToken(rawToken)
var existing model.Key
err := i.db.Where("token_hash = ?", tokenHash).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.KeysUpdated++
continue
}
update := map[string]any{
"master_id": master.ID,
"group": normalizeGroup(item.Group, master.Group),
"scopes": strings.Join(item.Scopes, ","),
"default_namespace": normalizeNamespace(master.DefaultNamespace),
"namespaces": normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
"issued_at_epoch": master.Epoch,
"status": normalizeStatus(item.Status, "active"),
"issued_by": "import",
"model_limits": strings.Join(item.ModelLimits, ","),
"model_limits_enabled": item.ModelLimitsEnabled,
"expires_at": item.ExpiresAt,
"allow_ips": strings.Join(item.AllowIPs, ","),
"quota_limit": normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
"quota_used": normalizeQuotaUsed(item.QuotaUsed),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return err
}
summary.KeysUpdated++
default:
summary.KeysSkipped++
}
continue
}
if i.opts.DryRun {
summary.KeysCreated++
continue
}
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost)
if err != nil {
return err
}
key := model.Key{
MasterID: master.ID,
KeySecret: string(hashedKey),
TokenHash: tokenHash,
Group: normalizeGroup(item.Group, master.Group),
Scopes: strings.Join(item.Scopes, ","),
DefaultNamespace: normalizeNamespace(master.DefaultNamespace),
Namespaces: normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
IssuedAtEpoch: master.Epoch,
Status: normalizeStatus(item.Status, "active"),
IssuedBy: "import",
ModelLimits: strings.Join(item.ModelLimits, ","),
ModelLimitsEnabled: item.ModelLimitsEnabled,
ExpiresAt: item.ExpiresAt,
AllowIPs: strings.Join(item.AllowIPs, ","),
QuotaLimit: normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
QuotaUsed: normalizeQuotaUsed(item.QuotaUsed),
}
if err := i.db.Create(&key).Error; err != nil {
return err
}
summary.KeysCreated++
}
return nil
}
func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error {
for _, item := range items {
ns := normalizeNamespace(item.Namespace, "default")
publicModel := strings.TrimSpace(item.Model)
if publicModel == "" {
summary.Warnings = append(summary.Warnings, "skip binding with empty model")
continue
}
groupName := normalizeGroup(item.RouteGroup)
var group model.ProviderGroup
if err := i.db.Where("name = ?", groupName).First(&group).Error; err != nil {
summary.Warnings = append(summary.Warnings, "skip binding with missing provider group: "+groupName)
continue
}
var existing model.Binding
err := i.db.Where("namespace = ? AND public_model = ? AND group_id = ?", ns, publicModel, group.ID).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.BindingsUpdated++
continue
}
update := map[string]any{
"group_id": group.ID,
"weight": 1,
"selector_type": "exact",
"selector_value": publicModel,
"status": normalizeStatus(item.Status, "active"),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return err
}
summary.BindingsUpdated++
default:
summary.BindingsSkipped++
}
continue
}
if i.opts.DryRun {
summary.BindingsCreated++
continue
}
binding := model.Binding{
Namespace: ns,
PublicModel: publicModel,
GroupID: group.ID,
Weight: 1,
SelectorType: "exact",
SelectorValue: publicModel,
Status: normalizeStatus(item.Status, "active"),
}
if err := i.db.Create(&binding).Error; err != nil {
return err
}
summary.BindingsCreated++
}
return nil
}
func addNamespaces(set map[string]struct{}, names ...string) {
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
set[name] = struct{}{}
}
}
func normalizeGroup(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return "default"
}
func normalizeNamespace(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return "default"
}
func normalizeNamespaces(values []string, defaultNamespace string) string {
if len(values) == 0 {
return normalizeNamespace(defaultNamespace)
}
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, v := range values {
v = strings.TrimSpace(v)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
if len(out) == 0 {
return normalizeNamespace(defaultNamespace)
}
if defaultNamespace != "" {
if _, ok := seen[defaultNamespace]; !ok {
out = append(out, defaultNamespace)
}
}
return strings.Join(out, ",")
}
func normalizeMaxChildKeys(value int) int {
if value > 0 {
return value
}
return 5
}
func normalizeWeight(value int) int {
if value > 0 {
return value
}
return 1
}
func resolveWeight(weight int, priority int) int {
if weight > 0 {
return weight
}
if priority > 0 {
return priority
}
return normalizeWeight(weight)
}
func normalizeProviderStatus(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" || value == "active" {
return "active"
}
if value == "disabled" {
return "manual_disabled"
}
return value
}
func normalizeStatus(value string, fallback string) string {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" {
return fallback
}
return value
}
func normalizeQuotaLimit(limit *int64, unlimited bool, summary *ImportSummary) int64 {
if unlimited {
return -1
}
if limit != nil {
return *limit
}
if summary != nil {
summary.Warnings = append(summary.Warnings, "quota_limit missing for limited key, defaulting to -1")
}
return -1
}
func normalizeQuotaUsed(used *int64) int64 {
if used == nil {
return 0
}
return *used
}
func generateRandomKey(length int) (string, error) {
if length <= 0 {
return "", errors.New("invalid key length")
}
buf := make([]byte, length)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}