package migrate import ( "crypto/rand" "encoding/hex" "encoding/json" "errors" "fmt" "os" "strings" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/foundation/tokenhash" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) const ( ConflictSkip = "skip" ConflictOverwrite = "overwrite" ) type ImportOptions struct { DryRun bool ConflictPolicy string IncludeBindings bool } type MasterCredential struct { Name string `json:"name"` MasterKey string `json:"master_key"` } type ImportSummary struct { ProvidersCreated int `json:"providers_created"` ProvidersUpdated int `json:"providers_updated"` ProvidersSkipped int `json:"providers_skipped"` MastersCreated int `json:"masters_created"` MastersUpdated int `json:"masters_updated"` MastersSkipped int `json:"masters_skipped"` KeysCreated int `json:"keys_created"` KeysUpdated int `json:"keys_updated"` KeysSkipped int `json:"keys_skipped"` BindingsCreated int `json:"bindings_created"` BindingsUpdated int `json:"bindings_updated"` BindingsSkipped int `json:"bindings_skipped"` NamespacesCreated int `json:"namespaces_created"` MasterKeys []MasterCredential `json:"master_keys,omitempty"` Warnings []string `json:"warnings,omitempty"` } type Importer struct { db *gorm.DB opts ImportOptions } func NewImporter(db *gorm.DB, opts ImportOptions) *Importer { opts.ConflictPolicy = strings.ToLower(strings.TrimSpace(opts.ConflictPolicy)) if opts.ConflictPolicy == "" { opts.ConflictPolicy = ConflictSkip } return &Importer{db: db, opts: opts} } func (i *Importer) ImportFile(path string) (*ImportSummary, error) { raw, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %w", err) } var payload ExportResult if err := json.Unmarshal(raw, &payload); err != nil { return nil, fmt.Errorf("parse json: %w", err) } return i.Import(&payload) } func (i *Importer) Import(payload *ExportResult) (*ImportSummary, error) { if payload == nil { return nil, errors.New("empty payload") } if i.db == nil { return nil, errors.New("db is required") } summary := &ImportSummary{} if len(payload.Warnings) > 0 { summary.Warnings = append(summary.Warnings, payload.Warnings...) } if err := i.ensureNamespaces(payload, summary); err != nil { return nil, err } masterMap, err := i.importMasters(payload.Data.Masters, summary) if err != nil { return nil, err } if err := i.importProviders(payload.Data.Providers, summary); err != nil { return nil, err } if err := i.importKeys(payload.Data.Keys, masterMap, summary); err != nil { return nil, err } if i.opts.IncludeBindings { if err := i.importBindings(payload.Data.Bindings, summary); err != nil { return nil, err } } return summary, nil } func (i *Importer) ensureNamespaces(payload *ExportResult, summary *ImportSummary) error { namespaceSet := make(map[string]struct{}) namespaceSet["default"] = struct{}{} for _, master := range payload.Data.Masters { addNamespaces(namespaceSet, master.DefaultNamespace) addNamespaces(namespaceSet, master.Namespaces...) } for _, key := range payload.Data.Keys { addNamespaces(namespaceSet, key.Namespaces...) } if i.opts.IncludeBindings { for _, binding := range payload.Data.Bindings { addNamespaces(namespaceSet, binding.Namespace) } } for name := range namespaceSet { if err := i.ensureNamespace(name, summary); err != nil { return err } } return nil } func (i *Importer) ensureNamespace(name string, summary *ImportSummary) error { name = strings.TrimSpace(name) if name == "" { return nil } var existing model.Namespace err := i.db.Where("name = ?", name).First(&existing).Error if err == nil { return nil } if !errors.Is(err, gorm.ErrRecordNotFound) { return err } if i.opts.DryRun { summary.NamespacesCreated++ return nil } ns := &model.Namespace{Name: name, Status: "active"} if err := i.db.Create(ns).Error; err != nil { return err } summary.NamespacesCreated++ return nil } func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[string]model.Master, error) { out := make(map[string]model.Master) for _, item := range items { name := strings.TrimSpace(item.Name) if name == "" { summary.Warnings = append(summary.Warnings, "skip master with empty name") continue } var existing model.Master err := i.db.Where("name = ?", name).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } if err == nil { out[name] = existing defaultNS := normalizeNamespace(item.DefaultNamespace) switch i.opts.ConflictPolicy { case ConflictOverwrite: if i.opts.DryRun { summary.MastersUpdated++ continue } update := map[string]any{ "group": normalizeGroup(item.Group), "default_namespace": defaultNS, "namespaces": normalizeNamespaces(item.Namespaces, defaultNS), "max_child_keys": normalizeMaxChildKeys(item.MaxChildKeys), "global_qps": item.GlobalQPS, "status": normalizeStatus(item.Status, "active"), } if err := i.db.Model(&existing).Updates(update).Error; err != nil { return nil, err } out[name] = existing summary.MastersUpdated++ default: summary.MastersSkipped++ } continue } defaultNS := normalizeNamespace(item.DefaultNamespace) normalizedNamespaces := normalizeNamespaces(item.Namespaces, defaultNS) if i.opts.DryRun { summary.MastersCreated++ out[name] = model.Master{ Name: name, Group: normalizeGroup(item.Group), DefaultNamespace: defaultNS, Namespaces: normalizedNamespaces, MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys), GlobalQPS: item.GlobalQPS, Status: normalizeStatus(item.Status, "active"), Epoch: 1, } continue } rawKey, err := generateRandomKey(32) if err != nil { return nil, err } hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost) if err != nil { return nil, err } digest := tokenhash.HashToken(rawKey) master := model.Master{ Name: name, MasterKey: string(hashedKey), MasterKeyDigest: digest, Group: normalizeGroup(item.Group), DefaultNamespace: defaultNS, Namespaces: normalizedNamespaces, MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys), GlobalQPS: item.GlobalQPS, Status: normalizeStatus(item.Status, "active"), Epoch: 1, } if err := i.db.Create(&master).Error; err != nil { return nil, err } out[name] = master summary.MastersCreated++ summary.MasterKeys = append(summary.MasterKeys, MasterCredential{ Name: master.Name, MasterKey: rawKey, }) } return out, nil } func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error { groupCache := make(map[string]model.ProviderGroup) for _, item := range items { groupName := normalizeGroup(item.PrimaryGroup) if strings.TrimSpace(groupName) == "" { groupName = "default" } group, ok := groupCache[groupName] if !ok { var existing model.ProviderGroup err := i.db.Where("name = ?", groupName).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if err == nil { group = existing } else { group = model.ProviderGroup{ Name: groupName, Type: strings.TrimSpace(item.Type), BaseURL: strings.TrimSpace(item.BaseURL), Models: strings.Join(item.Models, ","), Status: normalizeStatus(item.Status, "active"), } if !i.opts.DryRun { if err := i.db.Create(&group).Error; err != nil { return err } } } groupCache[groupName] = group } apiKey := strings.TrimSpace(item.APIKey) if apiKey == "" { summary.Warnings = append(summary.Warnings, "skip api key with empty api_key") continue } var existingKey model.APIKey err := i.db.Where("group_id = ? AND api_key = ?", group.ID, apiKey).First(&existingKey).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if err == nil { switch i.opts.ConflictPolicy { case ConflictOverwrite: if i.opts.DryRun { summary.ProvidersUpdated++ continue } update := map[string]any{ "weight": resolveWeight(item.Weight, item.Priority), "status": normalizeProviderStatus(item.Status), "auto_ban": item.AutoBan, } if err := i.db.Model(&existingKey).Updates(update).Error; err != nil { return err } summary.ProvidersUpdated++ default: summary.ProvidersSkipped++ } continue } if i.opts.DryRun { summary.ProvidersCreated++ continue } key := model.APIKey{ GroupID: group.ID, APIKey: apiKey, Weight: resolveWeight(item.Weight, item.Priority), Status: normalizeProviderStatus(item.Status), AutoBan: item.AutoBan, } if err := i.db.Create(&key).Error; err != nil { return err } summary.ProvidersCreated++ } return nil } func (i *Importer) importKeys(items []Key, masters map[string]model.Master, summary *ImportSummary) error { for _, item := range items { masterName := strings.TrimSpace(item.MasterRef) if masterName == "" { summary.Warnings = append(summary.Warnings, "skip key with empty master_ref") continue } master, ok := masters[masterName] if !ok { return fmt.Errorf("master not found for key: %s", masterName) } rawToken := strings.TrimSpace(item.OriginalToken) if rawToken == "" || !item.TokenPlaintextAvailable { summary.Warnings = append(summary.Warnings, fmt.Sprintf("skip key for master %s: missing plaintext token", masterName)) continue } tokenHash := tokenhash.HashToken(rawToken) var existing model.Key err := i.db.Where("token_hash = ?", tokenHash).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if err == nil { switch i.opts.ConflictPolicy { case ConflictOverwrite: if i.opts.DryRun { summary.KeysUpdated++ continue } update := map[string]any{ "master_id": master.ID, "group": normalizeGroup(item.Group, master.Group), "scopes": strings.Join(item.Scopes, ","), "default_namespace": normalizeNamespace(master.DefaultNamespace), "namespaces": normalizeNamespaces(item.Namespaces, master.DefaultNamespace), "issued_at_epoch": master.Epoch, "status": normalizeStatus(item.Status, "active"), "issued_by": "import", "model_limits": strings.Join(item.ModelLimits, ","), "model_limits_enabled": item.ModelLimitsEnabled, "expires_at": item.ExpiresAt, "allow_ips": strings.Join(item.AllowIPs, ","), "quota_limit": normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary), "quota_used": normalizeQuotaUsed(item.QuotaUsed), } if err := i.db.Model(&existing).Updates(update).Error; err != nil { return err } summary.KeysUpdated++ default: summary.KeysSkipped++ } continue } if i.opts.DryRun { summary.KeysCreated++ continue } hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost) if err != nil { return err } key := model.Key{ MasterID: master.ID, KeySecret: string(hashedKey), TokenHash: tokenHash, Group: normalizeGroup(item.Group, master.Group), Scopes: strings.Join(item.Scopes, ","), DefaultNamespace: normalizeNamespace(master.DefaultNamespace), Namespaces: normalizeNamespaces(item.Namespaces, master.DefaultNamespace), IssuedAtEpoch: master.Epoch, Status: normalizeStatus(item.Status, "active"), IssuedBy: "import", ModelLimits: strings.Join(item.ModelLimits, ","), ModelLimitsEnabled: item.ModelLimitsEnabled, ExpiresAt: item.ExpiresAt, AllowIPs: strings.Join(item.AllowIPs, ","), QuotaLimit: normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary), QuotaUsed: normalizeQuotaUsed(item.QuotaUsed), } if err := i.db.Create(&key).Error; err != nil { return err } summary.KeysCreated++ } return nil } func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error { for _, item := range items { ns := normalizeNamespace(item.Namespace, "default") publicModel := strings.TrimSpace(item.Model) if publicModel == "" { summary.Warnings = append(summary.Warnings, "skip binding with empty model") continue } groupName := normalizeGroup(item.RouteGroup) var group model.ProviderGroup if err := i.db.Where("name = ?", groupName).First(&group).Error; err != nil { summary.Warnings = append(summary.Warnings, "skip binding with missing provider group: "+groupName) continue } var existing model.Binding err := i.db.Where("namespace = ? AND public_model = ? AND group_id = ?", ns, publicModel, group.ID).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if err == nil { switch i.opts.ConflictPolicy { case ConflictOverwrite: if i.opts.DryRun { summary.BindingsUpdated++ continue } update := map[string]any{ "group_id": group.ID, "weight": 1, "selector_type": "exact", "selector_value": publicModel, "status": normalizeStatus(item.Status, "active"), } if err := i.db.Model(&existing).Updates(update).Error; err != nil { return err } summary.BindingsUpdated++ default: summary.BindingsSkipped++ } continue } if i.opts.DryRun { summary.BindingsCreated++ continue } binding := model.Binding{ Namespace: ns, PublicModel: publicModel, GroupID: group.ID, Weight: 1, SelectorType: "exact", SelectorValue: publicModel, Status: normalizeStatus(item.Status, "active"), } if err := i.db.Create(&binding).Error; err != nil { return err } summary.BindingsCreated++ } return nil } func addNamespaces(set map[string]struct{}, names ...string) { for _, name := range names { name = strings.TrimSpace(name) if name == "" { continue } set[name] = struct{}{} } } func normalizeGroup(values ...string) string { for _, v := range values { v = strings.TrimSpace(v) if v != "" { return v } } return "default" } func normalizeNamespace(values ...string) string { for _, v := range values { v = strings.TrimSpace(v) if v != "" { return v } } return "default" } func normalizeNamespaces(values []string, defaultNamespace string) string { if len(values) == 0 { return normalizeNamespace(defaultNamespace) } seen := make(map[string]struct{}, len(values)) out := make([]string, 0, len(values)) for _, v := range values { v = strings.TrimSpace(v) if v == "" { continue } if _, ok := seen[v]; ok { continue } seen[v] = struct{}{} out = append(out, v) } if len(out) == 0 { return normalizeNamespace(defaultNamespace) } if defaultNamespace != "" { if _, ok := seen[defaultNamespace]; !ok { out = append(out, defaultNamespace) } } return strings.Join(out, ",") } func normalizeMaxChildKeys(value int) int { if value > 0 { return value } return 5 } func normalizeWeight(value int) int { if value > 0 { return value } return 1 } func resolveWeight(weight int, priority int) int { if weight > 0 { return weight } if priority > 0 { return priority } return normalizeWeight(weight) } func normalizeProviderStatus(value string) string { value = strings.ToLower(strings.TrimSpace(value)) if value == "" || value == "active" { return "active" } if value == "disabled" { return "manual_disabled" } return value } func normalizeStatus(value string, fallback string) string { value = strings.ToLower(strings.TrimSpace(value)) if value == "" { return fallback } return value } func normalizeQuotaLimit(limit *int64, unlimited bool, summary *ImportSummary) int64 { if unlimited { return -1 } if limit != nil { return *limit } if summary != nil { summary.Warnings = append(summary.Warnings, "quota_limit missing for limited key, defaulting to -1") } return -1 } func normalizeQuotaUsed(used *int64) int64 { if used == nil { return 0 } return *used } func generateRandomKey(length int) (string, error) { if length <= 0 { return "", errors.New("invalid key length") } buf := make([]byte, length) if _, err := rand.Read(buf); err != nil { return "", err } return hex.EncodeToString(buf), nil }