gin-gonic / gin
1
// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
2
// Use of this source code is governed by a MIT style
3
// license that can be found in the LICENSE file.
4

5
package gin
6

7
import (
8
	"bytes"
9
	"fmt"
10
	"io"
11
	"io/ioutil"
12
	"log"
13
	"net"
14
	"net/http"
15
	"net/http/httputil"
16
	"os"
17
	"runtime"
18
	"strings"
19
	"time"
20
)
21

22
var (
23
	dunno     = []byte("???")
24
	centerDot = []byte("·")
25
	dot       = []byte(".")
26
	slash     = []byte("/")
27
)
28

29
// RecoveryFunc defines the function passable to CustomRecovery.
30
type RecoveryFunc func(c *Context, err interface{})
31

32
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
33
func Recovery() HandlerFunc {
34 9
	return RecoveryWithWriter(DefaultErrorWriter)
35
}
36

37
//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
38
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
39 9
	return RecoveryWithWriter(DefaultErrorWriter, handle)
40
}
41

42
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
43
func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
44 9
	if len(recovery) > 0 {
45 9
		return CustomRecoveryWithWriter(out, recovery[0])
46
	}
47 9
	return CustomRecoveryWithWriter(out, defaultHandleRecovery)
48
}
49

50
// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
51
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
52 9
	var logger *log.Logger
53 9
	if out != nil {
54 9
		logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
55
	}
56 9
	return func(c *Context) {
57 9
		defer func() {
58 9
			if err := recover(); err != nil {
59
				// Check for a broken connection, as it is not really a
60
				// condition that warrants a panic stack trace.
61 9
				var brokenPipe bool
62 9
				if ne, ok := err.(*net.OpError); ok {
63 9
					if se, ok := ne.Err.(*os.SyscallError); ok {
64 9
						if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
65 9
							brokenPipe = true
66
						}
67
					}
68
				}
69 9
				if logger != nil {
70 9
					stack := stack(3)
71 9
					httpRequest, _ := httputil.DumpRequest(c.Request, false)
72 9
					headers := strings.Split(string(httpRequest), "\r\n")
73
					for idx, header := range headers {
74 9
						current := strings.Split(header, ":")
75 9
						if current[0] == "Authorization" {
76 9
							headers[idx] = current[0] + ": *"
77
						}
78
					}
79 9
					headersToStr := strings.Join(headers, "\r\n")
80 9
					if brokenPipe {
81 9
						logger.Printf("%s\n%s%s", err, headersToStr, reset)
82 9
					} else if IsDebugging() {
83 9
						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
84 9
							timeFormat(time.Now()), headersToStr, err, stack, reset)
85 9
					} else {
86 9
						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
87 9
							timeFormat(time.Now()), err, stack, reset)
88
					}
89
				}
90 9
				if brokenPipe {
91
					// If the connection is dead, we can't write a status to it.
92 9
					c.Error(err.(error)) // nolint: errcheck
93 9
					c.Abort()
94 9
				} else {
95 9
					handle(c, err)
96
				}
97
			}
98
		}()
99 9
		c.Next()
100
	}
101
}
102

103
func defaultHandleRecovery(c *Context, err interface{}) {
104 9
	c.AbortWithStatus(http.StatusInternalServerError)
105
}
106

107
// stack returns a nicely formatted stack frame, skipping skip frames.
108
func stack(skip int) []byte {
109 9
	buf := new(bytes.Buffer) // the returned data
110
	// As we loop, we open files and read them. These variables record the currently
111
	// loaded file.
112 9
	var lines [][]byte
113 9
	var lastFile string
114 9
	for i := skip; ; i++ { // Skip the expected number of frames
115 9
		pc, file, line, ok := runtime.Caller(i)
116 9
		if !ok {
117 9
			break
118
		}
119
		// Print this much at least.  If we can't find the source, it won't show.
120 9
		fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
121 9
		if file != lastFile {
122 9
			data, err := ioutil.ReadFile(file)
123 9
			if err != nil {
124 0
				continue
125
			}
126 9
			lines = bytes.Split(data, []byte{'\n'})
127 9
			lastFile = file
128
		}
129 9
		fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
130
	}
131 9
	return buf.Bytes()
132
}
133

134
// source returns a space-trimmed slice of the n'th line.
135
func source(lines [][]byte, n int) []byte {
136 9
	n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
137 9
	if n < 0 || n >= len(lines) {
138 9
		return dunno
139
	}
140 9
	return bytes.TrimSpace(lines[n])
141
}
142

143
// function returns, if possible, the name of the function containing the PC.
144
func function(pc uintptr) []byte {
145 9
	fn := runtime.FuncForPC(pc)
146 9
	if fn == nil {
147 9
		return dunno
148
	}
149 9
	name := []byte(fn.Name())
150
	// The name includes the path name to the package, which is unnecessary
151
	// since the file name is already included.  Plus, it has center dots.
152
	// That is, we see
153
	//	runtime/debug.*T·ptrmethod
154
	// and want
155
	//	*T.ptrmethod
156
	// Also the package path might contains dot (e.g. code.google.com/...),
157
	// so first eliminate the path prefix
158 9
	if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
159 9
		name = name[lastSlash+1:]
160
	}
161 9
	if period := bytes.Index(name, dot); period >= 0 {
162 9
		name = name[period+1:]
163
	}
164 9
	name = bytes.Replace(name, centerDot, dot, -1)
165 9
	return name
166
}
167

168
func timeFormat(t time.Time) string {
169 9
	timeString := t.Format("2006/01/02 - 15:04:05")
170 9
	return timeString
171
}

Read our documentation on viewing source code .

Loading