mirror of
https://github.com/EZ-Api/ez-api.git
synced 2026-01-13 17:47:51 +00:00
Add shared response DTOs and enhance the response envelope middleware with excluded paths, trace ID generation fallback, and automatic extraction of error details from handler responses. Update default business code mapping for 503 and 504, and adjust idempotency detection to only treat the new envelope format as already-wrapped. BREAKING CHANGE: responses using the old envelope format (e.g., string `code`) are now wrapped into the new standard envelope.
392 lines
8.0 KiB
Go
392 lines
8.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// Business code constants
|
|
const (
|
|
CodeSuccess = 0
|
|
|
|
// Common errors (1xxx)
|
|
CodeInvalidParam = 1001
|
|
CodeUnauthorized = 1002
|
|
CodeForbidden = 1003
|
|
CodeRateLimited = 1004
|
|
|
|
// Client errors (4xxx)
|
|
CodeResourceNotFound = 4001
|
|
CodeResourceConflict = 4002
|
|
CodeInvalidState = 4003
|
|
|
|
// Server errors (5xxx)
|
|
CodeInternalError = 5001
|
|
CodeServiceUnavailable = 5002
|
|
CodeTimeout = 5003
|
|
)
|
|
|
|
const (
|
|
businessCodeKey = "response_business_code"
|
|
errorDetailsKey = "response_error_details"
|
|
errorMessageKey = "response_error_message"
|
|
)
|
|
|
|
// Paths that should not be wrapped
|
|
var excludedPaths = map[string]bool{
|
|
"/health": true,
|
|
"/debug/vars": true,
|
|
"/internal/metrics": true,
|
|
}
|
|
|
|
type responseEnvelope struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data json.RawMessage `json:"data"`
|
|
TraceID string `json:"trace_id"`
|
|
Details any `json:"details,omitempty"`
|
|
}
|
|
|
|
type envelopeWriter struct {
|
|
gin.ResponseWriter
|
|
body bytes.Buffer
|
|
status int
|
|
size int
|
|
wrote bool
|
|
}
|
|
|
|
func (w *envelopeWriter) WriteHeader(code int) {
|
|
w.status = code
|
|
w.wrote = true
|
|
}
|
|
|
|
func (w *envelopeWriter) WriteHeaderNow() {
|
|
if !w.wrote {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func (w *envelopeWriter) Write(data []byte) (int, error) {
|
|
if !w.wrote {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
n, err := w.body.Write(data)
|
|
w.size += n
|
|
return n, err
|
|
}
|
|
|
|
func (w *envelopeWriter) WriteString(s string) (int, error) {
|
|
if !w.wrote {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
n, err := w.body.WriteString(s)
|
|
w.size += n
|
|
return n, err
|
|
}
|
|
|
|
func (w *envelopeWriter) Status() int {
|
|
if w.status == 0 {
|
|
return http.StatusOK
|
|
}
|
|
return w.status
|
|
}
|
|
|
|
func (w *envelopeWriter) Size() int {
|
|
return w.size
|
|
}
|
|
|
|
func (w *envelopeWriter) Written() bool {
|
|
return w.wrote
|
|
}
|
|
|
|
// SetBusinessCode sets an explicit business code for the response envelope.
|
|
func SetBusinessCode(c *gin.Context, code int) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
c.Set(businessCodeKey, code)
|
|
}
|
|
|
|
// SetErrorDetails sets additional error details for the response envelope.
|
|
func SetErrorDetails(c *gin.Context, details any) {
|
|
if c == nil || details == nil {
|
|
return
|
|
}
|
|
c.Set(errorDetailsKey, details)
|
|
}
|
|
|
|
// SetErrorMessage sets a custom error message for the response envelope.
|
|
func SetErrorMessage(c *gin.Context, message string) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
c.Set(errorMessageKey, message)
|
|
}
|
|
|
|
func ResponseEnvelope() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Skip excluded paths
|
|
if isExcludedPath(c.Request.URL.Path) {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
originalWriter := c.Writer
|
|
writer := &envelopeWriter{ResponseWriter: originalWriter}
|
|
c.Writer = writer
|
|
|
|
c.Next()
|
|
|
|
status := writer.Status()
|
|
body := writer.body.Bytes()
|
|
if !bodyAllowedForStatus(status) || len(body) == 0 {
|
|
writeStatusOnly(originalWriter, status)
|
|
return
|
|
}
|
|
|
|
contentType := originalWriter.Header().Get("Content-Type")
|
|
if !isJSONContentType(contentType) {
|
|
writeThrough(originalWriter, status, body)
|
|
return
|
|
}
|
|
|
|
obj, objOK := parseObject(body)
|
|
if objOK && isEnvelopeObject(obj) {
|
|
writeThrough(originalWriter, status, body)
|
|
return
|
|
}
|
|
|
|
code := businessCodeFromContext(c)
|
|
if code == 0 {
|
|
code = defaultBusinessCode(status)
|
|
}
|
|
|
|
traceID := getTraceID(c)
|
|
isError := status >= http.StatusBadRequest
|
|
|
|
var message string
|
|
if customMsg := errorMessageFromContext(c); customMsg != "" {
|
|
message = customMsg
|
|
} else if isError && objOK {
|
|
if raw, ok := obj["error"]; ok {
|
|
var msg string
|
|
if err := json.Unmarshal(raw, &msg); err == nil {
|
|
message = msg
|
|
}
|
|
}
|
|
}
|
|
if message == "" {
|
|
if isError {
|
|
message = http.StatusText(status)
|
|
} else {
|
|
message = "success"
|
|
}
|
|
}
|
|
|
|
// Get details: from context first, then from original response
|
|
details := errorDetailsFromContext(c)
|
|
if details == nil && isError && objOK {
|
|
if raw, ok := obj["details"]; ok {
|
|
var d any
|
|
if err := json.Unmarshal(raw, &d); err == nil {
|
|
details = d
|
|
}
|
|
}
|
|
}
|
|
|
|
var data json.RawMessage
|
|
if isError {
|
|
data = json.RawMessage("null")
|
|
} else {
|
|
data = json.RawMessage(body)
|
|
}
|
|
|
|
envelope := responseEnvelope{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
TraceID: traceID,
|
|
Details: details,
|
|
}
|
|
payload, err := json.Marshal(envelope)
|
|
if err != nil {
|
|
writeThrough(originalWriter, status, body)
|
|
return
|
|
}
|
|
|
|
originalWriter.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
originalWriter.Header().Del("Content-Length")
|
|
writeThrough(originalWriter, status, payload)
|
|
}
|
|
}
|
|
|
|
func isExcludedPath(path string) bool {
|
|
if excludedPaths[path] {
|
|
return true
|
|
}
|
|
// Also exclude swagger paths
|
|
if strings.HasPrefix(path, "/swagger") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func businessCodeFromContext(c *gin.Context) int {
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
value, ok := c.Get(businessCodeKey)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
code, ok := value.(int)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return code
|
|
}
|
|
|
|
func errorDetailsFromContext(c *gin.Context) any {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
value, ok := c.Get(errorDetailsKey)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func errorMessageFromContext(c *gin.Context) string {
|
|
if c == nil {
|
|
return ""
|
|
}
|
|
value, ok := c.Get(errorMessageKey)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
msg, ok := value.(string)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return msg
|
|
}
|
|
|
|
func getTraceID(c *gin.Context) string {
|
|
if c == nil {
|
|
return generateTraceID()
|
|
}
|
|
// Try to get from context first (set by RequestID middleware)
|
|
if id, ok := c.Get("request_id"); ok {
|
|
if s, ok := id.(string); ok && s != "" {
|
|
return s
|
|
}
|
|
}
|
|
// Fallback to header
|
|
if id := c.GetHeader("X-Request-ID"); id != "" {
|
|
return id
|
|
}
|
|
// Generate UUID as fallback
|
|
return generateTraceID()
|
|
}
|
|
|
|
func generateTraceID() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func defaultBusinessCode(status int) int {
|
|
switch {
|
|
case status >= http.StatusOK && status < http.StatusMultipleChoices:
|
|
return CodeSuccess
|
|
case status == http.StatusBadRequest:
|
|
return CodeInvalidParam
|
|
case status == http.StatusUnauthorized:
|
|
return CodeUnauthorized
|
|
case status == http.StatusForbidden:
|
|
return CodeForbidden
|
|
case status == http.StatusNotFound:
|
|
return CodeResourceNotFound
|
|
case status == http.StatusConflict:
|
|
return CodeResourceConflict
|
|
case status == http.StatusTooManyRequests:
|
|
return CodeRateLimited
|
|
case status >= http.StatusBadRequest && status < http.StatusInternalServerError:
|
|
return CodeInvalidParam
|
|
case status == http.StatusServiceUnavailable:
|
|
return CodeServiceUnavailable
|
|
case status == http.StatusGatewayTimeout:
|
|
return CodeTimeout
|
|
case status >= http.StatusInternalServerError:
|
|
return CodeInternalError
|
|
default:
|
|
return CodeSuccess
|
|
}
|
|
}
|
|
|
|
func bodyAllowedForStatus(status int) bool {
|
|
switch {
|
|
case status >= 100 && status <= 199:
|
|
return false
|
|
case status == http.StatusNoContent:
|
|
return false
|
|
case status == http.StatusNotModified:
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func isJSONContentType(contentType string) bool {
|
|
if contentType == "" {
|
|
return false
|
|
}
|
|
return strings.Contains(strings.ToLower(contentType), "application/json")
|
|
}
|
|
|
|
func parseObject(body []byte) (map[string]json.RawMessage, bool) {
|
|
var obj map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &obj); err != nil {
|
|
return nil, false
|
|
}
|
|
return obj, true
|
|
}
|
|
|
|
func isEnvelopeObject(obj map[string]json.RawMessage) bool {
|
|
if obj == nil {
|
|
return false
|
|
}
|
|
// Only recognize new envelope format: must have trace_id and code must be number
|
|
codeRaw, hasCode := obj["code"]
|
|
_, hasData := obj["data"]
|
|
_, hasMessage := obj["message"]
|
|
_, hasTraceID := obj["trace_id"]
|
|
|
|
if !hasCode || !hasData || !hasMessage || !hasTraceID {
|
|
return false
|
|
}
|
|
|
|
// Code must be a number (int)
|
|
var codeNum int
|
|
if json.Unmarshal(codeRaw, &codeNum) != nil {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func writeStatusOnly(w gin.ResponseWriter, status int) {
|
|
w.WriteHeader(status)
|
|
}
|
|
|
|
func writeThrough(w gin.ResponseWriter, status int, body []byte) {
|
|
w.WriteHeader(status)
|
|
_, _ = w.Write(body)
|
|
}
|