1
// Copyright 2019 Aporeto Inc.
2
// Licensed under the Apache License, Version 2.0 (the "License");
3
// you may not use this file except in compliance with the License.
4
// You may obtain a copy of the License at
5
//     http://www.apache.org/licenses/LICENSE-2.0
6
// Unless required by applicable law or agreed to in writing, software
7
// distributed under the License is distributed on an "AS IS" BASIS,
8
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
// See the License for the specific language governing permissions and
10
// limitations under the License.
11

12
package elemental
13

14
import (
15
	"bytes"
16
	"fmt"
17
	"io"
18
	"mime"
19
	"net/http"
20
	"reflect"
21
	"strings"
22
	"sync"
23

24
	"github.com/ugorji/go/codec"
25
)
26

27
var (
28
	externalSupportedContentType = map[string]struct{}{}
29
	externalSupportedAcceptType  = map[string]struct{}{}
30
)
31

32
// RegisterSupportedContentType registers a new media type
33
// that elemental should support for Content-Type.
34
// Note that this needs external intervention to handle encoding.
35
func RegisterSupportedContentType(mimetype string) {
36 14
	externalSupportedContentType[mimetype] = struct{}{}
37
}
38

39
// RegisterSupportedAcceptType registers a new media type
40
// that elemental should support for Accept.
41
// Note that this needs external intervention to handle decoding.
42
func RegisterSupportedAcceptType(mimetype string) {
43 14
	externalSupportedAcceptType[mimetype] = struct{}{}
44
}
45

46
// An Encodable is the interface of objects
47
// that can hold encoding information.
48
type Encodable interface {
49
	GetEncoding() EncodingType
50
}
51

52
// A Encoder is an Encodable that can be encoded.
53
type Encoder interface {
54
	Encode(obj interface{}) (err error)
55
	Encodable
56
}
57

58
// A Decoder is an Encodable that can be decoded.
59
type Decoder interface {
60
	Decode(dst interface{}) error
61
	Encodable
62
}
63

64
// An EncodingType represents one type of data encoding
65
type EncodingType string
66

67
// Various values for EncodingType.
68
const (
69
	EncodingTypeJSON    EncodingType = "application/json"
70
	EncodingTypeMSGPACK EncodingType = "application/msgpack"
71
)
72

73
var (
74
	jsonHandle       = &codec.JsonHandle{}
75
	jsonEncodersPool = sync.Pool{
76 14
		New: func() interface{} {
77 14
			return codec.NewEncoder(nil, jsonHandle)
78 14
		},
79
	}
80
	jsonDecodersPool = sync.Pool{
81 14
		New: func() interface{} {
82 14
			return codec.NewDecoder(nil, jsonHandle)
83 14
		},
84
	}
85

86
	msgpackHandle       = &codec.MsgpackHandle{}
87
	msgpackEncodersPool = sync.Pool{
88 14
		New: func() interface{} {
89 14
			return codec.NewEncoder(nil, msgpackHandle)
90 14
		},
91
	}
92
	msgpackDecodersPool = sync.Pool{
93 14
		New: func() interface{} {
94 14
			return codec.NewDecoder(nil, msgpackHandle)
95 14
		},
96
	}
97
)
98

99
func init() {
100
	// If you need to understand all of this, go there http://ugorji.net/blog/go-codec-primer
101
	// But you should not need to touch that.
102 14
	jsonHandle.Canonical = true
103 14
	jsonHandle.MapType = reflect.TypeOf(map[string]interface{}(nil))
104

105 14
	msgpackHandle.Canonical = true
106 14
	msgpackHandle.WriteExt = true
107 14
	msgpackHandle.MapType = reflect.TypeOf(map[string]interface{}(nil))
108 14
	msgpackHandle.TypeInfos = codec.NewTypeInfos([]string{"msgpack"})
109
}
110

111
// Decode decodes the given data using an appropriate decoder chosen
112
// from the given encoding.
113
func Decode(encoding EncodingType, data []byte, dest interface{}) error {
114

115 14
	var pool *sync.Pool
116

117 14
	switch encoding {
118 14
	case EncodingTypeMSGPACK:
119 14
		pool = &msgpackDecodersPool
120 14
	default:
121 14
		pool = &jsonDecodersPool
122 14
		encoding = EncodingTypeJSON
123
	}
124

125 14
	dec := pool.Get().(*codec.Decoder)
126 14
	defer pool.Put(dec)
127

128 14
	dec.Reset(bytes.NewBuffer(data))
129

130 14
	if err := dec.Decode(dest); err != nil {
131 14
		return fmt.Errorf("unable to decode %s: %s", encoding, err.Error())
132
	}
133

134 14
	return nil
135
}
136

137
// Encode encodes the given object using an appropriate encoder chosen
138
// from the given acceptType.
139
func Encode(encoding EncodingType, obj interface{}) ([]byte, error) {
140

141 14
	if obj == nil {
142 14
		return nil, fmt.Errorf("encode received a nil object")
143
	}
144

145 14
	var pool *sync.Pool
146

147 14
	switch encoding {
148 14
	case EncodingTypeMSGPACK:
149 14
		pool = &msgpackEncodersPool
150 14
	default:
151 14
		pool = &jsonEncodersPool
152 14
		encoding = EncodingTypeJSON
153
	}
154

155 14
	enc := pool.Get().(*codec.Encoder)
156 14
	defer pool.Put(enc)
157

158 14
	buf := bytes.NewBuffer(nil)
159 14
	enc.Reset(buf)
160

161 14
	if err := enc.Encode(obj); err != nil {
162 0
		return nil, fmt.Errorf("unable to encode %s: %s", encoding, err.Error())
163
	}
164

165 14
	return buf.Bytes(), nil
166
}
167

168
// MakeStreamDecoder returns a function that can be used to decode a stream from the
169
// given reader using the given encoding.
170
//
171
// This function returns the decoder function that can be called until it returns an
172
// io.EOF error, indicating the stream is over, and a dispose function that will
173
// put back the decoder in the memory pool.
174
// The dispose function will be called automatically when the decoding is over,
175
// but not on a single decoding error.
176
// In any case, the dispose function should be always called, in a defer for example.
177
func MakeStreamDecoder(encoding EncodingType, reader io.Reader) (func(dest interface{}) error, func()) {
178

179 14
	var pool *sync.Pool
180

181 14
	switch encoding {
182 14
	case EncodingTypeMSGPACK:
183 14
		pool = &msgpackDecodersPool
184 14
	default:
185 14
		pool = &jsonDecodersPool
186
	}
187

188 14
	dec := pool.Get().(*codec.Decoder)
189 14
	dec.Reset(reader)
190

191
	clean := func() {
192 14
		if pool != nil {
193 14
			pool.Put(dec)
194 14
			pool = nil
195
		}
196
	}
197

198 14
	return func(dest interface{}) error {
199

200 14
			if err := dec.Decode(dest); err != nil {
201

202 14
				if err == io.EOF {
203 14
					clean()
204 14
					return err
205
				}
206

207 14
				return fmt.Errorf("unable to decode %s: %s", encoding, err.Error())
208
			}
209

210 14
			return nil
211 14
		}, func() {
212 14
			clean()
213
		}
214
}
215

216
// MakeStreamEncoder returns a function that can be user en encode given data
217
// into the given io.Writer using the given encoding.
218
//
219
// It also returns a function must be called once the encoding procedure
220
// is complete, so the internal encoders can be put back into the shared
221
// memory pools.
222
func MakeStreamEncoder(encoding EncodingType, writer io.Writer) (func(obj interface{}) error, func()) {
223

224 14
	var pool *sync.Pool
225

226 14
	switch encoding {
227 14
	case EncodingTypeMSGPACK:
228 14
		pool = &msgpackEncodersPool
229 14
	default:
230 14
		pool = &jsonEncodersPool
231
	}
232

233 14
	enc := pool.Get().(*codec.Encoder)
234 14
	enc.Reset(writer)
235

236
	clean := func() {
237 14
		if pool != nil {
238 14
			pool.Put(enc)
239 14
			pool = nil
240
		}
241
	}
242

243 14
	return func(dest interface{}) error {
244

245 14
			if err := enc.Encode(dest); err != nil {
246 0
				return fmt.Errorf("unable to encode %s: %s", encoding, err.Error())
247
			}
248

249 14
			return nil
250 14
		}, func() {
251 14
			clean()
252
		}
253
}
254

255
// Convert converts from one EncodingType to another
256
func Convert(from EncodingType, to EncodingType, data []byte) ([]byte, error) {
257

258 14
	if from == to {
259 14
		return data, nil
260
	}
261

262 14
	m := map[string]interface{}{}
263 14
	if err := Decode(from, data, &m); err != nil {
264 14
		return nil, err
265
	}
266

267 14
	return Encode(to, m)
268
}
269

270
// EncodingFromHeaders returns the read (Content-Type) and write (Accept) encoding
271
// from the given http.Header.
272
func EncodingFromHeaders(header http.Header) (read EncodingType, write EncodingType, err error) {
273

274 14
	read = EncodingTypeJSON
275 14
	write = EncodingTypeJSON
276

277 14
	if header == nil {
278 14
		return read, write, nil
279
	}
280

281 14
	if v := header.Get("Content-Type"); v != "" {
282 14
		ct, _, err := mime.ParseMediaType(v)
283 14
		if err != nil {
284 14
			return "", "", NewError("Bad Request", fmt.Sprintf("Invalid Content-Type header: %s", err), "elemental", http.StatusBadRequest)
285
		}
286

287 14
		switch ct {
288

289 14
		case "application/msgpack":
290 14
			read = EncodingTypeMSGPACK
291

292 14
		case "application/*", "*/*", "application/json":
293 14
			read = EncodingTypeJSON
294

295 14
		default:
296 14
			var supported bool
297
			for t := range externalSupportedContentType {
298 14
				if ct == t {
299 14
					supported = true
300 14
					break
301
				}
302
			}
303 14
			if !supported {
304 14
				return "", "", NewError("Unsupported Media Type", fmt.Sprintf("Cannot find any acceptable Content-Type media type in provided header: %s", v), "elemental", http.StatusUnsupportedMediaType)
305
			}
306

307 14
			read = EncodingType(ct)
308
		}
309
	}
310

311 14
	if v := header.Get("Accept"); v != "" {
312 14
		var agreed bool
313 14
	L:
314
		for _, item := range strings.Split(v, ",") {
315

316 14
			at, _, err := mime.ParseMediaType(item)
317 14
			if err != nil {
318 14
				return "", "", NewError("Bad Request", fmt.Sprintf("Invalid Accept header: %s", err), "elemental", http.StatusBadRequest)
319
			}
320

321 14
			switch at {
322

323 14
			case "application/msgpack":
324 14
				write = EncodingTypeMSGPACK
325 14
				agreed = true
326 14
				break L
327

328 14
			case "application/*", "*/*", "application/json":
329 14
				write = EncodingTypeJSON
330 14
				agreed = true
331 14
				break L
332

333 14
			default:
334
				for t := range externalSupportedAcceptType {
335 14
					if at == t {
336 14
						agreed = true
337 14
						write = EncodingType(at)
338 14
						break L
339
					}
340
				}
341
			}
342
		}
343

344 14
		if !agreed {
345 14
			return "", "", NewError("Unsupported Media Type", fmt.Sprintf("Cannot find any acceptable Accept media type in provided header: %s", v), "elemental", http.StatusUnsupportedMediaType)
346
		}
347
	}
348

349 14
	return read, write, nil
350
}

Read our documentation on viewing source code .

Loading