diff --git a/cmd/server/main.go b/cmd/server/main.go index 4dd8bf3..3c4c8b7 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,7 +2,10 @@ package main import ( "context" + "encoding/json" "expvar" + "flag" + "fmt" "log/slog" "net/http" "os" @@ -16,6 +19,7 @@ import ( "github.com/ez-api/ez-api/internal/config" "github.com/ez-api/ez-api/internal/cron" "github.com/ez-api/ez-api/internal/middleware" + "github.com/ez-api/ez-api/internal/migrate" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/ez-api/internal/service" "github.com/ez-api/foundation/logging" @@ -72,6 +76,10 @@ func isOriginAllowed(allowed []string, origin string) bool { func main() { logger, _ := logging.New(logging.Options{Service: "ez-api"}) + if len(os.Args) > 1 && os.Args[1] == "import" { + code := runImport(logger, os.Args[2:]) + os.Exit(code) + } // 1. Load Configuration cfg, err := config.Load() @@ -357,3 +365,71 @@ func main() { logger.Info("server exited properly") } + +func runImport(logger *slog.Logger, args []string) int { + fs := flag.NewFlagSet("import", flag.ContinueOnError) + fs.SetOutput(os.Stderr) + + var filePath string + var dryRun bool + var conflictPolicy string + var includeBindings bool + + fs.StringVar(&filePath, "file", "", "Path to export JSON") + fs.BoolVar(&dryRun, "dry-run", false, "Validate only, do not write to database") + fs.StringVar(&conflictPolicy, "conflict", migrate.ConflictSkip, "Conflict policy: skip or overwrite") + fs.BoolVar(&includeBindings, "include-bindings", false, "Import bindings from payload") + + if err := fs.Parse(args); err != nil { + logger.Error("failed to parse flags", "err", err) + return 2 + } + if strings.TrimSpace(filePath) == "" { + fmt.Fprintln(os.Stderr, "missing --file") + return 2 + } + conflictPolicy = strings.ToLower(strings.TrimSpace(conflictPolicy)) + if conflictPolicy != migrate.ConflictSkip && conflictPolicy != migrate.ConflictOverwrite { + fmt.Fprintf(os.Stderr, "invalid --conflict value: %s\n", conflictPolicy) + return 2 + } + + cfg, err := config.Load() + if err != nil { + logger.Error("failed to load config", "err", err) + return 1 + } + + db, err := gorm.Open(postgres.Open(cfg.Postgres.DSN), &gorm.Config{}) + if err != nil { + logger.Error("failed to connect to postgresql", "err", err) + return 1 + } + + if err := db.AutoMigrate(&model.Master{}, &model.Key{}, &model.Provider{}, &model.Model{}, &model.Binding{}, &model.Namespace{}); err != nil { + logger.Error("failed to auto migrate", "err", err) + return 1 + } + + importer := migrate.NewImporter(db, migrate.ImportOptions{ + DryRun: dryRun, + ConflictPolicy: conflictPolicy, + IncludeBindings: includeBindings, + }) + summary, err := importer.ImportFile(filePath) + if err != nil { + logger.Error("import failed", "err", err) + return 1 + } + + payload, err := json.MarshalIndent(summary, "", " ") + if err != nil { + logger.Error("failed to render import summary", "err", err) + return 1 + } + fmt.Fprintln(os.Stdout, string(payload)) + if dryRun { + fmt.Fprintln(os.Stdout, "dry-run only: no data written") + } + return 0 +} diff --git a/internal/migrate/importer.go b/internal/migrate/importer.go new file mode 100644 index 0000000..7f07381 --- /dev/null +++ b/internal/migrate/importer.go @@ -0,0 +1,602 @@ +package migrate + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/foundation/tokenhash" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +const ( + ConflictSkip = "skip" + ConflictOverwrite = "overwrite" +) + +type ImportOptions struct { + DryRun bool + ConflictPolicy string + IncludeBindings bool +} + +type MasterCredential struct { + Name string `json:"name"` + MasterKey string `json:"master_key"` +} + +type ImportSummary struct { + ProvidersCreated int `json:"providers_created"` + ProvidersUpdated int `json:"providers_updated"` + ProvidersSkipped int `json:"providers_skipped"` + + MastersCreated int `json:"masters_created"` + MastersUpdated int `json:"masters_updated"` + MastersSkipped int `json:"masters_skipped"` + + KeysCreated int `json:"keys_created"` + KeysUpdated int `json:"keys_updated"` + KeysSkipped int `json:"keys_skipped"` + + BindingsCreated int `json:"bindings_created"` + BindingsUpdated int `json:"bindings_updated"` + BindingsSkipped int `json:"bindings_skipped"` + + NamespacesCreated int `json:"namespaces_created"` + + MasterKeys []MasterCredential `json:"master_keys,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +type Importer struct { + db *gorm.DB + opts ImportOptions +} + +func NewImporter(db *gorm.DB, opts ImportOptions) *Importer { + opts.ConflictPolicy = strings.ToLower(strings.TrimSpace(opts.ConflictPolicy)) + if opts.ConflictPolicy == "" { + opts.ConflictPolicy = ConflictSkip + } + return &Importer{db: db, opts: opts} +} + +func (i *Importer) ImportFile(path string) (*ImportSummary, error) { + raw, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read file: %w", err) + } + var payload ExportResult + if err := json.Unmarshal(raw, &payload); err != nil { + return nil, fmt.Errorf("parse json: %w", err) + } + return i.Import(&payload) +} + +func (i *Importer) Import(payload *ExportResult) (*ImportSummary, error) { + if payload == nil { + return nil, errors.New("empty payload") + } + if i.db == nil { + return nil, errors.New("db is required") + } + + summary := &ImportSummary{} + if len(payload.Warnings) > 0 { + summary.Warnings = append(summary.Warnings, payload.Warnings...) + } + + if err := i.ensureNamespaces(payload, summary); err != nil { + return nil, err + } + + masterMap, err := i.importMasters(payload.Data.Masters, summary) + if err != nil { + return nil, err + } + if err := i.importProviders(payload.Data.Providers, summary); err != nil { + return nil, err + } + if err := i.importKeys(payload.Data.Keys, masterMap, summary); err != nil { + return nil, err + } + if i.opts.IncludeBindings { + if err := i.importBindings(payload.Data.Bindings, summary); err != nil { + return nil, err + } + } + + return summary, nil +} + +func (i *Importer) ensureNamespaces(payload *ExportResult, summary *ImportSummary) error { + namespaceSet := make(map[string]struct{}) + namespaceSet["default"] = struct{}{} + for _, master := range payload.Data.Masters { + addNamespaces(namespaceSet, master.DefaultNamespace) + addNamespaces(namespaceSet, master.Namespaces...) + } + for _, key := range payload.Data.Keys { + addNamespaces(namespaceSet, key.Namespaces...) + } + if i.opts.IncludeBindings { + for _, binding := range payload.Data.Bindings { + addNamespaces(namespaceSet, binding.Namespace) + } + } + + for name := range namespaceSet { + if err := i.ensureNamespace(name, summary); err != nil { + return err + } + } + return nil +} + +func (i *Importer) ensureNamespace(name string, summary *ImportSummary) error { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + var existing model.Namespace + err := i.db.Where("name = ?", name).First(&existing).Error + if err == nil { + return nil + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if i.opts.DryRun { + summary.NamespacesCreated++ + return nil + } + ns := &model.Namespace{Name: name, Status: "active"} + if err := i.db.Create(ns).Error; err != nil { + return err + } + summary.NamespacesCreated++ + return nil +} + +func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[string]model.Master, error) { + out := make(map[string]model.Master) + for _, item := range items { + name := strings.TrimSpace(item.Name) + if name == "" { + summary.Warnings = append(summary.Warnings, "skip master with empty name") + continue + } + var existing model.Master + err := i.db.Where("name = ?", name).First(&existing).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + if err == nil { + out[name] = existing + defaultNS := normalizeNamespace(item.DefaultNamespace) + switch i.opts.ConflictPolicy { + case ConflictOverwrite: + if i.opts.DryRun { + summary.MastersUpdated++ + continue + } + update := map[string]any{ + "group": normalizeGroup(item.Group), + "default_namespace": defaultNS, + "namespaces": normalizeNamespaces(item.Namespaces, defaultNS), + "max_child_keys": normalizeMaxChildKeys(item.MaxChildKeys), + "global_qps": item.GlobalQPS, + "status": normalizeStatus(item.Status, "active"), + } + if err := i.db.Model(&existing).Updates(update).Error; err != nil { + return nil, err + } + out[name] = existing + summary.MastersUpdated++ + default: + summary.MastersSkipped++ + } + continue + } + + defaultNS := normalizeNamespace(item.DefaultNamespace) + normalizedNamespaces := normalizeNamespaces(item.Namespaces, defaultNS) + if i.opts.DryRun { + summary.MastersCreated++ + out[name] = model.Master{ + Name: name, + Group: normalizeGroup(item.Group), + DefaultNamespace: defaultNS, + Namespaces: normalizedNamespaces, + MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys), + GlobalQPS: item.GlobalQPS, + Status: normalizeStatus(item.Status, "active"), + Epoch: 1, + } + continue + } + + rawKey, err := generateRandomKey(32) + if err != nil { + return nil, err + } + hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + digest := tokenhash.HashToken(rawKey) + + master := model.Master{ + Name: name, + MasterKey: string(hashedKey), + MasterKeyDigest: digest, + Group: normalizeGroup(item.Group), + DefaultNamespace: defaultNS, + Namespaces: normalizedNamespaces, + MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys), + GlobalQPS: item.GlobalQPS, + Status: normalizeStatus(item.Status, "active"), + Epoch: 1, + } + if err := i.db.Create(&master).Error; err != nil { + return nil, err + } + out[name] = master + summary.MastersCreated++ + summary.MasterKeys = append(summary.MasterKeys, MasterCredential{ + Name: master.Name, + MasterKey: rawKey, + }) + } + return out, nil +} + +func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error { + for _, item := range items { + name := strings.TrimSpace(item.Name) + if name == "" { + summary.Warnings = append(summary.Warnings, "skip provider with empty name") + continue + } + var existing model.Provider + err := i.db.Where("name = ?", name).First(&existing).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil { + switch i.opts.ConflictPolicy { + case ConflictOverwrite: + if i.opts.DryRun { + summary.ProvidersUpdated++ + continue + } + update := map[string]any{ + "type": strings.TrimSpace(item.Type), + "base_url": strings.TrimSpace(item.BaseURL), + "api_key": strings.TrimSpace(item.APIKey), + "group": normalizeGroup(item.PrimaryGroup), + "models": strings.Join(item.Models, ","), + "weight": resolveWeight(item.Weight, item.Priority), + "status": normalizeProviderStatus(item.Status), + "auto_ban": item.AutoBan, + } + if err := i.db.Model(&existing).Updates(update).Error; err != nil { + return err + } + summary.ProvidersUpdated++ + default: + summary.ProvidersSkipped++ + } + continue + } + + if i.opts.DryRun { + summary.ProvidersCreated++ + continue + } + + provider := model.Provider{ + Name: name, + Type: strings.TrimSpace(item.Type), + BaseURL: strings.TrimSpace(item.BaseURL), + APIKey: strings.TrimSpace(item.APIKey), + Group: normalizeGroup(item.PrimaryGroup), + Models: strings.Join(item.Models, ","), + Weight: resolveWeight(item.Weight, item.Priority), + Status: normalizeProviderStatus(item.Status), + AutoBan: item.AutoBan, + } + if err := i.db.Create(&provider).Error; err != nil { + return err + } + summary.ProvidersCreated++ + } + return nil +} + +func (i *Importer) importKeys(items []Key, masters map[string]model.Master, summary *ImportSummary) error { + for _, item := range items { + masterName := strings.TrimSpace(item.MasterRef) + if masterName == "" { + summary.Warnings = append(summary.Warnings, "skip key with empty master_ref") + continue + } + master, ok := masters[masterName] + if !ok { + return fmt.Errorf("master not found for key: %s", masterName) + } + rawToken := strings.TrimSpace(item.OriginalToken) + if rawToken == "" || !item.TokenPlaintextAvailable { + summary.Warnings = append(summary.Warnings, fmt.Sprintf("skip key for master %s: missing plaintext token", masterName)) + continue + } + tokenHash := tokenhash.HashToken(rawToken) + + var existing model.Key + err := i.db.Where("token_hash = ?", tokenHash).First(&existing).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil { + switch i.opts.ConflictPolicy { + case ConflictOverwrite: + if i.opts.DryRun { + summary.KeysUpdated++ + continue + } + update := map[string]any{ + "master_id": master.ID, + "group": normalizeGroup(item.Group, master.Group), + "scopes": strings.Join(item.Scopes, ","), + "default_namespace": normalizeNamespace(master.DefaultNamespace), + "namespaces": normalizeNamespaces(item.Namespaces, master.DefaultNamespace), + "issued_at_epoch": master.Epoch, + "status": normalizeStatus(item.Status, "active"), + "issued_by": "import", + "model_limits": strings.Join(item.ModelLimits, ","), + "model_limits_enabled": item.ModelLimitsEnabled, + "expires_at": item.ExpiresAt, + "allow_ips": strings.Join(item.AllowIPs, ","), + "quota_limit": normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary), + "quota_used": normalizeQuotaUsed(item.QuotaUsed), + } + if err := i.db.Model(&existing).Updates(update).Error; err != nil { + return err + } + summary.KeysUpdated++ + default: + summary.KeysSkipped++ + } + continue + } + + if i.opts.DryRun { + summary.KeysCreated++ + continue + } + + hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost) + if err != nil { + return err + } + + key := model.Key{ + MasterID: master.ID, + KeySecret: string(hashedKey), + TokenHash: tokenHash, + Group: normalizeGroup(item.Group, master.Group), + Scopes: strings.Join(item.Scopes, ","), + DefaultNamespace: normalizeNamespace(master.DefaultNamespace), + Namespaces: normalizeNamespaces(item.Namespaces, master.DefaultNamespace), + IssuedAtEpoch: master.Epoch, + Status: normalizeStatus(item.Status, "active"), + IssuedBy: "import", + ModelLimits: strings.Join(item.ModelLimits, ","), + ModelLimitsEnabled: item.ModelLimitsEnabled, + ExpiresAt: item.ExpiresAt, + AllowIPs: strings.Join(item.AllowIPs, ","), + QuotaLimit: normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary), + QuotaUsed: normalizeQuotaUsed(item.QuotaUsed), + } + if err := i.db.Create(&key).Error; err != nil { + return err + } + summary.KeysCreated++ + } + return nil +} + +func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error { + for _, item := range items { + ns := normalizeNamespace(item.Namespace, "default") + publicModel := strings.TrimSpace(item.Model) + if publicModel == "" { + summary.Warnings = append(summary.Warnings, "skip binding with empty model") + continue + } + var existing model.Binding + err := i.db.Where("namespace = ? AND public_model = ?", ns, publicModel).First(&existing).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil { + switch i.opts.ConflictPolicy { + case ConflictOverwrite: + if i.opts.DryRun { + summary.BindingsUpdated++ + continue + } + update := map[string]any{ + "route_group": normalizeGroup(item.RouteGroup), + "selector_type": "exact", + "selector_value": publicModel, + "status": normalizeStatus(item.Status, "active"), + } + if err := i.db.Model(&existing).Updates(update).Error; err != nil { + return err + } + summary.BindingsUpdated++ + default: + summary.BindingsSkipped++ + } + continue + } + + if i.opts.DryRun { + summary.BindingsCreated++ + continue + } + + binding := model.Binding{ + Namespace: ns, + PublicModel: publicModel, + RouteGroup: normalizeGroup(item.RouteGroup), + SelectorType: "exact", + SelectorValue: publicModel, + Status: normalizeStatus(item.Status, "active"), + } + if err := i.db.Create(&binding).Error; err != nil { + return err + } + summary.BindingsCreated++ + } + return nil +} + +func addNamespaces(set map[string]struct{}, names ...string) { + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" { + continue + } + set[name] = struct{}{} + } +} + +func normalizeGroup(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "default" +} + +func normalizeNamespace(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "default" +} + +func normalizeNamespaces(values []string, defaultNamespace string) string { + if len(values) == 0 { + return normalizeNamespace(defaultNamespace) + } + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, v := range values { + v = strings.TrimSpace(v) + if v == "" { + continue + } + if _, ok := seen[v]; ok { + continue + } + seen[v] = struct{}{} + out = append(out, v) + } + if len(out) == 0 { + return normalizeNamespace(defaultNamespace) + } + if defaultNamespace != "" { + if _, ok := seen[defaultNamespace]; !ok { + out = append(out, defaultNamespace) + } + } + return strings.Join(out, ",") +} + +func normalizeMaxChildKeys(value int) int { + if value > 0 { + return value + } + return 5 +} + +func normalizeWeight(value int) int { + if value > 0 { + return value + } + return 1 +} + +func resolveWeight(weight int, priority int) int { + if weight > 0 { + return weight + } + if priority > 0 { + return priority + } + return normalizeWeight(weight) +} + +func normalizeProviderStatus(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" || value == "active" { + return "active" + } + if value == "disabled" { + return "manual_disabled" + } + return value +} + +func normalizeStatus(value string, fallback string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return fallback + } + return value +} + +func normalizeQuotaLimit(limit *int64, unlimited bool, summary *ImportSummary) int64 { + if unlimited { + return -1 + } + if limit != nil { + return *limit + } + if summary != nil { + summary.Warnings = append(summary.Warnings, "quota_limit missing for limited key, defaulting to -1") + } + return -1 +} + +func normalizeQuotaUsed(used *int64) int64 { + if used == nil { + return 0 + } + return *used +} + +func generateRandomKey(length int) (string, error) { + if length <= 0 { + return "", errors.New("invalid key length") + } + buf := make([]byte, length) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} diff --git a/internal/migrate/schema.go b/internal/migrate/schema.go new file mode 100644 index 0000000..7188aca --- /dev/null +++ b/internal/migrate/schema.go @@ -0,0 +1,98 @@ +package migrate + +import ( + "encoding/json" + "time" +) + +// ExportResult represents the complete export output. +type ExportResult struct { + Version string `json:"version"` + Source Source `json:"source"` + Data Data `json:"data"` + Warnings []string `json:"warnings,omitempty"` +} + +// Source represents the source system information. +type Source struct { + Type string `json:"type"` + Version string `json:"version"` + ExportedAt time.Time `json:"exported_at"` +} + +// Data contains all exported entities. +type Data struct { + Providers []Provider `json:"providers,omitempty"` + Masters []Master `json:"masters,omitempty"` + Keys []Key `json:"keys,omitempty"` + Bindings []Binding `json:"bindings,omitempty"` +} + +// Provider represents an EZ-API provider (mapped from New API channel). +type Provider struct { + OriginalID int `json:"original_id"` + Name string `json:"name"` + Type string `json:"type"` + BaseURL string `json:"base_url,omitempty"` + APIKey string `json:"api_key"` + Models []string `json:"models,omitempty"` + PrimaryGroup string `json:"primary_group"` + AllGroups []string `json:"all_groups,omitempty"` + Weight int `json:"weight"` + Priority int `json:"priority,omitempty"` + Status string `json:"status"` + AutoBan bool `json:"auto_ban"` + + IsMultiKey bool `json:"is_multi_key,omitempty"` + MultiKeyIndex int `json:"multi_key_index,omitempty"` + OriginalName string `json:"original_name,omitempty"` + + Original json.RawMessage `json:"_original,omitempty"` +} + +// Master represents an EZ-API master (inferred from New API user). +type Master struct { + Name string `json:"name"` + Group string `json:"group"` + Namespaces []string `json:"namespaces,omitempty"` + DefaultNamespace string `json:"default_namespace,omitempty"` + MaxChildKeys int `json:"max_child_keys,omitempty"` + GlobalQPS int `json:"global_qps,omitempty"` + Status string `json:"status"` + + SourceUserID int `json:"_source_user_id"` + SourceEmail string `json:"_source_email,omitempty"` +} + +// Key represents an EZ-API key (mapped from New API token). +type Key struct { + MasterRef string `json:"master_ref"` + OriginalToken string `json:"original_token"` + Group string `json:"group,omitempty"` + Status string `json:"status"` + + Scopes []string `json:"scopes,omitempty"` + Namespaces []string `json:"namespaces,omitempty"` + + ModelLimitsEnabled bool `json:"model_limits_enabled,omitempty"` + ModelLimits []string `json:"model_limits,omitempty"` + + ExpiresAt *time.Time `json:"expires_at,omitempty"` + + AllowIPs []string `json:"allow_ips,omitempty"` + + QuotaLimit *int64 `json:"quota_limit,omitempty"` + QuotaUsed *int64 `json:"quota_used,omitempty"` + UnlimitedQuota bool `json:"unlimited_quota,omitempty"` + + OriginalID int `json:"_original_id"` + TokenPlaintextAvailable bool `json:"_token_plaintext_available,omitempty"` +} + +// Binding represents an EZ-API binding (optional, from abilities). +type Binding struct { + Namespace string `json:"namespace"` + RouteGroup string `json:"route_group"` + Model string `json:"model"` + Status string `json:"status"` +}