diff --git a/internal/api/auth_handler.go b/internal/api/auth_handler.go index 6d92762..690756a 100644 --- a/internal/api/auth_handler.go +++ b/internal/api/auth_handler.go @@ -186,14 +186,19 @@ func (h *AuthHandler) Whoami(c *gin.Context) { } // Get master ID and issued_at_epoch - masterIDStr := keyData["master_id"] + masterIDStr := strings.TrimSpace(keyData["master_id"]) masterID, err := strconv.ParseUint(masterIDStr, 10, 64) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token metadata"}) return } - issuedAtEpoch, _ := strconv.ParseInt(keyData["issued_at_epoch"], 10, 64) + issuedAtStr := strings.TrimSpace(keyData["issued_at_epoch"]) + issuedAtEpoch, err := strconv.ParseInt(issuedAtStr, 10, 64) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token metadata"}) + return + } // Get master metadata from Redis masterData, err := h.rdb.HGetAll(ctx, "auth:master:"+masterIDStr).Result() @@ -210,14 +215,27 @@ func (h *AuthHandler) Whoami(c *gin.Context) { } // Check epoch (key revocation) - masterEpoch, _ := strconv.ParseInt(masterData["epoch"], 10, 64) + masterEpochStr := strings.TrimSpace(masterData["epoch"]) + masterEpoch, err := strconv.ParseInt(masterEpochStr, 10, 64) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid master metadata"}) + return + } if issuedAtEpoch < masterEpoch { c.JSON(http.StatusUnauthorized, gin.H{"error": "token has been revoked"}) return } // Check expiration - expiresAt, _ := strconv.ParseInt(keyData["expires_at"], 10, 64) + expiresAt := int64(0) + expiresAtStr := strings.TrimSpace(keyData["expires_at"]) + if expiresAtStr != "" { + expiresAt, err = strconv.ParseInt(expiresAtStr, 10, 64) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token metadata"}) + return + } + } if expiresAt > 0 && time.Now().Unix() >= expiresAt { c.JSON(http.StatusUnauthorized, gin.H{"error": "token has expired"}) return diff --git a/internal/api/auth_handler_test.go b/internal/api/auth_handler_test.go new file mode 100644 index 0000000..4cdb5f4 --- /dev/null +++ b/internal/api/auth_handler_test.go @@ -0,0 +1,154 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/ez-api/ez-api/internal/model" + "github.com/ez-api/ez-api/internal/service" + "github.com/ez-api/foundation/tokenhash" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func newAuthHandler(t *testing.T) (*AuthHandler, *gorm.DB, *miniredis.Miniredis) { + t.Helper() + gin.SetMode(gin.TestMode) + t.Setenv("EZ_ADMIN_TOKEN", "admin-secret") + + adminService, err := service.NewAdminService() + if err != nil { + t.Fatalf("NewAdminService: %v", err) + } + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Master{}, &model.Key{}); err != nil { + t.Fatalf("migrate: %v", err) + } + + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + + masterService := service.NewMasterService(db) + handler := NewAuthHandler(db, rdb, adminService, masterService) + return handler, db, mr +} + +func TestAuthHandler_Whoami_InvalidIssuedAtEpoch_Returns401(t *testing.T) { + handler, _, mr := newAuthHandler(t) + + token := "sk-live-invalid-epoch" + hash := tokenhash.HashToken(token) + mr.HSet("auth:token:"+hash, "master_id", "1") + mr.HSet("auth:token:"+hash, "issued_at_epoch", "bad") + mr.HSet("auth:token:"+hash, "status", "active") + mr.HSet("auth:token:"+hash, "group", "default") + mr.HSet("auth:master:1", "epoch", "1") + mr.HSet("auth:master:1", "status", "active") + + r := gin.New() + r.GET("/auth/whoami", handler.Whoami) + + req := httptest.NewRequest(http.MethodGet, "/auth/whoami", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rr.Code, rr.Body.String()) + } +} + +func TestAuthHandler_Whoami_InvalidMasterEpoch_Returns401(t *testing.T) { + handler, _, mr := newAuthHandler(t) + + token := "sk-live-invalid-master-epoch" + hash := tokenhash.HashToken(token) + mr.HSet("auth:token:"+hash, "master_id", "1") + mr.HSet("auth:token:"+hash, "issued_at_epoch", "1") + mr.HSet("auth:token:"+hash, "status", "active") + mr.HSet("auth:token:"+hash, "group", "default") + mr.HSet("auth:master:1", "epoch", "bad") + mr.HSet("auth:master:1", "status", "active") + + r := gin.New() + r.GET("/auth/whoami", handler.Whoami) + + req := httptest.NewRequest(http.MethodGet, "/auth/whoami", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rr.Code, rr.Body.String()) + } +} + +func TestAuthHandler_Whoami_KeyResponseIncludesIPRules(t *testing.T) { + handler, db, mr := newAuthHandler(t) + + token := "sk-live-valid" + hash := tokenhash.HashToken(token) + expiresAt := time.Now().Add(time.Hour).Unix() + + mr.HSet("auth:token:"+hash, "master_id", "1") + mr.HSet("auth:token:"+hash, "issued_at_epoch", "1") + mr.HSet("auth:token:"+hash, "status", "active") + mr.HSet("auth:token:"+hash, "group", "default") + mr.HSet("auth:token:"+hash, "expires_at", fmt.Sprintf("%d", expiresAt)) + mr.HSet("auth:master:1", "epoch", "1") + mr.HSet("auth:master:1", "status", "active") + + key := &model.Key{ + MasterID: 1, + TokenHash: hash, + Group: "default", + Scopes: "chat:write", + DefaultNamespace: "default", + Namespaces: "default", + Status: "active", + IssuedAtEpoch: 1, + IssuedBy: "master", + AllowIPs: "1.2.3.4", + DenyIPs: "5.6.7.0/24", + } + if err := db.Create(key).Error; err != nil { + t.Fatalf("create key: %v", err) + } + + r := gin.New() + r.GET("/auth/whoami", handler.Whoami) + + req := httptest.NewRequest(http.MethodGet, "/auth/whoami", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["allow_ips"] != "1.2.3.4" { + t.Fatalf("expected allow_ips, got %v", resp["allow_ips"]) + } + if resp["deny_ips"] != "5.6.7.0/24" { + t.Fatalf("expected deny_ips, got %v", resp["deny_ips"]) + } + if got, ok := resp["expires_at"].(float64); !ok || int64(got) != expiresAt { + t.Fatalf("expected expires_at=%d, got %v", expiresAt, resp["expires_at"]) + } +}