1
// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
2
// Use of this source code is governed by a MIT style
3
// license that can be found in the LICENSE file.
4

5
package gin
6

7
import (
8
	"bytes"
9
	"fmt"
10
	"io"
11
	"io/ioutil"
12
	"log"
13
	"net"
14
	"net/http"
15
	"net/http/httputil"
16
	"os"
17
	"runtime"
18
	"strings"
19
	"time"
20
)
21

22
var (
23
	dunno     = []byte("???")
24
	centerDot = []byte("·")
25
	dot       = []byte(".")
26
	slash     = []byte("/")
27
)
28

29
// RecoveryFunc defines the function passable to CustomRecovery.
30
type RecoveryFunc func(c *Context, err interface{})
31

32
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
33
func Recovery() HandlerFunc {
34 14
	return RecoveryWithWriter(DefaultErrorWriter)
35
}
36

37
//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
38
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
39 14
	return RecoveryWithWriter(DefaultErrorWriter, handle)
40
}
41

42
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
43
func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
44 14
	if len(recovery) > 0 {
45 14
		return CustomRecoveryWithWriter(out, recovery[0])
46
	}
47 14
	return CustomRecoveryWithWriter(out, defaultHandleRecovery)
48
}
49

50
// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
51
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
52 14
	var logger *log.Logger
53 14
	if out != nil {
54 14
		logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
55
	}
56 14
	return func(c *Context) {
57 14
		defer func() {
58 14
			if err := recover(); err != nil {
59
				// Check for a broken connection, as it is not really a
60
				// condition that warrants a panic stack trace.
61 14
				var brokenPipe bool
62 14
				if ne, ok := err.(*net.OpError); ok {
63 14
					if se, ok := ne.Err.(*os.SyscallError); ok {
64 14
						if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
65 14
							brokenPipe = true
66
						}
67
					}
68
				}
69 14
				if logger != nil {
70 14
					stack := stack(3)
71 14
					httpRequest, _ := httputil.DumpRequest(c.Request, false)
72 14
					headers := strings.Split(string(httpRequest), "\r\n")
73
					for idx, header := range headers {
74 14
						current := strings.Split(header, ":")
75 14
						if current[0] == "Authorization" {
76 14
							headers[idx] = current[0] + ": *"
77
						}
78
					}
79 14
					headersToStr := strings.Join(headers, "\r\n")
80 14
					if brokenPipe {
81 14
						logger.Printf("%s\n%s%s", err, headersToStr, reset)
82 14
					} else if IsDebugging() {
83 14
						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
84 14
							timeFormat(time.Now()), headersToStr, err, stack, reset)
85 14
					} else {
86 14
						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
87 14
							timeFormat(time.Now()), err, stack, reset)
88
					}
89
				}
90 14
				if brokenPipe {
91
					// If the connection is dead, we can't write a status to it.
92 14
					c.Error(err.(error)) // nolint: errcheck
93 14
					c.Abort()
94 14
				} else {
95 14
					handle(c, err)
96
				}
97
			}
98
		}()
99 14
		c.Next()
100
	}
101
}
102

103
func defaultHandleRecovery(c *Context, err interface{}) {
104 14
	c.AbortWithStatus(http.StatusInternalServerError)
105
}
106

107
// stack returns a nicely formatted stack frame, skipping skip frames.
108
func stack(skip int) []byte {
109 14
	buf := new(bytes.Buffer) // the returned data
110
	// As we loop, we open files and read them. These variables record the currently
111
	// loaded file.
112 14
	var lines [][]byte
113 14
	var lastFile string
114 14
	for i := skip; ; i++ { // Skip the expected number of frames
115 14
		pc, file, line, ok := runtime.Caller(i)
116 14
		if !ok {
117 14
			break
118
		}
119
		// Print this much at least.  If we can't find the source, it won't show.
120 14
		fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
121 14
		if file != lastFile {
122 14
			data, err := ioutil.ReadFile(file)
123 14
			if err != nil {
124 0
				continue
125
			}
126 14
			lines = bytes.Split(data, []byte{'\n'})
127 14
			lastFile = file
128
		}
129 14
		fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
130
	}
131 14
	return buf.Bytes()
132
}
133

134
// source returns a space-trimmed slice of the n'th line.
135
func source(lines [][]byte, n int) []byte {
136 14
	n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
137 14
	if n < 0 || n >= len(lines) {
138 14
		return dunno
139
	}
140 14
	return bytes.TrimSpace(lines[n])
141
}
142

143
// function returns, if possible, the name of the function containing the PC.
144
func function(pc uintptr) []byte {
145 14
	fn := runtime.FuncForPC(pc)
146 14
	if fn == nil {
147 14
		return dunno
148
	}
149 14
	name := []byte(fn.Name())
150
	// The name includes the path name to the package, which is unnecessary
151
	// since the file name is already included.  Plus, it has center dots.
152
	// That is, we see
153
	//	runtime/debug.*T·ptrmethod
154
	// and want
155
	//	*T.ptrmethod
156
	// Also the package path might contains dot (e.g. code.google.com/...),
157
	// so first eliminate the path prefix
158 14
	if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
159 14
		name = name[lastSlash+1:]
160
	}
161 14
	if period := bytes.Index(name, dot); period >= 0 {
162 14
		name = name[period+1:]
163
	}
164 14
	name = bytes.Replace(name, centerDot, dot, -1)
165 14
	return name
166
}
167

168
func timeFormat(t time.Time) string {
169 14
	timeString := t.Format("2006/01/02 - 15:04:05")
170 14
	return timeString
171
}

Read our documentation on viewing source code .

Loading