uber-go / dig
1
// Copyright (c) 2019 Uber Technologies, Inc.
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining a copy
4
// of this software and associated documentation files (the "Software"), to deal
5
// in the Software without restriction, including without limitation the rights
6
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
// copies of the Software, and to permit persons to whom the Software is
8
// furnished to do so, subject to the following conditions:
9
//
10
// The above copyright notice and this permission notice shall be included in
11
// all copies or substantial portions of the Software.
12
//
13
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
// THE SOFTWARE.
20

21
package dig
22

23
import (
24
	"errors"
25
	"fmt"
26
	"reflect"
27

28
	"go.uber.org/dig/internal/dot"
29
)
30

31
// The param interface represents a dependency for a constructor.
32
//
33
// The following implementations exist:
34
//  paramList     All arguments of the constructor.
35
//  paramSingle   An explicitly requested type.
36
//  paramObject   dig.In struct where each field in the struct can be another
37
//                param.
38
//  paramGroupedSlice
39
//                A slice consuming a value group. This will receive all
40
//                values produced with a `group:".."` tag with the same name
41
//                as a slice.
42
type param interface {
43
	fmt.Stringer
44

45
	// Builds this dependency and any of its dependencies from the provided
46
	// Container.
47
	//
48
	// This MAY panic if the param does not produce a single value.
49
	Build(containerStore) (reflect.Value, error)
50

51
	// DotParam returns a slice of dot.Param(s).
52
	DotParam() []*dot.Param
53
}
54

55
var (
56
	_ param = paramSingle{}
57
	_ param = paramObject{}
58
	_ param = paramList{}
59
	_ param = paramGroupedSlice{}
60
)
61

62
// newParam builds a param from the given type. If the provided type is a
63
// dig.In struct, an paramObject will be returned.
64
func newParam(t reflect.Type) (param, error) {
65 2
	switch {
66 2
	case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType):
67 2
		return nil, errf("cannot depend on result objects", "%v embeds a dig.Out", t)
68 2
	case IsIn(t):
69 2
		return newParamObject(t)
70 2
	case embedsType(t, _inPtrType):
71 2
		return nil, errf(
72 2
			"cannot build a parameter object by embedding *dig.In, embed dig.In instead",
73 2
			"%v embeds *dig.In", t)
74 2
	case t.Kind() == reflect.Ptr && IsIn(t.Elem()):
75 2
		return nil, errf(
76 2
			"cannot depend on a pointer to a parameter object, use a value instead",
77 2
			"%v is a pointer to a struct that embeds dig.In", t)
78 2
	default:
79 2
		return paramSingle{Type: t}, nil
80
	}
81
}
82

83
// paramVisitor visits every param in a param tree, allowing tracking state at
84
// each level.
85
type paramVisitor interface {
86
	// Visit is called on the param being visited.
87
	//
88
	// If Visit returns a non-nil paramVisitor, that paramVisitor visits all
89
	// the child params of this param.
90
	Visit(param) paramVisitor
91

92
	// We can implement AnnotateWithField and AnnotateWithPosition like
93
	// resultVisitor if we need to track that information in the future.
94
}
95

96
// paramVisitorFunc is a paramVisitor that visits param in a tree with the
97
// return value deciding whether the descendants of this param should be
98
// recursed into.
99
type paramVisitorFunc func(param) (recurse bool)
100

101
func (f paramVisitorFunc) Visit(p param) paramVisitor {
102 2
	if f(p) {
103 2
		return f
104
	}
105 2
	return nil
106
}
107

108
// walkParam walks the param tree for the given param with the provided
109
// visitor.
110
//
111
// paramVisitor.Visit will be called on the provided param and if a non-nil
112
// paramVisitor is received, this param's descendants will be walked with that
113
// visitor.
114
//
115
// This is very similar to how go/ast.Walk works.
116
func walkParam(p param, v paramVisitor) {
117 2
	v = v.Visit(p)
118 2
	if v == nil {
119 2
		return
120
	}
121

122 2
	switch par := p.(type) {
123 2
	case paramSingle, paramGroupedSlice:
124
		// No sub-results
125 2
	case paramObject:
126
		for _, f := range par.Fields {
127 2
			walkParam(f.Param, v)
128
		}
129 2
	case paramList:
130
		for _, p := range par.Params {
131 2
			walkParam(p, v)
132
		}
133 0
	default:
134 0
		panic(fmt.Sprintf(
135 0
			"It looks like you have found a bug in dig. "+
136 0
				"Please file an issue at https://github.com/uber-go/dig/issues/ "+
137 0
				"and provide the following message: "+
138 0
				"received unknown param type %T", p))
139
	}
140
}
141

142
// paramList holds all arguments of the constructor as params.
143
//
144
// NOTE: Build() MUST NOT be called on paramList. Instead, BuildList
145
// must be called.
146
type paramList struct {
147
	ctype reflect.Type // type of the constructor
148

149
	Params []param
150
}
151

152
func (pl paramList) DotParam() []*dot.Param {
153 2
	var types []*dot.Param
154
	for _, param := range pl.Params {
155 2
		types = append(types, param.DotParam()...)
156
	}
157 2
	return types
158
}
159

160
// newParamList builds a paramList from the provided constructor type.
161
//
162
// Variadic arguments of a constructor are ignored and not included as
163
// dependencies.
164
func newParamList(ctype reflect.Type) (paramList, error) {
165 2
	numArgs := ctype.NumIn()
166 2
	if ctype.IsVariadic() {
167
		// NOTE: If the function is variadic, we skip the last argument
168
		// because we're not filling variadic arguments yet. See #120.
169 2
		numArgs--
170
	}
171

172 2
	pl := paramList{
173 2
		ctype:  ctype,
174 2
		Params: make([]param, 0, numArgs),
175
	}
176

177
	for i := 0; i < numArgs; i++ {
178 2
		p, err := newParam(ctype.In(i))
179 2
		if err != nil {
180 2
			return pl, errf("bad argument %d", i+1, err)
181
		}
182 2
		pl.Params = append(pl.Params, p)
183
	}
184

185 2
	return pl, nil
186
}
187

188
func (pl paramList) Build(containerStore) (reflect.Value, error) {
189 2
	panic("It looks like you have found a bug in dig. " +
190 2
		"Please file an issue at https://github.com/uber-go/dig/issues/ " +
191 2
		"and provide the following message: " +
192 2
		"paramList.Build() must never be called")
193
}
194

195
// BuildList returns an ordered list of values which may be passed directly
196
// to the underlying constructor.
197
func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) {
198 2
	args := make([]reflect.Value, len(pl.Params))
199
	for i, p := range pl.Params {
200 2
		var err error
201 2
		args[i], err = p.Build(c)
202 2
		if err != nil {
203 2
			return nil, err
204
		}
205
	}
206 2
	return args, nil
207
}
208

209
// paramSingle is an explicitly requested type, optionally with a name.
210
//
211
// This object must be present in the graph as-is unless it's specified as
212
// optional.
213
type paramSingle struct {
214
	Name     string
215
	Optional bool
216
	Type     reflect.Type
217
}
218

219
func (ps paramSingle) DotParam() []*dot.Param {
220 2
	return []*dot.Param{
221
		{
222 2
			Node: &dot.Node{
223 2
				Type: ps.Type,
224 2
				Name: ps.Name,
225 2
			},
226 2
			Optional: ps.Optional,
227 2
		},
228
	}
229
}
230

231
func (ps paramSingle) Build(c containerStore) (reflect.Value, error) {
232 2
	if v, ok := c.getValue(ps.Name, ps.Type); ok {
233 2
		return v, nil
234
	}
235

236 2
	providers := c.getValueProviders(ps.Name, ps.Type)
237 2
	if len(providers) == 0 {
238 2
		if ps.Optional {
239 2
			return reflect.Zero(ps.Type), nil
240
		}
241 0
		return _noValue, newErrMissingTypes(c, key{name: ps.Name, t: ps.Type})
242
	}
243

244
	for _, n := range providers {
245 2
		err := n.Call(c)
246 2
		if err == nil {
247 2
			continue
248
		}
249

250
		// If we're missing dependencies but the parameter itself is optional,
251
		// we can just move on.
252 2
		if _, ok := err.(errMissingDependencies); ok && ps.Optional {
253 2
			return reflect.Zero(ps.Type), nil
254
		}
255

256 2
		return _noValue, errParamSingleFailed{
257 2
			CtorID: n.ID(),
258 2
			Key:    key{t: ps.Type, name: ps.Name},
259 2
			Reason: err,
260
		}
261
	}
262

263
	// If we get here, it's impossible for the value to be absent from the
264
	// container.
265 2
	v, _ := c.getValue(ps.Name, ps.Type)
266 2
	return v, nil
267
}
268

269
// paramObject is a dig.In struct where each field is another param.
270
//
271
// This object is not expected in the graph as-is.
272
type paramObject struct {
273
	Type   reflect.Type
274
	Fields []paramObjectField
275
}
276

277
func (po paramObject) DotParam() []*dot.Param {
278 2
	var types []*dot.Param
279
	for _, field := range po.Fields {
280 2
		types = append(types, field.DotParam()...)
281
	}
282 2
	return types
283
}
284

285
// newParamObject builds an paramObject from the provided type. The type MUST
286
// be a dig.In struct.
287
func newParamObject(t reflect.Type) (paramObject, error) {
288 2
	po := paramObject{Type: t}
289

290
	// Check if the In type supports ignoring unexported fields.
291 2
	var ignoreUnexported bool
292
	for i := 0; i < t.NumField(); i++ {
293 2
		f := t.Field(i)
294 2
		if f.Type == _inType {
295 2
			var err error
296 2
			ignoreUnexported, err = isIgnoreUnexportedSet(f)
297 2
			if err != nil {
298 2
				return po, err
299
			}
300 2
			break
301
		}
302
	}
303

304
	for i := 0; i < t.NumField(); i++ {
305 2
		f := t.Field(i)
306 2
		if f.Type == _inType {
307
			// Skip over the dig.In embed.
308 2
			continue
309
		}
310 2
		if f.PkgPath != "" && ignoreUnexported {
311
			// Skip over an unexported field if it is allowed.
312 2
			continue
313
		}
314 2
		pof, err := newParamObjectField(i, f)
315 2
		if err != nil {
316 2
			return po, errf("bad field %q of %v", f.Name, t, err)
317
		}
318

319 2
		po.Fields = append(po.Fields, pof)
320
	}
321

322 2
	return po, nil
323
}
324

325
func (po paramObject) Build(c containerStore) (reflect.Value, error) {
326 2
	dest := reflect.New(po.Type).Elem()
327
	for _, f := range po.Fields {
328 2
		v, err := f.Build(c)
329 2
		if err != nil {
330 2
			return dest, err
331
		}
332 2
		dest.Field(f.FieldIndex).Set(v)
333
	}
334 2
	return dest, nil
335
}
336

337
// paramObjectField is a single field of a dig.In struct.
338
type paramObjectField struct {
339
	// Name of the field in the struct.
340
	FieldName string
341

342
	// Index of this field in the target struct.
343
	//
344
	// We need to track this separately because not all fields of the
345
	// struct map to params.
346
	FieldIndex int
347

348
	// The dependency requested by this field.
349
	Param param
350
}
351

352
func (pof paramObjectField) DotParam() []*dot.Param {
353 2
	return pof.Param.DotParam()
354
}
355

356
func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, error) {
357 2
	pof := paramObjectField{
358 2
		FieldName:  f.Name,
359 2
		FieldIndex: idx,
360
	}
361

362 2
	var p param
363 2
	switch {
364 2
	case f.PkgPath != "":
365 2
		return pof, errf(
366 2
			"unexported fields not allowed in dig.In, did you mean to export %q (%v)?",
367 2
			f.Name, f.Type)
368

369 2
	case f.Tag.Get(_groupTag) != "":
370 2
		var err error
371 2
		p, err = newParamGroupedSlice(f)
372 2
		if err != nil {
373 2
			return pof, err
374
		}
375

376 2
	default:
377 2
		var err error
378 2
		p, err = newParam(f.Type)
379 2
		if err != nil {
380 2
			return pof, err
381
		}
382
	}
383

384 2
	if ps, ok := p.(paramSingle); ok {
385 2
		ps.Name = f.Tag.Get(_nameTag)
386

387 2
		var err error
388 2
		ps.Optional, err = isFieldOptional(f)
389 2
		if err != nil {
390 2
			return pof, err
391
		}
392

393 2
		p = ps
394
	}
395

396 2
	pof.Param = p
397 2
	return pof, nil
398
}
399

400
func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) {
401 2
	v, err := pof.Param.Build(c)
402 2
	if err != nil {
403 2
		return v, err
404
	}
405 2
	return v, nil
406
}
407

408
// paramGroupedSlice is a param which produces a slice of values with the same
409
// group name.
410
type paramGroupedSlice struct {
411
	// Name of the group as specified in the `group:".."` tag.
412
	Group string
413

414
	// Type of the slice.
415
	Type reflect.Type
416
}
417

418
func (pt paramGroupedSlice) DotParam() []*dot.Param {
419 2
	return []*dot.Param{
420
		{
421 2
			Node: &dot.Node{
422 2
				Type:  pt.Type,
423 2
				Group: pt.Group,
424 2
			},
425 2
		},
426
	}
427
}
428

429
// newParamGroupedSlice builds a paramGroupedSlice from the provided type with
430
// the given name.
431
//
432
// The type MUST be a slice type.
433
func newParamGroupedSlice(f reflect.StructField) (paramGroupedSlice, error) {
434 2
	g, err := parseGroupString(f.Tag.Get(_groupTag))
435 2
	if err != nil {
436 0
		return paramGroupedSlice{}, err
437
	}
438 2
	pg := paramGroupedSlice{Group: g.Name, Type: f.Type}
439

440 2
	name := f.Tag.Get(_nameTag)
441 2
	optional, _ := isFieldOptional(f)
442 2
	switch {
443 2
	case f.Type.Kind() != reflect.Slice:
444 2
		return pg, errf("value groups may be consumed as slices only",
445 2
			"field %q (%v) is not a slice", f.Name, f.Type)
446 2
	case g.Flatten:
447 2
		return pg, errf("cannot use flatten in parameter value groups",
448 2
			"field %q (%v) specifies flatten", f.Name, f.Type)
449 2
	case name != "":
450 2
		return pg, errf(
451 2
			"cannot use named values with value groups",
452 2
			"name:%q requested with group:%q", name, pg.Group)
453

454 2
	case optional:
455 2
		return pg, errors.New("value groups cannot be optional")
456
	}
457

458 2
	return pg, nil
459
}
460

461
func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) {
462
	for _, n := range c.getGroupProviders(pt.Group, pt.Type.Elem()) {
463 2
		if err := n.Call(c); err != nil {
464 2
			return _noValue, errParamGroupFailed{
465 2
				CtorID: n.ID(),
466 2
				Key:    key{group: pt.Group, t: pt.Type.Elem()},
467 2
				Reason: err,
468
			}
469
		}
470
	}
471

472 2
	items := c.getValueGroup(pt.Group, pt.Type.Elem())
473

474 2
	result := reflect.MakeSlice(pt.Type, len(items), len(items))
475
	for i, v := range items {
476 2
		result.Index(i).Set(v)
477
	}
478 2
	return result, nil
479
}

Read our documentation on viewing source code .

Loading