diff --git a/internal/middleware/response_envelope.go b/internal/middleware/response_envelope.go index 4012529..f8f4a01 100644 --- a/internal/middleware/response_envelope.go +++ b/internal/middleware/response_envelope.go @@ -9,12 +9,39 @@ import ( "github.com/gin-gonic/gin" ) -const businessCodeKey = "response_business_code" +// Business code constants +const ( + CodeSuccess = 0 + + // Common errors (1xxx) + CodeInvalidParam = 1001 + CodeUnauthorized = 1002 + CodeForbidden = 1003 + CodeRateLimited = 1004 + + // Client errors (4xxx) + CodeResourceNotFound = 4001 + CodeResourceConflict = 4002 + CodeInvalidState = 4003 + + // Server errors (5xxx) + CodeInternalError = 5001 + CodeServiceUnavailable = 5002 + CodeTimeout = 5003 +) + +const ( + businessCodeKey = "response_business_code" + errorDetailsKey = "response_error_details" + errorMessageKey = "response_error_message" +) type responseEnvelope struct { - Code string `json:"code"` - Data json.RawMessage `json:"data"` + Code int `json:"code"` Message string `json:"message"` + Data json.RawMessage `json:"data"` + TraceID string `json:"trace_id"` + Details any `json:"details,omitempty"` } type envelopeWriter struct { @@ -70,15 +97,27 @@ func (w *envelopeWriter) Written() bool { } // SetBusinessCode sets an explicit business code for the response envelope. -func SetBusinessCode(c *gin.Context, code string) { +func SetBusinessCode(c *gin.Context, code int) { if c == nil { return } - code = strings.TrimSpace(code) - if code == "" { + c.Set(businessCodeKey, code) +} + +// SetErrorDetails sets additional error details for the response envelope. +func SetErrorDetails(c *gin.Context, details any) { + if c == nil || details == nil { return } - c.Set(businessCodeKey, code) + c.Set(errorDetailsKey, details) +} + +// SetErrorMessage sets a custom error message for the response envelope. +func SetErrorMessage(c *gin.Context, message string) { + if c == nil { + return + } + c.Set(errorMessageKey, message) } func ResponseEnvelope() gin.HandlerFunc { @@ -109,12 +148,17 @@ func ResponseEnvelope() gin.HandlerFunc { } code := businessCodeFromContext(c) - if code == "" { + if code == 0 { code = defaultBusinessCode(status) } - message := "" - if status >= http.StatusBadRequest && objOK { + traceID := getTraceID(c) + isError := status >= http.StatusBadRequest + + var message string + if customMsg := errorMessageFromContext(c); customMsg != "" { + message = customMsg + } else if isError && objOK { if raw, ok := obj["error"]; ok { var msg string if err := json.Unmarshal(raw, &msg); err == nil { @@ -122,11 +166,27 @@ func ResponseEnvelope() gin.HandlerFunc { } } } + if message == "" { + if isError { + message = http.StatusText(status) + } else { + message = "success" + } + } + + var data json.RawMessage + if isError { + data = json.RawMessage("null") + } else { + data = json.RawMessage(body) + } envelope := responseEnvelope{ Code: code, - Data: json.RawMessage(body), Message: message, + Data: data, + TraceID: traceID, + Details: errorDetailsFromContext(c), } payload, err := json.Marshal(envelope) if err != nil { @@ -140,43 +200,83 @@ func ResponseEnvelope() gin.HandlerFunc { } } -func businessCodeFromContext(c *gin.Context) string { +func businessCodeFromContext(c *gin.Context) int { if c == nil { - return "" + return 0 } value, ok := c.Get(businessCodeKey) if !ok { + return 0 + } + code, ok := value.(int) + if !ok { + return 0 + } + return code +} + +func errorDetailsFromContext(c *gin.Context) any { + if c == nil { + return nil + } + value, ok := c.Get(errorDetailsKey) + if !ok { + return nil + } + return value +} + +func errorMessageFromContext(c *gin.Context) string { + if c == nil { return "" } - code, ok := value.(string) + value, ok := c.Get(errorMessageKey) if !ok { return "" } - return strings.TrimSpace(code) + msg, ok := value.(string) + if !ok { + return "" + } + return msg } -func defaultBusinessCode(status int) string { +func getTraceID(c *gin.Context) string { + if c == nil { + return "" + } + // Try to get from context first (set by RequestID middleware) + if id, ok := c.Get("request_id"); ok { + if s, ok := id.(string); ok && s != "" { + return s + } + } + // Fallback to header + return c.GetHeader("X-Request-ID") +} + +func defaultBusinessCode(status int) int { switch { case status >= http.StatusOK && status < http.StatusMultipleChoices: - return "ok" + return CodeSuccess case status == http.StatusBadRequest: - return "invalid_request" + return CodeInvalidParam case status == http.StatusUnauthorized: - return "unauthorized" + return CodeUnauthorized case status == http.StatusForbidden: - return "forbidden" + return CodeForbidden case status == http.StatusNotFound: - return "not_found" + return CodeResourceNotFound case status == http.StatusConflict: - return "conflict" + return CodeResourceConflict case status == http.StatusTooManyRequests: - return "rate_limited" + return CodeRateLimited case status >= http.StatusBadRequest && status < http.StatusInternalServerError: - return "request_error" + return CodeInvalidParam case status >= http.StatusInternalServerError: - return "internal_error" + return CodeInternalError default: - return "ok" + return CodeSuccess } } @@ -211,16 +311,33 @@ func isEnvelopeObject(obj map[string]json.RawMessage) bool { if obj == nil { return false } - if _, ok := obj["code"]; !ok { + // Check for new envelope format (code is number, has trace_id) + codeRaw, hasCode := obj["code"] + _, hasData := obj["data"] + _, hasMessage := obj["message"] + _, hasTraceID := obj["trace_id"] + + if !hasCode || !hasData || !hasMessage { return false } - if _, ok := obj["data"]; !ok { - return false + + // If has trace_id, it's definitely our envelope + if hasTraceID { + return true } - if _, ok := obj["message"]; !ok { - return false + + // Check if code is a number (new format) or string (old format) + // Both should be treated as envelope to avoid double-wrapping + var codeNum int + if json.Unmarshal(codeRaw, &codeNum) == nil { + return true } - return true + var codeStr string + if json.Unmarshal(codeRaw, &codeStr) == nil { + return true + } + + return false } func writeStatusOnly(w gin.ResponseWriter, status int) { diff --git a/internal/middleware/response_envelope_test.go b/internal/middleware/response_envelope_test.go index 4d19c1d..00175fc 100644 --- a/internal/middleware/response_envelope_test.go +++ b/internal/middleware/response_envelope_test.go @@ -9,10 +9,52 @@ import ( "github.com/gin-gonic/gin" ) -func TestResponseEnvelope_DefaultMapping(t *testing.T) { +func TestResponseEnvelope_Success(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() + r.Use(RequestID()) + r.Use(ResponseEnvelope()) + r.GET("/ok", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"id": 1, "name": "test"}) + }) + + req := httptest.NewRequest(http.MethodGet, "/ok", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + + var env struct { + Code int `json:"code"` + Message string `json:"message"` + Data map[string]any `json:"data"` + TraceID string `json:"trace_id"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + if env.Code != CodeSuccess { + t.Fatalf("expected code=0, got %d", env.Code) + } + if env.Message != "success" { + t.Fatalf("expected message='success', got %q", env.Message) + } + if env.Data["id"] != float64(1) { + t.Fatalf("expected data.id=1, got %v", env.Data["id"]) + } + if env.TraceID == "" { + t.Fatal("expected trace_id to be set") + } +} + +func TestResponseEnvelope_NotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/missing", func(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "not here"}) @@ -27,21 +69,22 @@ func TestResponseEnvelope_DefaultMapping(t *testing.T) { } var env struct { - Code string `json:"code"` - Message string `json:"message"` - Data map[string]any `json:"data"` + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data"` + TraceID string `json:"trace_id"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } - if env.Code != "not_found" { - t.Fatalf("expected code=not_found, got %q", env.Code) + if env.Code != CodeResourceNotFound { + t.Fatalf("expected code=%d, got %d", CodeResourceNotFound, env.Code) } if env.Message != "not here" { - t.Fatalf("expected message 'not here', got %q", env.Message) + t.Fatalf("expected message='not here', got %q", env.Message) } - if env.Data["error"] != "not here" { - t.Fatalf("expected data.error 'not here', got %v", env.Data["error"]) + if env.Data != nil { + t.Fatalf("expected data=null, got %v", env.Data) } } @@ -49,9 +92,10 @@ func TestResponseEnvelope_OverrideBusinessCode(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() + r.Use(RequestID()) r.Use(ResponseEnvelope()) r.GET("/rate-limit", func(c *gin.Context) { - SetBusinessCode(c, "quota_exceeded") + SetBusinessCode(c, 1099) // Custom code c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limited"}) }) @@ -64,13 +108,58 @@ func TestResponseEnvelope_OverrideBusinessCode(t *testing.T) { } var env struct { - Code string `json:"code"` + Code int `json:"code"` } if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } - if env.Code != "quota_exceeded" { - t.Fatalf("expected code=quota_exceeded, got %q", env.Code) + if env.Code != 1099 { + t.Fatalf("expected code=1099, got %d", env.Code) + } +} + +func TestResponseEnvelope_WithDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestID()) + r.Use(ResponseEnvelope()) + r.POST("/validate", func(c *gin.Context) { + SetErrorDetails(c, map[string]string{ + "email": "格式错误", + "user_id": "必填", + }) + c.JSON(http.StatusBadRequest, gin.H{"error": "参数校验失败"}) + }) + + req := httptest.NewRequest(http.MethodPost, "/validate", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } + + var env struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data"` + Details map[string]string `json:"details"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + if env.Code != CodeInvalidParam { + t.Fatalf("expected code=%d, got %d", CodeInvalidParam, env.Code) + } + if env.Message != "参数校验失败" { + t.Fatalf("expected message='参数校验失败', got %q", env.Message) + } + if env.Data != nil { + t.Fatalf("expected data=null, got %v", env.Data) + } + if env.Details["email"] != "格式错误" { + t.Fatalf("expected details.email='格式错误', got %v", env.Details["email"]) } } @@ -81,9 +170,10 @@ func TestResponseEnvelope_Idempotent(t *testing.T) { r.Use(ResponseEnvelope()) r.GET("/wrapped", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ - "code": "ok", - "data": gin.H{"id": 1}, - "message": "", + "code": 0, + "data": gin.H{"id": 1}, + "message": "success", + "trace_id": "test-123", }) }) @@ -99,11 +189,118 @@ func TestResponseEnvelope_Idempotent(t *testing.T) { if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { t.Fatalf("unmarshal envelope: %v", err) } - if env["code"] != "ok" { - t.Fatalf("expected code=ok, got %v", env["code"]) + if env["code"] != float64(0) { + t.Fatalf("expected code=0, got %v", env["code"]) + } + if env["trace_id"] != "test-123" { + t.Fatalf("expected trace_id=test-123, got %v", env["trace_id"]) } data, ok := env["data"].(map[string]any) if !ok || data["id"] != float64(1) { t.Fatalf("unexpected data: %+v", env["data"]) } } + +func TestResponseEnvelope_IdempotentOldFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ResponseEnvelope()) + r.GET("/old-wrapped", func(c *gin.Context) { + // Old format with string code + c.JSON(http.StatusOK, gin.H{ + "code": "ok", + "data": gin.H{"id": 1}, + "message": "", + }) + }) + + req := httptest.NewRequest(http.MethodGet, "/old-wrapped", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + + var env map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + // Should pass through unchanged + if env["code"] != "ok" { + t.Fatalf("expected code='ok', got %v", env["code"]) + } +} + +func TestResponseEnvelope_TraceIDFromHeader(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestID()) + r.Use(ResponseEnvelope()) + r.GET("/trace", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/trace", nil) + req.Header.Set("X-Request-ID", "custom-trace-123") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + var env struct { + TraceID string `json:"trace_id"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &env); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + if env.TraceID != "custom-trace-123" { + t.Fatalf("expected trace_id='custom-trace-123', got %q", env.TraceID) + } +} + +func TestResponseEnvelope_NonJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ResponseEnvelope()) + r.GET("/text", func(c *gin.Context) { + c.String(http.StatusOK, "plain text") + }) + + req := httptest.NewRequest(http.MethodGet, "/text", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + if rr.Body.String() != "plain text" { + t.Fatalf("expected 'plain text', got %q", rr.Body.String()) + } +} + +func TestResponseEnvelope_DefaultCodes(t *testing.T) { + testCases := []struct { + status int + expectedCode int + }{ + {http.StatusOK, CodeSuccess}, + {http.StatusCreated, CodeSuccess}, + {http.StatusBadRequest, CodeInvalidParam}, + {http.StatusUnauthorized, CodeUnauthorized}, + {http.StatusForbidden, CodeForbidden}, + {http.StatusNotFound, CodeResourceNotFound}, + {http.StatusConflict, CodeResourceConflict}, + {http.StatusTooManyRequests, CodeRateLimited}, + {http.StatusInternalServerError, CodeInternalError}, + {http.StatusServiceUnavailable, CodeInternalError}, + } + + for _, tc := range testCases { + got := defaultBusinessCode(tc.status) + if got != tc.expectedCode { + t.Errorf("defaultBusinessCode(%d) = %d, want %d", tc.status, got, tc.expectedCode) + } + } +}