mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Restructure the provider management system by separating the monolithic Provider model into two distinct entities: - ProviderGroup: defines shared upstream configuration (type, base_url, google settings, models, status) - APIKey: represents individual credentials within a group (api_key, weight, status, auto_ban, ban settings) This change also updates: - Binding model to reference GroupID instead of RouteGroup string - All CRUD handlers for the new provider-group and api-key endpoints - Sync service to rebuild provider snapshots from joined tables - Model registry to aggregate capabilities across group/key pairs - Access handler to validate namespace existence and subset constraints - Migration importer to handle the new schema structure - All related tests to use the new model relationships BREAKING CHANGE: Provider API endpoints replaced with /provider-groups and /api-keys endpoints; Binding.RouteGroup replaced with Binding.GroupID
633 lines
16 KiB
Go
633 lines
16 KiB
Go
package migrate
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/ez-api/ez-api/internal/model"
|
|
"github.com/ez-api/foundation/tokenhash"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
ConflictSkip = "skip"
|
|
ConflictOverwrite = "overwrite"
|
|
)
|
|
|
|
type ImportOptions struct {
|
|
DryRun bool
|
|
ConflictPolicy string
|
|
IncludeBindings bool
|
|
}
|
|
|
|
type MasterCredential struct {
|
|
Name string `json:"name"`
|
|
MasterKey string `json:"master_key"`
|
|
}
|
|
|
|
type ImportSummary struct {
|
|
ProvidersCreated int `json:"providers_created"`
|
|
ProvidersUpdated int `json:"providers_updated"`
|
|
ProvidersSkipped int `json:"providers_skipped"`
|
|
|
|
MastersCreated int `json:"masters_created"`
|
|
MastersUpdated int `json:"masters_updated"`
|
|
MastersSkipped int `json:"masters_skipped"`
|
|
|
|
KeysCreated int `json:"keys_created"`
|
|
KeysUpdated int `json:"keys_updated"`
|
|
KeysSkipped int `json:"keys_skipped"`
|
|
|
|
BindingsCreated int `json:"bindings_created"`
|
|
BindingsUpdated int `json:"bindings_updated"`
|
|
BindingsSkipped int `json:"bindings_skipped"`
|
|
|
|
NamespacesCreated int `json:"namespaces_created"`
|
|
|
|
MasterKeys []MasterCredential `json:"master_keys,omitempty"`
|
|
Warnings []string `json:"warnings,omitempty"`
|
|
}
|
|
|
|
type Importer struct {
|
|
db *gorm.DB
|
|
opts ImportOptions
|
|
}
|
|
|
|
func NewImporter(db *gorm.DB, opts ImportOptions) *Importer {
|
|
opts.ConflictPolicy = strings.ToLower(strings.TrimSpace(opts.ConflictPolicy))
|
|
if opts.ConflictPolicy == "" {
|
|
opts.ConflictPolicy = ConflictSkip
|
|
}
|
|
return &Importer{db: db, opts: opts}
|
|
}
|
|
|
|
func (i *Importer) ImportFile(path string) (*ImportSummary, error) {
|
|
raw, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read file: %w", err)
|
|
}
|
|
var payload ExportResult
|
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
|
return nil, fmt.Errorf("parse json: %w", err)
|
|
}
|
|
return i.Import(&payload)
|
|
}
|
|
|
|
func (i *Importer) Import(payload *ExportResult) (*ImportSummary, error) {
|
|
if payload == nil {
|
|
return nil, errors.New("empty payload")
|
|
}
|
|
if i.db == nil {
|
|
return nil, errors.New("db is required")
|
|
}
|
|
|
|
summary := &ImportSummary{}
|
|
if len(payload.Warnings) > 0 {
|
|
summary.Warnings = append(summary.Warnings, payload.Warnings...)
|
|
}
|
|
|
|
if err := i.ensureNamespaces(payload, summary); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
masterMap, err := i.importMasters(payload.Data.Masters, summary)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := i.importProviders(payload.Data.Providers, summary); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := i.importKeys(payload.Data.Keys, masterMap, summary); err != nil {
|
|
return nil, err
|
|
}
|
|
if i.opts.IncludeBindings {
|
|
if err := i.importBindings(payload.Data.Bindings, summary); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return summary, nil
|
|
}
|
|
|
|
func (i *Importer) ensureNamespaces(payload *ExportResult, summary *ImportSummary) error {
|
|
namespaceSet := make(map[string]struct{})
|
|
namespaceSet["default"] = struct{}{}
|
|
for _, master := range payload.Data.Masters {
|
|
addNamespaces(namespaceSet, master.DefaultNamespace)
|
|
addNamespaces(namespaceSet, master.Namespaces...)
|
|
}
|
|
for _, key := range payload.Data.Keys {
|
|
addNamespaces(namespaceSet, key.Namespaces...)
|
|
}
|
|
if i.opts.IncludeBindings {
|
|
for _, binding := range payload.Data.Bindings {
|
|
addNamespaces(namespaceSet, binding.Namespace)
|
|
}
|
|
}
|
|
|
|
for name := range namespaceSet {
|
|
if err := i.ensureNamespace(name, summary); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Importer) ensureNamespace(name string, summary *ImportSummary) error {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return nil
|
|
}
|
|
var existing model.Namespace
|
|
err := i.db.Where("name = ?", name).First(&existing).Error
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if i.opts.DryRun {
|
|
summary.NamespacesCreated++
|
|
return nil
|
|
}
|
|
ns := &model.Namespace{Name: name, Status: "active"}
|
|
if err := i.db.Create(ns).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.NamespacesCreated++
|
|
return nil
|
|
}
|
|
|
|
func (i *Importer) importMasters(items []Master, summary *ImportSummary) (map[string]model.Master, error) {
|
|
out := make(map[string]model.Master)
|
|
for _, item := range items {
|
|
name := strings.TrimSpace(item.Name)
|
|
if name == "" {
|
|
summary.Warnings = append(summary.Warnings, "skip master with empty name")
|
|
continue
|
|
}
|
|
var existing model.Master
|
|
err := i.db.Where("name = ?", name).First(&existing).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
if err == nil {
|
|
out[name] = existing
|
|
defaultNS := normalizeNamespace(item.DefaultNamespace)
|
|
switch i.opts.ConflictPolicy {
|
|
case ConflictOverwrite:
|
|
if i.opts.DryRun {
|
|
summary.MastersUpdated++
|
|
continue
|
|
}
|
|
update := map[string]any{
|
|
"group": normalizeGroup(item.Group),
|
|
"default_namespace": defaultNS,
|
|
"namespaces": normalizeNamespaces(item.Namespaces, defaultNS),
|
|
"max_child_keys": normalizeMaxChildKeys(item.MaxChildKeys),
|
|
"global_qps": item.GlobalQPS,
|
|
"status": normalizeStatus(item.Status, "active"),
|
|
}
|
|
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
out[name] = existing
|
|
summary.MastersUpdated++
|
|
default:
|
|
summary.MastersSkipped++
|
|
}
|
|
continue
|
|
}
|
|
|
|
defaultNS := normalizeNamespace(item.DefaultNamespace)
|
|
normalizedNamespaces := normalizeNamespaces(item.Namespaces, defaultNS)
|
|
if i.opts.DryRun {
|
|
summary.MastersCreated++
|
|
out[name] = model.Master{
|
|
Name: name,
|
|
Group: normalizeGroup(item.Group),
|
|
DefaultNamespace: defaultNS,
|
|
Namespaces: normalizedNamespaces,
|
|
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
|
|
GlobalQPS: item.GlobalQPS,
|
|
Status: normalizeStatus(item.Status, "active"),
|
|
Epoch: 1,
|
|
}
|
|
continue
|
|
}
|
|
|
|
rawKey, err := generateRandomKey(32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
digest := tokenhash.HashToken(rawKey)
|
|
|
|
master := model.Master{
|
|
Name: name,
|
|
MasterKey: string(hashedKey),
|
|
MasterKeyDigest: digest,
|
|
Group: normalizeGroup(item.Group),
|
|
DefaultNamespace: defaultNS,
|
|
Namespaces: normalizedNamespaces,
|
|
MaxChildKeys: normalizeMaxChildKeys(item.MaxChildKeys),
|
|
GlobalQPS: item.GlobalQPS,
|
|
Status: normalizeStatus(item.Status, "active"),
|
|
Epoch: 1,
|
|
}
|
|
if err := i.db.Create(&master).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
out[name] = master
|
|
summary.MastersCreated++
|
|
summary.MasterKeys = append(summary.MasterKeys, MasterCredential{
|
|
Name: master.Name,
|
|
MasterKey: rawKey,
|
|
})
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (i *Importer) importProviders(items []Provider, summary *ImportSummary) error {
|
|
groupCache := make(map[string]model.ProviderGroup)
|
|
for _, item := range items {
|
|
groupName := normalizeGroup(item.PrimaryGroup)
|
|
if strings.TrimSpace(groupName) == "" {
|
|
groupName = "default"
|
|
}
|
|
group, ok := groupCache[groupName]
|
|
if !ok {
|
|
var existing model.ProviderGroup
|
|
err := i.db.Where("name = ?", groupName).First(&existing).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if err == nil {
|
|
group = existing
|
|
} else {
|
|
group = model.ProviderGroup{
|
|
Name: groupName,
|
|
Type: strings.TrimSpace(item.Type),
|
|
BaseURL: strings.TrimSpace(item.BaseURL),
|
|
Models: strings.Join(item.Models, ","),
|
|
Status: normalizeStatus(item.Status, "active"),
|
|
}
|
|
if !i.opts.DryRun {
|
|
if err := i.db.Create(&group).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
groupCache[groupName] = group
|
|
}
|
|
|
|
apiKey := strings.TrimSpace(item.APIKey)
|
|
if apiKey == "" {
|
|
summary.Warnings = append(summary.Warnings, "skip api key with empty api_key")
|
|
continue
|
|
}
|
|
var existingKey model.APIKey
|
|
err := i.db.Where("group_id = ? AND api_key = ?", group.ID, apiKey).First(&existingKey).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if err == nil {
|
|
switch i.opts.ConflictPolicy {
|
|
case ConflictOverwrite:
|
|
if i.opts.DryRun {
|
|
summary.ProvidersUpdated++
|
|
continue
|
|
}
|
|
update := map[string]any{
|
|
"weight": resolveWeight(item.Weight, item.Priority),
|
|
"status": normalizeProviderStatus(item.Status),
|
|
"auto_ban": item.AutoBan,
|
|
}
|
|
if err := i.db.Model(&existingKey).Updates(update).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.ProvidersUpdated++
|
|
default:
|
|
summary.ProvidersSkipped++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if i.opts.DryRun {
|
|
summary.ProvidersCreated++
|
|
continue
|
|
}
|
|
|
|
key := model.APIKey{
|
|
GroupID: group.ID,
|
|
APIKey: apiKey,
|
|
Weight: resolveWeight(item.Weight, item.Priority),
|
|
Status: normalizeProviderStatus(item.Status),
|
|
AutoBan: item.AutoBan,
|
|
}
|
|
if err := i.db.Create(&key).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.ProvidersCreated++
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Importer) importKeys(items []Key, masters map[string]model.Master, summary *ImportSummary) error {
|
|
for _, item := range items {
|
|
masterName := strings.TrimSpace(item.MasterRef)
|
|
if masterName == "" {
|
|
summary.Warnings = append(summary.Warnings, "skip key with empty master_ref")
|
|
continue
|
|
}
|
|
master, ok := masters[masterName]
|
|
if !ok {
|
|
return fmt.Errorf("master not found for key: %s", masterName)
|
|
}
|
|
rawToken := strings.TrimSpace(item.OriginalToken)
|
|
if rawToken == "" || !item.TokenPlaintextAvailable {
|
|
summary.Warnings = append(summary.Warnings, fmt.Sprintf("skip key for master %s: missing plaintext token", masterName))
|
|
continue
|
|
}
|
|
tokenHash := tokenhash.HashToken(rawToken)
|
|
|
|
var existing model.Key
|
|
err := i.db.Where("token_hash = ?", tokenHash).First(&existing).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if err == nil {
|
|
switch i.opts.ConflictPolicy {
|
|
case ConflictOverwrite:
|
|
if i.opts.DryRun {
|
|
summary.KeysUpdated++
|
|
continue
|
|
}
|
|
update := map[string]any{
|
|
"master_id": master.ID,
|
|
"group": normalizeGroup(item.Group, master.Group),
|
|
"scopes": strings.Join(item.Scopes, ","),
|
|
"default_namespace": normalizeNamespace(master.DefaultNamespace),
|
|
"namespaces": normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
|
|
"issued_at_epoch": master.Epoch,
|
|
"status": normalizeStatus(item.Status, "active"),
|
|
"issued_by": "import",
|
|
"model_limits": strings.Join(item.ModelLimits, ","),
|
|
"model_limits_enabled": item.ModelLimitsEnabled,
|
|
"expires_at": item.ExpiresAt,
|
|
"allow_ips": strings.Join(item.AllowIPs, ","),
|
|
"quota_limit": normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
|
|
"quota_used": normalizeQuotaUsed(item.QuotaUsed),
|
|
}
|
|
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.KeysUpdated++
|
|
default:
|
|
summary.KeysSkipped++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if i.opts.DryRun {
|
|
summary.KeysCreated++
|
|
continue
|
|
}
|
|
|
|
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawToken), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
key := model.Key{
|
|
MasterID: master.ID,
|
|
KeySecret: string(hashedKey),
|
|
TokenHash: tokenHash,
|
|
Group: normalizeGroup(item.Group, master.Group),
|
|
Scopes: strings.Join(item.Scopes, ","),
|
|
DefaultNamespace: normalizeNamespace(master.DefaultNamespace),
|
|
Namespaces: normalizeNamespaces(item.Namespaces, master.DefaultNamespace),
|
|
IssuedAtEpoch: master.Epoch,
|
|
Status: normalizeStatus(item.Status, "active"),
|
|
IssuedBy: "import",
|
|
ModelLimits: strings.Join(item.ModelLimits, ","),
|
|
ModelLimitsEnabled: item.ModelLimitsEnabled,
|
|
ExpiresAt: item.ExpiresAt,
|
|
AllowIPs: strings.Join(item.AllowIPs, ","),
|
|
QuotaLimit: normalizeQuotaLimit(item.QuotaLimit, item.UnlimitedQuota, summary),
|
|
QuotaUsed: normalizeQuotaUsed(item.QuotaUsed),
|
|
}
|
|
if err := i.db.Create(&key).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.KeysCreated++
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Importer) importBindings(items []Binding, summary *ImportSummary) error {
|
|
for _, item := range items {
|
|
ns := normalizeNamespace(item.Namespace, "default")
|
|
publicModel := strings.TrimSpace(item.Model)
|
|
if publicModel == "" {
|
|
summary.Warnings = append(summary.Warnings, "skip binding with empty model")
|
|
continue
|
|
}
|
|
groupName := normalizeGroup(item.RouteGroup)
|
|
var group model.ProviderGroup
|
|
if err := i.db.Where("name = ?", groupName).First(&group).Error; err != nil {
|
|
summary.Warnings = append(summary.Warnings, "skip binding with missing provider group: "+groupName)
|
|
continue
|
|
}
|
|
var existing model.Binding
|
|
err := i.db.Where("namespace = ? AND public_model = ? AND group_id = ?", ns, publicModel, group.ID).First(&existing).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if err == nil {
|
|
switch i.opts.ConflictPolicy {
|
|
case ConflictOverwrite:
|
|
if i.opts.DryRun {
|
|
summary.BindingsUpdated++
|
|
continue
|
|
}
|
|
update := map[string]any{
|
|
"group_id": group.ID,
|
|
"weight": 1,
|
|
"selector_type": "exact",
|
|
"selector_value": publicModel,
|
|
"status": normalizeStatus(item.Status, "active"),
|
|
}
|
|
if err := i.db.Model(&existing).Updates(update).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.BindingsUpdated++
|
|
default:
|
|
summary.BindingsSkipped++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if i.opts.DryRun {
|
|
summary.BindingsCreated++
|
|
continue
|
|
}
|
|
|
|
binding := model.Binding{
|
|
Namespace: ns,
|
|
PublicModel: publicModel,
|
|
GroupID: group.ID,
|
|
Weight: 1,
|
|
SelectorType: "exact",
|
|
SelectorValue: publicModel,
|
|
Status: normalizeStatus(item.Status, "active"),
|
|
}
|
|
if err := i.db.Create(&binding).Error; err != nil {
|
|
return err
|
|
}
|
|
summary.BindingsCreated++
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func addNamespaces(set map[string]struct{}, names ...string) {
|
|
for _, name := range names {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
set[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
func normalizeGroup(values ...string) string {
|
|
for _, v := range values {
|
|
v = strings.TrimSpace(v)
|
|
if v != "" {
|
|
return v
|
|
}
|
|
}
|
|
return "default"
|
|
}
|
|
|
|
func normalizeNamespace(values ...string) string {
|
|
for _, v := range values {
|
|
v = strings.TrimSpace(v)
|
|
if v != "" {
|
|
return v
|
|
}
|
|
}
|
|
return "default"
|
|
}
|
|
|
|
func normalizeNamespaces(values []string, defaultNamespace string) string {
|
|
if len(values) == 0 {
|
|
return normalizeNamespace(defaultNamespace)
|
|
}
|
|
seen := make(map[string]struct{}, len(values))
|
|
out := make([]string, 0, len(values))
|
|
for _, v := range values {
|
|
v = strings.TrimSpace(v)
|
|
if v == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[v]; ok {
|
|
continue
|
|
}
|
|
seen[v] = struct{}{}
|
|
out = append(out, v)
|
|
}
|
|
if len(out) == 0 {
|
|
return normalizeNamespace(defaultNamespace)
|
|
}
|
|
if defaultNamespace != "" {
|
|
if _, ok := seen[defaultNamespace]; !ok {
|
|
out = append(out, defaultNamespace)
|
|
}
|
|
}
|
|
return strings.Join(out, ",")
|
|
}
|
|
|
|
func normalizeMaxChildKeys(value int) int {
|
|
if value > 0 {
|
|
return value
|
|
}
|
|
return 5
|
|
}
|
|
|
|
func normalizeWeight(value int) int {
|
|
if value > 0 {
|
|
return value
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func resolveWeight(weight int, priority int) int {
|
|
if weight > 0 {
|
|
return weight
|
|
}
|
|
if priority > 0 {
|
|
return priority
|
|
}
|
|
return normalizeWeight(weight)
|
|
}
|
|
|
|
func normalizeProviderStatus(value string) string {
|
|
value = strings.ToLower(strings.TrimSpace(value))
|
|
if value == "" || value == "active" {
|
|
return "active"
|
|
}
|
|
if value == "disabled" {
|
|
return "manual_disabled"
|
|
}
|
|
return value
|
|
}
|
|
|
|
func normalizeStatus(value string, fallback string) string {
|
|
value = strings.ToLower(strings.TrimSpace(value))
|
|
if value == "" {
|
|
return fallback
|
|
}
|
|
return value
|
|
}
|
|
|
|
func normalizeQuotaLimit(limit *int64, unlimited bool, summary *ImportSummary) int64 {
|
|
if unlimited {
|
|
return -1
|
|
}
|
|
if limit != nil {
|
|
return *limit
|
|
}
|
|
if summary != nil {
|
|
summary.Warnings = append(summary.Warnings, "quota_limit missing for limited key, defaulting to -1")
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func normalizeQuotaUsed(used *int64) int64 {
|
|
if used == nil {
|
|
return 0
|
|
}
|
|
return *used
|
|
}
|
|
|
|
func generateRandomKey(length int) (string, error) {
|
|
if length <= 0 {
|
|
return "", errors.New("invalid key length")
|
|
}
|
|
buf := make([]byte, length)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|