Files
ez-api/internal/middleware/response_envelope_test.go
zenfun 26733be020 feat(api): standardize response envelope behavior
Add shared response DTOs and enhance the response envelope middleware with
excluded paths, trace ID generation fallback, and automatic extraction of
error details from handler responses. Update default business code mapping
for 503 and 504, and adjust idempotency detection to only treat the new
envelope format as already-wrapped.

BREAKING CHANGE: responses using the old envelope format (e.g., string
`code`) are now wrapped into the new standard envelope.
2026-01-10 00:59:45 +08:00

412 lines
11 KiB
Go

package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestResponseEnvelope_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.GET("/ok", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "test"})
})
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
}
var env struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
TraceID string `json:"trace_id"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.Code != CodeSuccess {
t.Fatalf("expected code=0, got %d", env.Code)
}
if env.Message != "success" {
t.Fatalf("expected message='success', got %q", env.Message)
}
if env.Data["id"] != float64(1) {
t.Fatalf("expected data.id=1, got %v", env.Data["id"])
}
if env.TraceID == "" {
t.Fatal("expected trace_id to be set")
}
}
func TestResponseEnvelope_NotFound(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.GET("/missing", func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "not here"})
})
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body=%s", rr.Code, rr.Body.String())
}
var env struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
TraceID string `json:"trace_id"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.Code != CodeResourceNotFound {
t.Fatalf("expected code=%d, got %d", CodeResourceNotFound, env.Code)
}
if env.Message != "not here" {
t.Fatalf("expected message='not here', got %q", env.Message)
}
if env.Data != nil {
t.Fatalf("expected data=null, got %v", env.Data)
}
}
func TestResponseEnvelope_OverrideBusinessCode(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.GET("/rate-limit", func(c *gin.Context) {
SetBusinessCode(c, 1099) // Custom code
c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limited"})
})
req := httptest.NewRequest(http.MethodGet, "/rate-limit", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d body=%s", rr.Code, rr.Body.String())
}
var env struct {
Code int `json:"code"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.Code != 1099 {
t.Fatalf("expected code=1099, got %d", env.Code)
}
}
func TestResponseEnvelope_WithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.POST("/validate", func(c *gin.Context) {
SetErrorDetails(c, map[string]string{
"email": "格式错误",
"user_id": "必填",
})
c.JSON(http.StatusBadRequest, gin.H{"error": "参数校验失败"})
})
req := httptest.NewRequest(http.MethodPost, "/validate", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String())
}
var env struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
Details map[string]string `json:"details"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.Code != CodeInvalidParam {
t.Fatalf("expected code=%d, got %d", CodeInvalidParam, env.Code)
}
if env.Message != "参数校验失败" {
t.Fatalf("expected message='参数校验失败', got %q", env.Message)
}
if env.Data != nil {
t.Fatalf("expected data=null, got %v", env.Data)
}
if env.Details["email"] != "格式错误" {
t.Fatalf("expected details.email='格式错误', got %v", env.Details["email"])
}
}
func TestResponseEnvelope_AutoExtractDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.POST("/validate", func(c *gin.Context) {
// Handler returns details in response, middleware should extract it
c.JSON(http.StatusBadRequest, gin.H{
"error": "参数校验失败",
"details": map[string]string{"email": "格式错误"},
})
})
req := httptest.NewRequest(http.MethodPost, "/validate", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var env struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
Details map[string]any `json:"details"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.Details["email"] != "格式错误" {
t.Fatalf("expected details.email='格式错误', got %v", env.Details)
}
}
func TestResponseEnvelope_Idempotent(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ResponseEnvelope())
r.GET("/wrapped", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{"id": 1},
"message": "success",
"trace_id": "test-123",
})
})
req := httptest.NewRequest(http.MethodGet, "/wrapped", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
}
var env map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env["code"] != float64(0) {
t.Fatalf("expected code=0, got %v", env["code"])
}
if env["trace_id"] != "test-123" {
t.Fatalf("expected trace_id=test-123, got %v", env["trace_id"])
}
data, ok := env["data"].(map[string]any)
if !ok || data["id"] != float64(1) {
t.Fatalf("unexpected data: %+v", env["data"])
}
}
func TestResponseEnvelope_OldFormatGetsWrapped(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.GET("/old-wrapped", func(c *gin.Context) {
// Old format with string code - should be wrapped now
c.JSON(http.StatusOK, gin.H{
"code": "ok",
"data": gin.H{"id": 1},
"message": "",
})
})
req := httptest.NewRequest(http.MethodGet, "/old-wrapped", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var env struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
TraceID string `json:"trace_id"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
// Old format should be wrapped, so code should be 0 (success)
if env.Code != 0 {
t.Fatalf("expected code=0, got %d", env.Code)
}
// Data should contain the old response
if env.Data["code"] != "ok" {
t.Fatalf("expected data.code='ok', got %v", env.Data["code"])
}
if env.TraceID == "" {
t.Fatal("expected trace_id to be set")
}
}
func TestResponseEnvelope_TraceIDFromHeader(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestID())
r.Use(ResponseEnvelope())
r.GET("/trace", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/trace", nil)
req.Header.Set("X-Request-ID", "custom-trace-123")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var env struct {
TraceID string `json:"trace_id"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.TraceID != "custom-trace-123" {
t.Fatalf("expected trace_id='custom-trace-123', got %q", env.TraceID)
}
}
func TestResponseEnvelope_TraceIDFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// No RequestID middleware
r.Use(ResponseEnvelope())
r.GET("/trace", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/trace", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var env struct {
TraceID string `json:"trace_id"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil {
t.Fatalf("unmarshal envelope: %v", err)
}
if env.TraceID == "" {
t.Fatal("expected trace_id to be generated")
}
if len(env.TraceID) != 32 {
t.Fatalf("expected 32-char hex trace_id, got %q", env.TraceID)
}
}
func TestResponseEnvelope_NonJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ResponseEnvelope())
r.GET("/text", func(c *gin.Context) {
c.String(http.StatusOK, "plain text")
})
req := httptest.NewRequest(http.MethodGet, "/text", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if rr.Body.String() != "plain text" {
t.Fatalf("expected 'plain text', got %q", rr.Body.String())
}
}
func TestResponseEnvelope_ExcludedPaths(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ResponseEnvelope())
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
r.GET("/swagger/doc.json", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"openapi": "3.0.0"})
})
tests := []struct {
path string
}{
{"/health"},
{"/swagger/doc.json"},
}
for _, tt := range tests {
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response for %s: %v", tt.path, err)
}
// Should not have envelope fields
if _, hasCode := resp["code"]; hasCode {
t.Fatalf("expected %s to not be wrapped, but found 'code' field", tt.path)
}
}
}
func TestResponseEnvelope_DefaultCodes(t *testing.T) {
testCases := []struct {
status int
expectedCode int
}{
{http.StatusOK, CodeSuccess},
{http.StatusCreated, CodeSuccess},
{http.StatusBadRequest, CodeInvalidParam},
{http.StatusUnauthorized, CodeUnauthorized},
{http.StatusForbidden, CodeForbidden},
{http.StatusNotFound, CodeResourceNotFound},
{http.StatusConflict, CodeResourceConflict},
{http.StatusTooManyRequests, CodeRateLimited},
{http.StatusInternalServerError, CodeInternalError},
{http.StatusServiceUnavailable, CodeServiceUnavailable},
{http.StatusGatewayTimeout, CodeTimeout},
}
for _, tc := range testCases {
got := defaultBusinessCode(tc.status)
if got != tc.expectedCode {
t.Errorf("defaultBusinessCode(%d) = %d, want %d", tc.status, got, tc.expectedCode)
}
}
}