package service import ( "crypto/rand" "encoding/hex" "errors" "fmt" "strings" "time" "github.com/ez-api/ez-api/internal/model" "github.com/ez-api/foundation/tokenhash" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) var ( ErrMasterNotFound = errors.New("master not found") ErrMasterNotActive = errors.New("master is not active") ErrChildKeyLimitReached = errors.New("child key limit reached") ErrChildKeyGroupForbidden = errors.New("cannot issue key for a different group") ErrModelLimitForbidden = errors.New("model not in master's accessible models") ) type MasterService struct { db *gorm.DB } func NewMasterService(db *gorm.DB) *MasterService { return &MasterService{db: db} } func (s *MasterService) CreateMaster(name, group string, maxChildKeys, globalQPS int) (*model.Master, string, error) { rawMasterKey, err := generateRandomKey(32) if err != nil { return nil, "", fmt.Errorf("failed to generate master key: %w", err) } hashedMasterKey, err := bcrypt.GenerateFromPassword([]byte(rawMasterKey), bcrypt.DefaultCost) if err != nil { return nil, "", fmt.Errorf("failed to hash master key: %w", err) } masterKeyDigest := tokenhash.HashToken(rawMasterKey) master := &model.Master{ Name: name, MasterKey: string(hashedMasterKey), MasterKeyDigest: masterKeyDigest, Group: group, DefaultNamespace: "default", Namespaces: "default", MaxChildKeys: maxChildKeys, GlobalQPS: globalQPS, Status: "active", Epoch: 1, } if err := s.db.Create(master).Error; err != nil { return nil, "", err } return master, rawMasterKey, nil } func (s *MasterService) ValidateMasterKey(masterKey string) (*model.Master, error) { digest := tokenhash.HashToken(masterKey) var master model.Master if err := s.db.Where("master_key_digest = ?", digest).First(&master).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } // Backward compatibility: look for legacy rows without digest. var masters []model.Master if err := s.db.Where("master_key_digest = '' OR master_key_digest IS NULL").Find(&masters).Error; err != nil { return nil, err } for _, m := range masters { if bcrypt.CompareHashAndPassword([]byte(m.MasterKey), []byte(masterKey)) == nil { master = m // Opportunistically backfill digest for next time. if strings.TrimSpace(m.MasterKeyDigest) == "" { _ = s.db.Model(&m).Update("master_key_digest", digest).Error } goto verified } } return nil, errors.New("invalid master key") } if bcrypt.CompareHashAndPassword([]byte(master.MasterKey), []byte(masterKey)) != nil { return nil, errors.New("invalid master key") } verified: if master.Status != "active" { return nil, fmt.Errorf("master is not active") } return &master, nil } type IssueKeyOptions struct { Group string Scopes string ModelLimits string ModelLimitsEnabled bool ExpiresAt *time.Time AllowIPs string DenyIPs string } func (s *MasterService) IssueChildKey(masterID uint, opts IssueKeyOptions) (*model.Key, string, error) { return s.issueChildKey(masterID, opts, "master") } func (s *MasterService) IssueChildKeyAsAdmin(masterID uint, opts IssueKeyOptions) (*model.Key, string, error) { return s.issueChildKey(masterID, opts, "admin") } func (s *MasterService) issueChildKey(masterID uint, opts IssueKeyOptions, issuedBy string) (*model.Key, string, error) { var master model.Master if err := s.db.First(&master, masterID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, "", fmt.Errorf("%w: %d", ErrMasterNotFound, masterID) } return nil, "", fmt.Errorf("load master: %w", err) } if master.Status != "active" { return nil, "", fmt.Errorf("%w", ErrMasterNotActive) } group := strings.TrimSpace(opts.Group) if group == "" { group = master.Group } if group != master.Group { return nil, "", fmt.Errorf("%w", ErrChildKeyGroupForbidden) } var count int64 s.db.Model(&model.Key{}).Where("master_id = ?", masterID).Count(&count) if count >= int64(master.MaxChildKeys) { return nil, "", fmt.Errorf("%w for master %d", ErrChildKeyLimitReached, masterID) } rawChildKey, err := generateRandomKey(32) if err != nil { return nil, "", fmt.Errorf("failed to generate child key: %w", err) } tokenHash := tokenhash.HashToken(rawChildKey) hashedChildKey, err := bcrypt.GenerateFromPassword([]byte(rawChildKey), bcrypt.DefaultCost) if err != nil { return nil, "", fmt.Errorf("failed to hash child key: %w", err) } if err := s.ValidateModelLimits(&master, opts.ModelLimits); err != nil { return nil, "", err } key := &model.Key{ MasterID: masterID, KeySecret: string(hashedChildKey), TokenHash: tokenHash, Group: group, Scopes: strings.TrimSpace(opts.Scopes), DefaultNamespace: strings.TrimSpace(master.DefaultNamespace), Namespaces: strings.TrimSpace(master.Namespaces), IssuedAtEpoch: master.Epoch, Status: "active", IssuedBy: strings.TrimSpace(issuedBy), ModelLimits: strings.TrimSpace(opts.ModelLimits), ModelLimitsEnabled: opts.ModelLimitsEnabled, ExpiresAt: opts.ExpiresAt, AllowIPs: strings.TrimSpace(opts.AllowIPs), DenyIPs: strings.TrimSpace(opts.DenyIPs), } if strings.TrimSpace(key.DefaultNamespace) == "" { key.DefaultNamespace = "default" } if strings.TrimSpace(key.Namespaces) == "" { key.Namespaces = key.DefaultNamespace } if key.IssuedBy == "" { key.IssuedBy = "master" } if err := s.db.Create(key).Error; err != nil { return nil, "", err } return key, rawChildKey, nil } func (s *MasterService) ValidateModelLimits(master *model.Master, limits string) error { if master == nil { return fmt.Errorf("master is required") } requested := splitList(limits) if len(requested) == 0 { return nil } namespaces := normalizeNamespaces(master.Namespaces, master.DefaultNamespace) if len(namespaces) == 0 { return fmt.Errorf("master has no namespaces configured") } var bindings []model.Binding if err := s.db.Where("namespace IN ?", namespaces).Find(&bindings).Error; err != nil { return fmt.Errorf("load bindings: %w", err) } allowedBindings := make(map[string]struct{}, len(bindings)) allowedPublic := make(map[string]struct{}, len(bindings)) for _, b := range bindings { if !bindingActive(b.Status) { continue } ns := strings.TrimSpace(b.Namespace) pm := strings.TrimSpace(b.PublicModel) if ns == "" || pm == "" { continue } allowedBindings[ns+"."+pm] = struct{}{} allowedPublic[pm] = struct{}{} } for _, m := range requested { if strings.Contains(m, ".") { if _, ok := allowedBindings[m]; !ok { return fmt.Errorf("%w: %s", ErrModelLimitForbidden, m) } continue } if _, ok := allowedPublic[m]; !ok { return fmt.Errorf("%w: %s", ErrModelLimitForbidden, m) } } return nil } func generateRandomKey(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", err } return hex.EncodeToString(bytes), nil } func normalizeDefaultNamespace(ns string) string { ns = strings.TrimSpace(ns) if ns == "" { return "default" } return ns } func normalizeNamespaces(raw string, defaultNamespace string) []string { defaultNamespace = normalizeDefaultNamespace(defaultNamespace) raw = strings.TrimSpace(raw) var parts []string if raw != "" { parts = strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == ';' || r == '\t' || r == '\n' }) } out := make([]string, 0, len(parts)+1) seen := make(map[string]struct{}, len(parts)+1) for _, p := range parts { p = strings.TrimSpace(p) if p == "" { continue } if _, ok := seen[p]; ok { continue } seen[p] = struct{}{} out = append(out, p) } if _, ok := seen[defaultNamespace]; !ok { out = append(out, defaultNamespace) } return out } func splitList(raw string) []string { raw = strings.TrimSpace(raw) if raw == "" { return nil } parts := strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == ';' || r == '\t' || r == '\n' }) out := make([]string, 0, len(parts)) seen := make(map[string]struct{}, len(parts)) for _, p := range parts { p = strings.TrimSpace(p) if p == "" { continue } if _, ok := seen[p]; ok { continue } seen[p] = struct{}{} out = append(out, p) } return out } func bindingActive(status string) bool { status = strings.ToLower(strings.TrimSpace(status)) return status == "" || status == "active" }