package middleware import ( "bytes" "crypto/rand" "encoding/hex" "encoding/json" "net/http" "strings" "github.com/gin-gonic/gin" ) // Business code constants const ( CodeSuccess = 0 // Common errors (1xxx) CodeInvalidParam = 1001 CodeUnauthorized = 1002 CodeForbidden = 1003 CodeRateLimited = 1004 // Client errors (4xxx) CodeResourceNotFound = 4001 CodeResourceConflict = 4002 CodeInvalidState = 4003 // Server errors (5xxx) CodeInternalError = 5001 CodeServiceUnavailable = 5002 CodeTimeout = 5003 ) const ( businessCodeKey = "response_business_code" errorDetailsKey = "response_error_details" errorMessageKey = "response_error_message" ) // Paths that should not be wrapped var excludedPaths = map[string]bool{ "/health": true, "/debug/vars": true, "/internal/metrics": true, } type responseEnvelope struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` TraceID string `json:"trace_id"` Details any `json:"details,omitempty"` } type envelopeWriter struct { gin.ResponseWriter body bytes.Buffer status int size int wrote bool } func (w *envelopeWriter) WriteHeader(code int) { w.status = code w.wrote = true } func (w *envelopeWriter) WriteHeaderNow() { if !w.wrote { w.WriteHeader(http.StatusOK) } } func (w *envelopeWriter) Write(data []byte) (int, error) { if !w.wrote { w.WriteHeader(http.StatusOK) } n, err := w.body.Write(data) w.size += n return n, err } func (w *envelopeWriter) WriteString(s string) (int, error) { if !w.wrote { w.WriteHeader(http.StatusOK) } n, err := w.body.WriteString(s) w.size += n return n, err } func (w *envelopeWriter) Status() int { if w.status == 0 { return http.StatusOK } return w.status } func (w *envelopeWriter) Size() int { return w.size } func (w *envelopeWriter) Written() bool { return w.wrote } // SetBusinessCode sets an explicit business code for the response envelope. func SetBusinessCode(c *gin.Context, code int) { if c == nil { return } c.Set(businessCodeKey, code) } // SetErrorDetails sets additional error details for the response envelope. func SetErrorDetails(c *gin.Context, details any) { if c == nil || details == nil { return } c.Set(errorDetailsKey, details) } // SetErrorMessage sets a custom error message for the response envelope. func SetErrorMessage(c *gin.Context, message string) { if c == nil { return } c.Set(errorMessageKey, message) } func ResponseEnvelope() gin.HandlerFunc { return func(c *gin.Context) { // Skip excluded paths if isExcludedPath(c.Request.URL.Path) { c.Next() return } originalWriter := c.Writer writer := &envelopeWriter{ResponseWriter: originalWriter} c.Writer = writer c.Next() status := writer.Status() body := writer.body.Bytes() if !bodyAllowedForStatus(status) || len(body) == 0 { writeStatusOnly(originalWriter, status) return } contentType := originalWriter.Header().Get("Content-Type") if !isJSONContentType(contentType) { writeThrough(originalWriter, status, body) return } obj, objOK := parseObject(body) if objOK && isEnvelopeObject(obj) { writeThrough(originalWriter, status, body) return } code := businessCodeFromContext(c) if code == 0 { code = defaultBusinessCode(status) } traceID := getTraceID(c) isError := status >= http.StatusBadRequest var message string if customMsg := errorMessageFromContext(c); customMsg != "" { message = customMsg } else if isError && objOK { if raw, ok := obj["error"]; ok { var msg string if err := json.Unmarshal(raw, &msg); err == nil { message = msg } } } if message == "" { if isError { message = http.StatusText(status) } else { message = "success" } } // Get details: from context first, then from original response details := errorDetailsFromContext(c) if details == nil && isError && objOK { if raw, ok := obj["details"]; ok { var d any if err := json.Unmarshal(raw, &d); err == nil { details = d } } } var data json.RawMessage if isError { data = json.RawMessage("null") } else { data = json.RawMessage(body) } envelope := responseEnvelope{ Code: code, Message: message, Data: data, TraceID: traceID, Details: details, } payload, err := json.Marshal(envelope) if err != nil { writeThrough(originalWriter, status, body) return } originalWriter.Header().Set("Content-Type", "application/json; charset=utf-8") originalWriter.Header().Del("Content-Length") writeThrough(originalWriter, status, payload) } } func isExcludedPath(path string) bool { if excludedPaths[path] { return true } // Also exclude swagger paths if strings.HasPrefix(path, "/swagger") { return true } return false } func businessCodeFromContext(c *gin.Context) int { if c == nil { return 0 } value, ok := c.Get(businessCodeKey) if !ok { return 0 } code, ok := value.(int) if !ok { return 0 } return code } func errorDetailsFromContext(c *gin.Context) any { if c == nil { return nil } value, ok := c.Get(errorDetailsKey) if !ok { return nil } return value } func errorMessageFromContext(c *gin.Context) string { if c == nil { return "" } value, ok := c.Get(errorMessageKey) if !ok { return "" } msg, ok := value.(string) if !ok { return "" } return msg } func getTraceID(c *gin.Context) string { if c == nil { return generateTraceID() } // Try to get from context first (set by RequestID middleware) if id, ok := c.Get("request_id"); ok { if s, ok := id.(string); ok && s != "" { return s } } // Fallback to header if id := c.GetHeader("X-Request-ID"); id != "" { return id } // Generate UUID as fallback return generateTraceID() } func generateTraceID() string { b := make([]byte, 16) _, _ = rand.Read(b) return hex.EncodeToString(b) } func defaultBusinessCode(status int) int { switch { case status >= http.StatusOK && status < http.StatusMultipleChoices: return CodeSuccess case status == http.StatusBadRequest: return CodeInvalidParam case status == http.StatusUnauthorized: return CodeUnauthorized case status == http.StatusForbidden: return CodeForbidden case status == http.StatusNotFound: return CodeResourceNotFound case status == http.StatusConflict: return CodeResourceConflict case status == http.StatusTooManyRequests: return CodeRateLimited case status >= http.StatusBadRequest && status < http.StatusInternalServerError: return CodeInvalidParam case status == http.StatusServiceUnavailable: return CodeServiceUnavailable case status == http.StatusGatewayTimeout: return CodeTimeout case status >= http.StatusInternalServerError: return CodeInternalError default: return CodeSuccess } } func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false case status == http.StatusNoContent: return false case status == http.StatusNotModified: return false } return true } func isJSONContentType(contentType string) bool { if contentType == "" { return false } return strings.Contains(strings.ToLower(contentType), "application/json") } func parseObject(body []byte) (map[string]json.RawMessage, bool) { var obj map[string]json.RawMessage if err := json.Unmarshal(body, &obj); err != nil { return nil, false } return obj, true } func isEnvelopeObject(obj map[string]json.RawMessage) bool { if obj == nil { return false } // Only recognize new envelope format: must have trace_id and code must be number codeRaw, hasCode := obj["code"] _, hasData := obj["data"] _, hasMessage := obj["message"] _, hasTraceID := obj["trace_id"] if !hasCode || !hasData || !hasMessage || !hasTraceID { return false } // Code must be a number (int) var codeNum int if json.Unmarshal(codeRaw, &codeNum) != nil { return false } return true } func writeStatusOnly(w gin.ResponseWriter, status int) { w.WriteHeader(status) } func writeThrough(w gin.ResponseWriter, status int, body []byte) { w.WriteHeader(status) _, _ = w.Write(body) }