package middleware import ( "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" ) func TestResponseEnvelope_Success(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/ok", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"id": 1, "name": "test"}) }) req := httptest.NewRequest(http.MethodGet, "/ok", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) } var env struct { Code int `json:"code"` Message string `json:"message"` Data map[string]any `json:"data"` TraceID string `json:"trace_id"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env.Code != CodeSuccess { t.Fatalf("expected code=0, got %d", env.Code) } if env.Message != "success" { t.Fatalf("expected message='success', got %q", env.Message) } if env.Data["id"] != float64(1) { t.Fatalf("expected data.id=1, got %v", env.Data["id"]) } if env.TraceID == "" { t.Fatal("expected trace_id to be set") } } func TestResponseEnvelope_NotFound(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/missing", func(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "not here"}) }) req := httptest.NewRequest(http.MethodGet, "/missing", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusNotFound { t.Fatalf("expected 404, got %d body=%s", rr.Code, rr.Body.String()) } var env struct { Code int `json:"code"` Message string `json:"message"` Data any `json:"data"` TraceID string `json:"trace_id"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env.Code != CodeResourceNotFound { t.Fatalf("expected code=%d, got %d", CodeResourceNotFound, env.Code) } if env.Message != "not here" { t.Fatalf("expected message='not here', got %q", env.Message) } if env.Data != nil { t.Fatalf("expected data=null, got %v", env.Data) } } func TestResponseEnvelope_OverrideBusinessCode(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/rate-limit", func(c *gin.Context) { SetBusinessCode(c, 1099) // Custom code c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limited"}) }) req := httptest.NewRequest(http.MethodGet, "/rate-limit", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Fatalf("expected 429, got %d body=%s", rr.Code, rr.Body.String()) } var env struct { Code int `json:"code"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env.Code != 1099 { t.Fatalf("expected code=1099, got %d", env.Code) } } func TestResponseEnvelope_WithDetails(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestID()) r.Use(ResponseEnvelope()) r.POST("/validate", func(c *gin.Context) { SetErrorDetails(c, map[string]string{ "email": "格式错误", "user_id": "必填", }) c.JSON(http.StatusBadRequest, gin.H{"error": "参数校验失败"}) }) req := httptest.NewRequest(http.MethodPost, "/validate", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) } var env struct { Code int `json:"code"` Message string `json:"message"` Data any `json:"data"` Details map[string]string `json:"details"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env.Code != CodeInvalidParam { t.Fatalf("expected code=%d, got %d", CodeInvalidParam, env.Code) } if env.Message != "参数校验失败" { t.Fatalf("expected message='参数校验失败', got %q", env.Message) } if env.Data != nil { t.Fatalf("expected data=null, got %v", env.Data) } if env.Details["email"] != "格式错误" { t.Fatalf("expected details.email='格式错误', got %v", env.Details["email"]) } } func TestResponseEnvelope_Idempotent(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(ResponseEnvelope()) r.GET("/wrapped", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": 0, "data": gin.H{"id": 1}, "message": "success", "trace_id": "test-123", }) }) req := httptest.NewRequest(http.MethodGet, "/wrapped", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) } var env map[string]any if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env["code"] != float64(0) { t.Fatalf("expected code=0, got %v", env["code"]) } if env["trace_id"] != "test-123" { t.Fatalf("expected trace_id=test-123, got %v", env["trace_id"]) } data, ok := env["data"].(map[string]any) if !ok || data["id"] != float64(1) { t.Fatalf("unexpected data: %+v", env["data"]) } } func TestResponseEnvelope_IdempotentOldFormat(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(ResponseEnvelope()) r.GET("/old-wrapped", func(c *gin.Context) { // Old format with string code c.JSON(http.StatusOK, gin.H{ "code": "ok", "data": gin.H{"id": 1}, "message": "", }) }) req := httptest.NewRequest(http.MethodGet, "/old-wrapped", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) } var env map[string]any if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } // Should pass through unchanged if env["code"] != "ok" { t.Fatalf("expected code='ok', got %v", env["code"]) } } func TestResponseEnvelope_TraceIDFromHeader(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/trace", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/trace", nil) req.Header.Set("X-Request-ID", "custom-trace-123") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) var env struct { TraceID string `json:"trace_id"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } if env.TraceID != "custom-trace-123" { t.Fatalf("expected trace_id='custom-trace-123', got %q", env.TraceID) } } func TestResponseEnvelope_NonJSON(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(ResponseEnvelope()) r.GET("/text", func(c *gin.Context) { c.String(http.StatusOK, "plain text") }) req := httptest.NewRequest(http.MethodGet, "/text", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } if rr.Body.String() != "plain text" { t.Fatalf("expected 'plain text', got %q", rr.Body.String()) } } func TestResponseEnvelope_DefaultCodes(t *testing.T) { testCases := []struct { status int expectedCode int }{ {http.StatusOK, CodeSuccess}, {http.StatusCreated, CodeSuccess}, {http.StatusBadRequest, CodeInvalidParam}, {http.StatusUnauthorized, CodeUnauthorized}, {http.StatusForbidden, CodeForbidden}, {http.StatusNotFound, CodeResourceNotFound}, {http.StatusConflict, CodeResourceConflict}, {http.StatusTooManyRequests, CodeRateLimited}, {http.StatusInternalServerError, CodeInternalError}, {http.StatusServiceUnavailable, CodeInternalError}, } for _, tc := range testCases { got := defaultBusinessCode(tc.status) if got != tc.expectedCode { t.Errorf("defaultBusinessCode(%d) = %d, want %d", tc.status, got, tc.expectedCode) } } }