feat(key): extend key metadata and validation

This commit is contained in:
zenfun
2025-12-19 21:24:24 +08:00
parent 5e98368428
commit 524f8c5a4e
6 changed files with 351 additions and 71 deletions

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strings"
"time"
"github.com/ez-api/ez-api/internal/model"
"github.com/ez-api/foundation/tokenhash"
@@ -18,6 +19,7 @@ var (
ErrMasterNotActive = errors.New("master is not active")
ErrChildKeyLimitReached = errors.New("child key limit reached")
ErrChildKeyGroupForbidden = errors.New("cannot issue key for a different group")
ErrModelLimitForbidden = errors.New("model not in master's accessible models")
)
type MasterService struct {
@@ -100,15 +102,25 @@ verified:
return &master, nil
}
func (s *MasterService) IssueChildKey(masterID uint, group string, scopes string) (*model.Key, string, error) {
return s.issueChildKey(masterID, group, scopes, "master")
type IssueKeyOptions struct {
Group string
Scopes string
ModelLimits string
ModelLimitsEnabled bool
ExpiresAt *time.Time
AllowIPs string
DenyIPs string
}
func (s *MasterService) IssueChildKeyAsAdmin(masterID uint, group string, scopes string) (*model.Key, string, error) {
return s.issueChildKey(masterID, group, scopes, "admin")
func (s *MasterService) IssueChildKey(masterID uint, opts IssueKeyOptions) (*model.Key, string, error) {
return s.issueChildKey(masterID, opts, "master")
}
func (s *MasterService) issueChildKey(masterID uint, group string, scopes string, issuedBy string) (*model.Key, string, error) {
func (s *MasterService) IssueChildKeyAsAdmin(masterID uint, opts IssueKeyOptions) (*model.Key, string, error) {
return s.issueChildKey(masterID, opts, "admin")
}
func (s *MasterService) issueChildKey(masterID uint, opts IssueKeyOptions, issuedBy string) (*model.Key, string, error) {
var master model.Master
if err := s.db.First(&master, masterID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -120,7 +132,7 @@ func (s *MasterService) issueChildKey(masterID uint, group string, scopes string
return nil, "", fmt.Errorf("%w", ErrMasterNotActive)
}
group = strings.TrimSpace(group)
group := strings.TrimSpace(opts.Group)
if group == "" {
group = master.Group
}
@@ -146,17 +158,26 @@ func (s *MasterService) issueChildKey(masterID uint, group string, scopes string
return nil, "", fmt.Errorf("failed to hash child key: %w", err)
}
if err := s.ValidateModelLimits(&master, opts.ModelLimits); err != nil {
return nil, "", err
}
key := &model.Key{
MasterID: masterID,
KeySecret: string(hashedChildKey),
TokenHash: tokenHash,
Group: group,
Scopes: scopes,
DefaultNamespace: strings.TrimSpace(master.DefaultNamespace),
Namespaces: strings.TrimSpace(master.Namespaces),
IssuedAtEpoch: master.Epoch,
Status: "active",
IssuedBy: strings.TrimSpace(issuedBy),
MasterID: masterID,
KeySecret: string(hashedChildKey),
TokenHash: tokenHash,
Group: group,
Scopes: strings.TrimSpace(opts.Scopes),
DefaultNamespace: strings.TrimSpace(master.DefaultNamespace),
Namespaces: strings.TrimSpace(master.Namespaces),
IssuedAtEpoch: master.Epoch,
Status: "active",
IssuedBy: strings.TrimSpace(issuedBy),
ModelLimits: strings.TrimSpace(opts.ModelLimits),
ModelLimitsEnabled: opts.ModelLimitsEnabled,
ExpiresAt: opts.ExpiresAt,
AllowIPs: strings.TrimSpace(opts.AllowIPs),
DenyIPs: strings.TrimSpace(opts.DenyIPs),
}
if strings.TrimSpace(key.DefaultNamespace) == "" {
key.DefaultNamespace = "default"
@@ -175,6 +196,53 @@ func (s *MasterService) issueChildKey(masterID uint, group string, scopes string
return key, rawChildKey, nil
}
func (s *MasterService) ValidateModelLimits(master *model.Master, limits string) error {
if master == nil {
return fmt.Errorf("master is required")
}
requested := splitList(limits)
if len(requested) == 0 {
return nil
}
namespaces := normalizeNamespaces(master.Namespaces, master.DefaultNamespace)
if len(namespaces) == 0 {
return fmt.Errorf("master has no namespaces configured")
}
var bindings []model.Binding
if err := s.db.Where("namespace IN ?", namespaces).Find(&bindings).Error; err != nil {
return fmt.Errorf("load bindings: %w", err)
}
allowedBindings := make(map[string]struct{}, len(bindings))
allowedPublic := make(map[string]struct{}, len(bindings))
for _, b := range bindings {
if !bindingActive(b.Status) {
continue
}
ns := strings.TrimSpace(b.Namespace)
pm := strings.TrimSpace(b.PublicModel)
if ns == "" || pm == "" {
continue
}
allowedBindings[ns+"."+pm] = struct{}{}
allowedPublic[pm] = struct{}{}
}
for _, m := range requested {
if strings.Contains(m, ".") {
if _, ok := allowedBindings[m]; !ok {
return fmt.Errorf("%w: %s", ErrModelLimitForbidden, m)
}
continue
}
if _, ok := allowedPublic[m]; !ok {
return fmt.Errorf("%w: %s", ErrModelLimitForbidden, m)
}
}
return nil
}
func generateRandomKey(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
@@ -182,3 +250,68 @@ func generateRandomKey(length int) (string, error) {
}
return hex.EncodeToString(bytes), nil
}
func normalizeDefaultNamespace(ns string) string {
ns = strings.TrimSpace(ns)
if ns == "" {
return "default"
}
return ns
}
func normalizeNamespaces(raw string, defaultNamespace string) []string {
defaultNamespace = normalizeDefaultNamespace(defaultNamespace)
raw = strings.TrimSpace(raw)
var parts []string
if raw != "" {
parts = strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == ' ' || r == ';' || r == '\t' || r == '\n'
})
}
out := make([]string, 0, len(parts)+1)
seen := make(map[string]struct{}, len(parts)+1)
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
if _, ok := seen[defaultNamespace]; !ok {
out = append(out, defaultNamespace)
}
return out
}
func splitList(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
parts := strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == ' ' || r == ';' || r == '\t' || r == '\n'
})
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
return out
}
func bindingActive(status string) bool {
status = strings.ToLower(strings.TrimSpace(status))
return status == "" || status == "active"
}

View File

@@ -53,7 +53,10 @@ func TestMasterService_IssueChildKey_RespectsLimit(t *testing.T) {
t.Fatalf("CreateMaster: %v", err)
}
_, raw1, err := svc.IssueChildKey(m.ID, "default", "chat:write")
_, raw1, err := svc.IssueChildKey(m.ID, IssueKeyOptions{
Group: "default",
Scopes: "chat:write",
})
if err != nil {
t.Fatalf("IssueChildKey #1: %v", err)
}
@@ -61,7 +64,10 @@ func TestMasterService_IssueChildKey_RespectsLimit(t *testing.T) {
t.Fatalf("expected raw child key")
}
_, _, err = svc.IssueChildKey(m.ID, "default", "chat:write")
_, _, err = svc.IssueChildKey(m.ID, IssueKeyOptions{
Group: "default",
Scopes: "chat:write",
})
if err == nil {
t.Fatalf("expected child key limit error")
}
@@ -76,7 +82,9 @@ func TestMasterService_IssueChildKeyAsAdmin_SetsIssuedBy(t *testing.T) {
t.Fatalf("CreateMaster: %v", err)
}
key, raw, err := svc.IssueChildKeyAsAdmin(m.ID, "", "chat:write")
key, raw, err := svc.IssueChildKeyAsAdmin(m.ID, IssueKeyOptions{
Scopes: "chat:write",
})
if err != nil {
t.Fatalf("IssueChildKeyAsAdmin: %v", err)
}

View File

@@ -36,14 +36,26 @@ func (s *SyncService) SyncKey(key *model.Key) error {
}
fields := map[string]interface{}{
"id": key.ID,
"master_id": key.MasterID,
"issued_at_epoch": key.IssuedAtEpoch,
"status": key.Status,
"group": key.Group,
"scopes": key.Scopes,
"default_namespace": key.DefaultNamespace,
"namespaces": key.Namespaces,
"id": key.ID,
"master_id": key.MasterID,
"issued_at_epoch": key.IssuedAtEpoch,
"status": key.Status,
"group": key.Group,
"scopes": key.Scopes,
"default_namespace": key.DefaultNamespace,
"namespaces": key.Namespaces,
"model_limits": strings.TrimSpace(key.ModelLimits),
"model_limits_enabled": key.ModelLimitsEnabled,
"expires_at": unixOrZero(key.ExpiresAt),
"allow_ips": strings.TrimSpace(key.AllowIPs),
"deny_ips": strings.TrimSpace(key.DenyIPs),
"last_accessed_at": unixOrZero(key.LastAccessedAt),
"request_count": key.RequestCount,
"used_tokens": key.UsedTokens,
"quota_limit": key.QuotaLimit,
"quota_used": key.QuotaUsed,
"quota_reset_at": unixOrZero(key.QuotaResetAt),
"quota_reset_type": strings.TrimSpace(key.QuotaResetType),
}
if err := s.rdb.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), fields).Err(); err != nil {
return fmt.Errorf("write auth token: %w", err)
@@ -261,14 +273,26 @@ func (s *SyncService) SyncAll(db *gorm.DB) error {
return fmt.Errorf("token hash missing for key %d", k.ID)
}
pipe.HSet(ctx, fmt.Sprintf("auth:token:%s", tokenHash), map[string]interface{}{
"id": k.ID,
"master_id": k.MasterID,
"issued_at_epoch": k.IssuedAtEpoch,
"status": k.Status,
"group": k.Group,
"scopes": k.Scopes,
"default_namespace": k.DefaultNamespace,
"namespaces": k.Namespaces,
"id": k.ID,
"master_id": k.MasterID,
"issued_at_epoch": k.IssuedAtEpoch,
"status": k.Status,
"group": k.Group,
"scopes": k.Scopes,
"default_namespace": k.DefaultNamespace,
"namespaces": k.Namespaces,
"model_limits": strings.TrimSpace(k.ModelLimits),
"model_limits_enabled": k.ModelLimitsEnabled,
"expires_at": unixOrZero(k.ExpiresAt),
"allow_ips": strings.TrimSpace(k.AllowIPs),
"deny_ips": strings.TrimSpace(k.DenyIPs),
"last_accessed_at": unixOrZero(k.LastAccessedAt),
"request_count": k.RequestCount,
"used_tokens": k.UsedTokens,
"quota_limit": k.QuotaLimit,
"quota_used": k.QuotaUsed,
"quota_reset_at": unixOrZero(k.QuotaResetAt),
"quota_reset_type": strings.TrimSpace(k.QuotaResetType),
})
}
@@ -464,6 +488,13 @@ func normalizeStatus(status string) string {
}
}
func unixOrZero(t *time.Time) int64 {
if t == nil {
return 0
}
return t.UTC().Unix()
}
func (s *SyncService) refreshModelsMetaFromRedis(ctx context.Context, source string) error {
raw, err := s.rdb.HGetAll(ctx, "meta:models").Result()
if err != nil {