feat(migrate): add import CLI command and importer for migration data

Introduce a new `import` subcommand to the server binary that reads
exported JSON files and imports masters, providers, keys, bindings,
and namespaces into the database.

Key features:
- Support for dry-run mode to validate without writing
- Conflict policies: skip existing or overwrite
- Optional binding import via --include-bindings flag
- Auto-generation of master keys with secure hashing
- Namespace auto-creation for referenced namespaces
- Detailed import summary with warnings and created credentials
This commit is contained in:
zenfun
2025-12-23 20:13:45 +08:00
parent ee6c28afc9
commit cd5616dc26
3 changed files with 776 additions and 0 deletions

View File

@@ -2,7 +2,10 @@ package main
import ( import (
"context" "context"
"encoding/json"
"expvar" "expvar"
"flag"
"fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
@@ -16,6 +19,7 @@ import (
"github.com/ez-api/ez-api/internal/config" "github.com/ez-api/ez-api/internal/config"
"github.com/ez-api/ez-api/internal/cron" "github.com/ez-api/ez-api/internal/cron"
"github.com/ez-api/ez-api/internal/middleware" "github.com/ez-api/ez-api/internal/middleware"
"github.com/ez-api/ez-api/internal/migrate"
"github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/ez-api/internal/service" "github.com/ez-api/ez-api/internal/service"
"github.com/ez-api/foundation/logging" "github.com/ez-api/foundation/logging"
@@ -72,6 +76,10 @@ func isOriginAllowed(allowed []string, origin string) bool {
func main() { func main() {
logger, _ := logging.New(logging.Options{Service: "ez-api"}) logger, _ := logging.New(logging.Options{Service: "ez-api"})
if len(os.Args) > 1 && os.Args[1] == "import" {
code := runImport(logger, os.Args[2:])
os.Exit(code)
}
// 1. Load Configuration // 1. Load Configuration
cfg, err := config.Load() cfg, err := config.Load()
@@ -357,3 +365,71 @@ func main() {
logger.Info("server exited properly") logger.Info("server exited properly")
} }
func runImport(logger *slog.Logger, args []string) int {
fs := flag.NewFlagSet("import", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
var filePath string
var dryRun bool
var conflictPolicy string
var includeBindings bool
fs.StringVar(&filePath, "file", "", "Path to export JSON")
fs.BoolVar(&dryRun, "dry-run", false, "Validate only, do not write to database")
fs.StringVar(&conflictPolicy, "conflict", migrate.ConflictSkip, "Conflict policy: skip or overwrite")
fs.BoolVar(&includeBindings, "include-bindings", false, "Import bindings from payload")
if err := fs.Parse(args); err != nil {
logger.Error("failed to parse flags", "err", err)
return 2
}
if strings.TrimSpace(filePath) == "" {
fmt.Fprintln(os.Stderr, "missing --file")
return 2
}
conflictPolicy = strings.ToLower(strings.TrimSpace(conflictPolicy))
if conflictPolicy != migrate.ConflictSkip && conflictPolicy != migrate.ConflictOverwrite {
fmt.Fprintf(os.Stderr, "invalid --conflict value: %s\n", conflictPolicy)
return 2
}
cfg, err := config.Load()
if err != nil {
logger.Error("failed to load config", "err", err)
return 1
}
db, err := gorm.Open(postgres.Open(cfg.Postgres.DSN), &gorm.Config{})
if err != nil {
logger.Error("failed to connect to postgresql", "err", err)
return 1
}
if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil {
logger.Error("failed to auto migrate", "err", err)
return 1
}
importer := migrate.NewImporter(db, migrate.ImportOptions{
DryRun: dryRun,
ConflictPolicy: conflictPolicy,
IncludeBindings: includeBindings,
})
summary, err := importer.ImportFile(filePath)
if err != nil {
logger.Error("import failed", "err", err)
return 1
}
payload, err := json.MarshalIndent(summary, "", " ")
if err != nil {
logger.Error("failed to render import summary", "err", err)
return 1
}
fmt.Fprintln(os.Stdout, string(payload))
if dryRun {
fmt.Fprintln(os.Stdout, "dry-run only: no data written")
}
return 0
}

View File

@@ -0,0 +1,602 @@
package migrate
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/foundation/tokenhash"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
const (
ConflictSkip = "skip"
ConflictOverwrite = "overwrite"
)
type ImportOptions struct {
DryRun bool
ConflictPolicy string
IncludeBindings bool
}
type MasterCredential struct {
Name string `json:"name"`
MasterKey string `json:"master_key"`
}
type ImportSummary struct {
ProvidersCreated int `json:"providers_created"`
ProvidersUpdated int `json:"providers_updated"`
ProvidersSkipped int `json:"providers_skipped"`
MastersCreated int `json:"masters_created"`
MastersUpdated int `json:"masters_updated"`
MastersSkipped int `json:"masters_skipped"`
KeysCreated int `json:"keys_created"`
KeysUpdated int `json:"keys_updated"`
KeysSkipped int `json:"keys_skipped"`
BindingsCreated int `json:"bindings_created"`
BindingsUpdated int `json:"bindings_updated"`
BindingsSkipped int `json:"bindings_skipped"`
NamespacesCreated int `json:"namespaces_created"`
MasterKeys []MasterCredential `json:"master_keys,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
type Importer struct {
db *gorm.DB
opts ImportOptions
}
func NewImporter(db *gorm.DB, opts ImportOptions) *Importer {
opts.ConflictPolicy = strings.ToLower(strings.TrimSpace(opts.ConflictPolicy))
if opts.ConflictPolicy == "" {
opts.ConflictPolicy = ConflictSkip
}
return &Importer{db: db, opts: opts}
}
func (i *Importer) ImportFile(path string) (*ImportSummary, error) {
raw, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
var payload ExportResult
if err := json.Unmarshal(raw, &payload); err != nil {
return nil, fmt.Errorf("parse json: %w", err)
}
return i.Import(&payload)
}
func (i *Importer) Import(payload *ExportResult) (*ImportSummary, error) {
if payload == nil {
return nil, errors.New("empty payload")
}
if i.db == nil {
return nil, errors.New("db is required")
}
summary := &ImportSummary{}
if len(payload.Warnings) > 0 {
summary.Warnings = append(summary.Warnings, payload.Warnings...)
}
if err := i.ensureNamespaces(payload, summary); err != nil {
return nil, err
}
masterMap, err := i.importMasters(payload.Data.Masters, summary)
if err != nil {
return nil, err
}
if err := i.importProviders(payload.Data.Providers, summary); err != nil {
return nil, err
}
if err := i.importKeys(payload.Data.Keys, masterMap, summary); err != nil {
return nil, err
}
if i.opts.IncludeBindings {
if err := i.importBindings(payload.Data.Bindings, summary); err != nil {
return nil, err
}
}
return summary, nil
}
func (i *Importer) ensureNamespaces(payload *ExportResult, summary *ImportSummary) error {
namespaceSet := make(map[string]struct{})
namespaceSet["default"] = struct{}{}
for _, master := range payload.Data.Masters {
addNamespaces(namespaceSet, master.DefaultNamespace)
addNamespaces(namespaceSet, master.Namespaces...)
}
for _, key := range payload.Data.Keys {
addNamespaces(namespaceSet, key.Namespaces...)
}
if i.opts.IncludeBindings {
for _, binding := range payload.Data.Bindings {
addNamespaces(namespaceSet, binding.Namespace)
}
}
for name := range namespaceSet {
if err := i.ensureNamespace(name, summary); err != nil {
return err
}
}
return nil
}
func (i *Importer) ensureNamespace(name string, summary *ImportSummary) error {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
var existing model.Namespace
err := i.db.Where("name = ?", name).First(&existing).Error
if err == nil {
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if i.opts.DryRun {
summary.NamespacesCreated++
return nil
}
ns := &model.Namespace{Name: name, Status: "active"}
if err := i.db.Create(ns).Error; err != nil {
return err
}
summary.NamespacesCreated++
return nil
}
func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[string]model.Master, error) {
out := make(map[string]model.Master)
for _, item := range items {
name := strings.TrimSpace(item.Name)
if name == "" {
summary.Warnings = append(summary.Warnings, "skip master with empty name")
continue
}
var existing model.Master
err := i.db.Where("name = ?", name).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if err == nil {
out[name] = existing
defaultNS := normalizeNamespace(item.DefaultNamespace)
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.MastersUpdated++
continue
}
update := map[string]any{
"group": normalizeGroup(item.Group),
"default_namespace": defaultNS,
"namespaces": normalizeNamespaces(item.Namespaces, defaultNS),
"max_child_keys": normalizeMaxChildKeys(item.MaxChildKeys),
"global_qps": item.GlobalQPS,
"status": normalizeStatus(item.Status, "active"),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return nil, err
}
out[name] = existing
summary.MastersUpdated++
default:
summary.MastersSkipped++
}
continue
}
defaultNS := normalizeNamespace(item.DefaultNamespace)
normalizedNamespaces := normalizeNamespaces(item.Namespaces, defaultNS)
if i.opts.DryRun {
summary.MastersCreated++
out[name] = model.Master{
Name: name,
Group: normalizeGroup(item.Group),
DefaultNamespace: defaultNS,
Namespaces: normalizedNamespaces,
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
GlobalQPS: item.GlobalQPS,
Status: normalizeStatus(item.Status, "active"),
Epoch: 1,
}
continue
}
rawKey, err := generateRandomKey(32)
if err != nil {
return nil, err
}
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
digest := tokenhash.HashToken(rawKey)
master := model.Master{
Name: name,
MasterKey: string(hashedKey),
MasterKeyDigest: digest,
Group: normalizeGroup(item.Group),
DefaultNamespace: defaultNS,
Namespaces: normalizedNamespaces,
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
GlobalQPS: item.GlobalQPS,
Status: normalizeStatus(item.Status, "active"),
Epoch: 1,
}
if err := i.db.Create(&master).Error; err != nil {
return nil, err
}
out[name] = master
summary.MastersCreated++
summary.MasterKeys = append(summary.MasterKeys, MasterCredential{
Name: master.Name,
MasterKey: rawKey,
})
}
return out, nil
}
func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error {
for _, item := range items {
name := strings.TrimSpace(item.Name)
if name == "" {
summary.Warnings = append(summary.Warnings, "skip provider with empty name")
continue
}
var existing model.Provider
err := i.db.Where("name = ?", name).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.ProvidersUpdated++
continue
}
update := map[string]any{
"type": strings.TrimSpace(item.Type),
"base_url": strings.TrimSpace(item.BaseURL),
"api_key": strings.TrimSpace(item.APIKey),
"group": normalizeGroup(item.PrimaryGroup),
"models": strings.Join(item.Models, ","),
"weight": resolveWeight(item.Weight, item.Priority),
"status": normalizeProviderStatus(item.Status),
"auto_ban": item.AutoBan,
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return err
}
summary.ProvidersUpdated++
default:
summary.ProvidersSkipped++
}
continue
}
if i.opts.DryRun {
summary.ProvidersCreated++
continue
}
provider := model.Provider{
Name: name,
Type: strings.TrimSpace(item.Type),
BaseURL: strings.TrimSpace(item.BaseURL),
APIKey: strings.TrimSpace(item.APIKey),
Group: normalizeGroup(item.PrimaryGroup),
Models: strings.Join(item.Models, ","),
Weight: resolveWeight(item.Weight, item.Priority),
Status: normalizeProviderStatus(item.Status),
AutoBan: item.AutoBan,
}
if err := i.db.Create(&provider).Error; err != nil {
return err
}
summary.ProvidersCreated++
}
return nil
}
func (i *Importer) importKeys(items []Key, masters map[string]model.Master, summary *ImportSummary) error {
for _, item := range items {
masterName := strings.TrimSpace(item.MasterRef)
if masterName == "" {
summary.Warnings = append(summary.Warnings, "skip key with empty master_ref")
continue
}
master, ok := masters[masterName]
if !ok {
return fmt.Errorf("master not found for key: %s", masterName)
}
rawToken := strings.TrimSpace(item.OriginalToken)
if rawToken == "" || !item.TokenPlaintextAvailable {
summary.Warnings = append(summary.Warnings, fmt.Sprintf("skip key for master %s: missing plaintext token", masterName))
continue
}
tokenHash := tokenhash.HashToken(rawToken)
var existing model.Key
err := i.db.Where("token_hash = ?", tokenHash).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.KeysUpdated++
continue
}
update := map[string]any{
"master_id": master.ID,
"group": normalizeGroup(item.Group, master.Group),
"scopes": strings.Join(item.Scopes, ","),
"default_namespace": normalizeNamespace(master.DefaultNamespace),
"namespaces": normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
"issued_at_epoch": master.Epoch,
"status": normalizeStatus(item.Status, "active"),
"issued_by": "import",
"model_limits": strings.Join(item.ModelLimits, ","),
"model_limits_enabled": item.ModelLimitsEnabled,
"expires_at": item.ExpiresAt,
"allow_ips": strings.Join(item.AllowIPs, ","),
"quota_limit": normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
"quota_used": normalizeQuotaUsed(item.QuotaUsed),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return err
}
summary.KeysUpdated++
default:
summary.KeysSkipped++
}
continue
}
if i.opts.DryRun {
summary.KeysCreated++
continue
}
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost)
if err != nil {
return err
}
key := model.Key{
MasterID: master.ID,
KeySecret: string(hashedKey),
TokenHash: tokenHash,
Group: normalizeGroup(item.Group, master.Group),
Scopes: strings.Join(item.Scopes, ","),
DefaultNamespace: normalizeNamespace(master.DefaultNamespace),
Namespaces: normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
IssuedAtEpoch: master.Epoch,
Status: normalizeStatus(item.Status, "active"),
IssuedBy: "import",
ModelLimits: strings.Join(item.ModelLimits, ","),
ModelLimitsEnabled: item.ModelLimitsEnabled,
ExpiresAt: item.ExpiresAt,
AllowIPs: strings.Join(item.AllowIPs, ","),
QuotaLimit: normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
QuotaUsed: normalizeQuotaUsed(item.QuotaUsed),
}
if err := i.db.Create(&key).Error; err != nil {
return err
}
summary.KeysCreated++
}
return nil
}
func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error {
for _, item := range items {
ns := normalizeNamespace(item.Namespace, "default")
publicModel := strings.TrimSpace(item.Model)
if publicModel == "" {
summary.Warnings = append(summary.Warnings, "skip binding with empty model")
continue
}
var existing model.Binding
err := i.db.Where("namespace = ? AND public_model = ?", ns, publicModel).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
switch i.opts.ConflictPolicy {
case ConflictOverwrite:
if i.opts.DryRun {
summary.BindingsUpdated++
continue
}
update := map[string]any{
"route_group": normalizeGroup(item.RouteGroup),
"selector_type": "exact",
"selector_value": publicModel,
"status": normalizeStatus(item.Status, "active"),
}
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
return err
}
summary.BindingsUpdated++
default:
summary.BindingsSkipped++
}
continue
}
if i.opts.DryRun {
summary.BindingsCreated++
continue
}
binding := model.Binding{
Namespace: ns,
PublicModel: publicModel,
RouteGroup: normalizeGroup(item.RouteGroup),
SelectorType: "exact",
SelectorValue: publicModel,
Status: normalizeStatus(item.Status, "active"),
}
if err := i.db.Create(&binding).Error; err != nil {
return err
}
summary.BindingsCreated++
}
return nil
}
func addNamespaces(set map[string]struct{}, names ...string) {
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
set[name] = struct{}{}
}
}
func normalizeGroup(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return "default"
}
func normalizeNamespace(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return "default"
}
func normalizeNamespaces(values []string, defaultNamespace string) string {
if len(values) == 0 {
return normalizeNamespace(defaultNamespace)
}
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, v := range values {
v = strings.TrimSpace(v)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
if len(out) == 0 {
return normalizeNamespace(defaultNamespace)
}
if defaultNamespace != "" {
if _, ok := seen[defaultNamespace]; !ok {
out = append(out, defaultNamespace)
}
}
return strings.Join(out, ",")
}
func normalizeMaxChildKeys(value int) int {
if value > 0 {
return value
}
return 5
}
func normalizeWeight(value int) int {
if value > 0 {
return value
}
return 1
}
func resolveWeight(weight int, priority int) int {
if weight > 0 {
return weight
}
if priority > 0 {
return priority
}
return normalizeWeight(weight)
}
func normalizeProviderStatus(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" || value == "active" {
return "active"
}
if value == "disabled" {
return "manual_disabled"
}
return value
}
func normalizeStatus(value string, fallback string) string {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" {
return fallback
}
return value
}
func normalizeQuotaLimit(limit *int64, unlimited bool, summary *ImportSummary) int64 {
if unlimited {
return -1
}
if limit != nil {
return *limit
}
if summary != nil {
summary.Warnings = append(summary.Warnings, "quota_limit missing for limited key, defaulting to -1")
}
return -1
}
func normalizeQuotaUsed(used *int64) int64 {
if used == nil {
return 0
}
return *used
}
func generateRandomKey(length int) (string, error) {
if length <= 0 {
return "", errors.New("invalid key length")
}
buf := make([]byte, length)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}

View File

@@ -0,0 +1,98 @@
package migrate
import (
"encoding/json"
"time"
)
// ExportResult represents the complete export output.
type ExportResult struct {
Version string `json:"version"`
Source Source `json:"source"`
Data Data `json:"data"`
Warnings []string `json:"warnings,omitempty"`
}
// Source represents the source system information.
type Source struct {
Type string `json:"type"`
Version string `json:"version"`
ExportedAt time.Time `json:"exported_at"`
}
// Data contains all exported entities.
type Data struct {
Providers []Provider `json:"providers,omitempty"`
Masters []Master `json:"masters,omitempty"`
Keys []Key `json:"keys,omitempty"`
Bindings []Binding `json:"bindings,omitempty"`
}
// Provider represents an EZ-API provider (mapped from New API channel).
type Provider struct {
OriginalID int `json:"original_id"`
Name string `json:"name"`
Type string `json:"type"`
BaseURL string `json:"base_url,omitempty"`
APIKey string `json:"api_key"`
Models []string `json:"models,omitempty"`
PrimaryGroup string `json:"primary_group"`
AllGroups []string `json:"all_groups,omitempty"`
Weight int `json:"weight"`
Priority int `json:"priority,omitempty"`
Status string `json:"status"`
AutoBan bool `json:"auto_ban"`
IsMultiKey bool `json:"is_multi_key,omitempty"`
MultiKeyIndex int `json:"multi_key_index,omitempty"`
OriginalName string `json:"original_name,omitempty"`
Original json.RawMessage `json:"_original,omitempty"`
}
// Master represents an EZ-API master (inferred from New API user).
type Master struct {
Name string `json:"name"`
Group string `json:"group"`
Namespaces []string `json:"namespaces,omitempty"`
DefaultNamespace string `json:"default_namespace,omitempty"`
MaxChildKeys int `json:"max_child_keys,omitempty"`
GlobalQPS int `json:"global_qps,omitempty"`
Status string `json:"status"`
SourceUserID int `json:"_source_user_id"`
SourceEmail string `json:"_source_email,omitempty"`
}
// Key represents an EZ-API key (mapped from New API token).
type Key struct {
MasterRef string `json:"master_ref"`
OriginalToken string `json:"original_token"`
Group string `json:"group,omitempty"`
Status string `json:"status"`
Scopes []string `json:"scopes,omitempty"`
Namespaces []string `json:"namespaces,omitempty"`
ModelLimitsEnabled bool `json:"model_limits_enabled,omitempty"`
ModelLimits []string `json:"model_limits,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
AllowIPs []string `json:"allow_ips,omitempty"`
QuotaLimit *int64 `json:"quota_limit,omitempty"`
QuotaUsed *int64 `json:"quota_used,omitempty"`
UnlimitedQuota bool `json:"unlimited_quota,omitempty"`
OriginalID int `json:"_original_id"`
TokenPlaintextAvailable bool `json:"_token_plaintext_available,omitempty"`
}
// Binding represents an EZ-API binding (optional, from abilities).
type Binding struct {
Namespace string `json:"namespace"`
RouteGroup string `json:"route_group"`
Model string `json:"model"`
Status string `json:"status"`
}