binder.go

  1package config
  2
  3import (
  4	"fmt"
  5	"go/token"
  6	"go/types"
  7
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/99designs/gqlgen/internal/code"
 10	"github.com/pkg/errors"
 11	"github.com/vektah/gqlparser/ast"
 12	"golang.org/x/tools/go/packages"
 13)
 14
 15// Binder connects graphql types to golang types using static analysis
 16type Binder struct {
 17	pkgs       map[string]*packages.Package
 18	schema     *ast.Schema
 19	cfg        *Config
 20	References []*TypeReference
 21}
 22
 23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
 24	pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
 25	if err != nil {
 26		return nil, err
 27	}
 28
 29	mp := map[string]*packages.Package{}
 30	for _, p := range pkgs {
 31		populatePkg(mp, p)
 32		for _, e := range p.Errors {
 33			if e.Kind == packages.ListError {
 34				return nil, p.Errors[0]
 35			}
 36		}
 37	}
 38
 39	return &Binder{
 40		pkgs:   mp,
 41		schema: s,
 42		cfg:    c,
 43	}, nil
 44}
 45
 46func populatePkg(mp map[string]*packages.Package, p *packages.Package) {
 47	imp := code.NormalizeVendor(p.PkgPath)
 48	if _, ok := mp[imp]; ok {
 49		return
 50	}
 51	mp[imp] = p
 52	for _, p := range p.Imports {
 53		populatePkg(mp, p)
 54	}
 55}
 56
 57func (b *Binder) TypePosition(typ types.Type) token.Position {
 58	named, isNamed := typ.(*types.Named)
 59	if !isNamed {
 60		return token.Position{
 61			Filename: "unknown",
 62		}
 63	}
 64
 65	return b.ObjectPosition(named.Obj())
 66}
 67
 68func (b *Binder) ObjectPosition(typ types.Object) token.Position {
 69	if typ == nil {
 70		return token.Position{
 71			Filename: "unknown",
 72		}
 73	}
 74	pkg := b.getPkg(typ.Pkg().Path())
 75	return pkg.Fset.Position(typ.Pos())
 76}
 77
 78func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
 79	obj, err := b.FindObject(pkgName, typeName)
 80	if err != nil {
 81		return nil, err
 82	}
 83
 84	if fun, isFunc := obj.(*types.Func); isFunc {
 85		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
 86	}
 87	return obj.Type(), nil
 88}
 89
 90func (b *Binder) getPkg(find string) *packages.Package {
 91	imp := code.NormalizeVendor(find)
 92	if p, ok := b.pkgs[imp]; ok {
 93		return p
 94	}
 95	return nil
 96}
 97
 98var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
 99var InterfaceType = types.NewInterfaceType(nil, nil)
100
101func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
102	models := b.cfg.Models[name].Model
103	if len(models) == 0 {
104		return nil, fmt.Errorf(name + " not found in typemap")
105	}
106
107	if models[0] == "map[string]interface{}" {
108		return MapType, nil
109	}
110
111	if models[0] == "interface{}" {
112		return InterfaceType, nil
113	}
114
115	pkgName, typeName := code.PkgAndType(models[0])
116	if pkgName == "" {
117		return nil, fmt.Errorf("missing package name for %s", name)
118	}
119
120	obj, err := b.FindObject(pkgName, typeName)
121	if err != nil {
122		return nil, err
123	}
124
125	return obj.Type(), nil
126}
127
128func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
129	if pkgName == "" {
130		return nil, fmt.Errorf("package cannot be nil")
131	}
132	fullName := typeName
133	if pkgName != "" {
134		fullName = pkgName + "." + typeName
135	}
136
137	pkg := b.getPkg(pkgName)
138	if pkg == nil {
139		return nil, errors.Errorf("required package was not loaded: %s", fullName)
140	}
141
142	// function based marshalers take precedence
143	for astNode, def := range pkg.TypesInfo.Defs {
144		// only look at defs in the top scope
145		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146			continue
147		}
148
149		if astNode.Name == "Marshal"+typeName {
150			return def, nil
151		}
152	}
153
154	// then look for types directly
155	for astNode, def := range pkg.TypesInfo.Defs {
156		// only look at defs in the top scope
157		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
158			continue
159		}
160
161		if astNode.Name == typeName {
162			return def, nil
163		}
164	}
165
166	return nil, errors.Errorf("unable to find type %s\n", fullName)
167}
168
169func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
170	newRef := &TypeReference{
171		GO:          types.NewPointer(ref.GO),
172		GQL:         ref.GQL,
173		CastType:    ref.CastType,
174		Definition:  ref.Definition,
175		Unmarshaler: ref.Unmarshaler,
176		Marshaler:   ref.Marshaler,
177		IsMarshaler: ref.IsMarshaler,
178	}
179
180	b.References = append(b.References, newRef)
181	return newRef
182}
183
184// TypeReference is used by args and field types. The Definition can refer to both input and output types.
185type TypeReference struct {
186	Definition  *ast.Definition
187	GQL         *ast.Type
188	GO          types.Type
189	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
190	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
191	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
192	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
193}
194
195func (ref *TypeReference) Elem() *TypeReference {
196	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
197		return &TypeReference{
198			GO:          p.Elem(),
199			GQL:         ref.GQL,
200			CastType:    ref.CastType,
201			Definition:  ref.Definition,
202			Unmarshaler: ref.Unmarshaler,
203			Marshaler:   ref.Marshaler,
204			IsMarshaler: ref.IsMarshaler,
205		}
206	}
207
208	if ref.IsSlice() {
209		return &TypeReference{
210			GO:          ref.GO.(*types.Slice).Elem(),
211			GQL:         ref.GQL.Elem,
212			CastType:    ref.CastType,
213			Definition:  ref.Definition,
214			Unmarshaler: ref.Unmarshaler,
215			Marshaler:   ref.Marshaler,
216			IsMarshaler: ref.IsMarshaler,
217		}
218	}
219	return nil
220}
221
222func (t *TypeReference) IsPtr() bool {
223	_, isPtr := t.GO.(*types.Pointer)
224	return isPtr
225}
226
227func (t *TypeReference) IsNilable() bool {
228	_, isPtr := t.GO.(*types.Pointer)
229	_, isMap := t.GO.(*types.Map)
230	_, isInterface := t.GO.(*types.Interface)
231	return isPtr || isMap || isInterface
232}
233
234func (t *TypeReference) IsSlice() bool {
235	_, isSlice := t.GO.(*types.Slice)
236	return t.GQL.Elem != nil && isSlice
237}
238
239func (t *TypeReference) IsNamed() bool {
240	_, isSlice := t.GO.(*types.Named)
241	return isSlice
242}
243
244func (t *TypeReference) IsStruct() bool {
245	_, isStruct := t.GO.Underlying().(*types.Struct)
246	return isStruct
247}
248
249func (t *TypeReference) IsScalar() bool {
250	return t.Definition.Kind == ast.Scalar
251}
252
253func (t *TypeReference) UniquenessKey() string {
254	var nullability = "O"
255	if t.GQL.NonNull {
256		nullability = "N"
257	}
258
259	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
260}
261
262func (t *TypeReference) MarshalFunc() string {
263	if t.Definition == nil {
264		panic(errors.New("Definition missing for " + t.GQL.Name()))
265	}
266
267	if t.Definition.Kind == ast.InputObject {
268		return ""
269	}
270
271	return "marshal" + t.UniquenessKey()
272}
273
274func (t *TypeReference) UnmarshalFunc() string {
275	if t.Definition == nil {
276		panic(errors.New("Definition missing for " + t.GQL.Name()))
277	}
278
279	if !t.Definition.IsInputType() {
280		return ""
281	}
282
283	return "unmarshal" + t.UniquenessKey()
284}
285
286func (b *Binder) PushRef(ret *TypeReference) {
287	b.References = append(b.References, ret)
288}
289
290func isMap(t types.Type) bool {
291	if t == nil {
292		return true
293	}
294	_, ok := t.(*types.Map)
295	return ok
296}
297
298func isIntf(t types.Type) bool {
299	if t == nil {
300		return true
301	}
302	_, ok := t.(*types.Interface)
303	return ok
304}
305
306func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
307	var pkgName, typeName string
308	def := b.schema.Types[schemaType.Name()]
309	defer func() {
310		if err == nil && ret != nil {
311			b.PushRef(ret)
312		}
313	}()
314
315	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
316		return nil, fmt.Errorf("%s was not found", schemaType.Name())
317	}
318
319	for _, model := range b.cfg.Models[schemaType.Name()].Model {
320		if model == "map[string]interface{}" {
321			if !isMap(bindTarget) {
322				continue
323			}
324			return &TypeReference{
325				Definition: def,
326				GQL:        schemaType,
327				GO:         MapType,
328			}, nil
329		}
330
331		if model == "interface{}" {
332			if !isIntf(bindTarget) {
333				continue
334			}
335			return &TypeReference{
336				Definition: def,
337				GQL:        schemaType,
338				GO:         InterfaceType,
339			}, nil
340		}
341
342		pkgName, typeName = code.PkgAndType(model)
343		if pkgName == "" {
344			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
345		}
346
347		ref := &TypeReference{
348			Definition: def,
349			GQL:        schemaType,
350		}
351
352		obj, err := b.FindObject(pkgName, typeName)
353		if err != nil {
354			return nil, err
355		}
356
357		if fun, isFunc := obj.(*types.Func); isFunc {
358			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
359			ref.Marshaler = fun
360			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
361		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
362			ref.GO = obj.Type()
363			ref.IsMarshaler = true
364		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
365			// Special case for named types wrapping strings. Used by default enum implementations.
366
367			ref.GO = obj.Type()
368			ref.CastType = underlying
369
370			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
371			if err != nil {
372				return nil, err
373			}
374
375			ref.Marshaler = underlyingRef.Marshaler
376			ref.Unmarshaler = underlyingRef.Unmarshaler
377		} else {
378			ref.GO = obj.Type()
379		}
380
381		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
382
383		if bindTarget != nil {
384			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
385				continue
386			}
387			ref.GO = bindTarget
388		}
389
390		return ref, nil
391	}
392
393	return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
394}
395
396func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
397	if t.Elem != nil {
398		child := b.CopyModifiersFromAst(t.Elem, base)
399		if _, isStruct := child.Underlying().(*types.Struct); isStruct {
400			child = types.NewPointer(child)
401		}
402		return types.NewSlice(child)
403	}
404
405	var isInterface bool
406	if named, ok := base.(*types.Named); ok {
407		_, isInterface = named.Underlying().(*types.Interface)
408	}
409
410	if !isInterface && !t.NonNull {
411		return types.NewPointer(base)
412	}
413
414	return base
415}
416
417func hasMethod(it types.Type, name string) bool {
418	if ptr, isPtr := it.(*types.Pointer); isPtr {
419		it = ptr.Elem()
420	}
421	namedType, ok := it.(*types.Named)
422	if !ok {
423		return false
424	}
425
426	for i := 0; i < namedType.NumMethods(); i++ {
427		if namedType.Method(i).Name() == name {
428			return true
429		}
430	}
431	return false
432}
433
434func basicUnderlying(it types.Type) *types.Basic {
435	if ptr, isPtr := it.(*types.Pointer); isPtr {
436		it = ptr.Elem()
437	}
438	namedType, ok := it.(*types.Named)
439	if !ok {
440		return nil
441	}
442
443	if basic, ok := namedType.Underlying().(*types.Basic); ok {
444		return basic
445	}
446
447	return nil
448}