binder.go

  1package config
  2
  3import (
  4	"fmt"
  5	"go/token"
  6	"go/types"
  7
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/99designs/gqlgen/internal/code"
 10	"github.com/pkg/errors"
 11	"github.com/vektah/gqlparser/ast"
 12	"golang.org/x/tools/go/packages"
 13)
 14
 15// Binder connects graphql types to golang types using static analysis
 16type Binder struct {
 17	pkgs       []*packages.Package
 18	schema     *ast.Schema
 19	cfg        *Config
 20	References []*TypeReference
 21}
 22
 23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
 24	pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
 25	if err != nil {
 26		return nil, err
 27	}
 28
 29	for _, p := range pkgs {
 30		for _, e := range p.Errors {
 31			if e.Kind == packages.ListError {
 32				return nil, p.Errors[0]
 33			}
 34		}
 35	}
 36
 37	return &Binder{
 38		pkgs:   pkgs,
 39		schema: s,
 40		cfg:    c,
 41	}, nil
 42}
 43
 44func (b *Binder) TypePosition(typ types.Type) token.Position {
 45	named, isNamed := typ.(*types.Named)
 46	if !isNamed {
 47		return token.Position{
 48			Filename: "unknown",
 49		}
 50	}
 51
 52	return b.ObjectPosition(named.Obj())
 53}
 54
 55func (b *Binder) ObjectPosition(typ types.Object) token.Position {
 56	if typ == nil {
 57		return token.Position{
 58			Filename: "unknown",
 59		}
 60	}
 61	pkg := b.getPkg(typ.Pkg().Path())
 62	return pkg.Fset.Position(typ.Pos())
 63}
 64
 65func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
 66	obj, err := b.FindObject(pkgName, typeName)
 67	if err != nil {
 68		return nil, err
 69	}
 70
 71	if fun, isFunc := obj.(*types.Func); isFunc {
 72		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
 73	}
 74	return obj.Type(), nil
 75}
 76
 77func (b *Binder) getPkg(find string) *packages.Package {
 78	for _, p := range b.pkgs {
 79		if code.NormalizeVendor(find) == code.NormalizeVendor(p.PkgPath) {
 80			return p
 81		}
 82	}
 83	return nil
 84}
 85
 86var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
 87var InterfaceType = types.NewInterfaceType(nil, nil)
 88
 89func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
 90	models := b.cfg.Models[name].Model
 91	if len(models) == 0 {
 92		return nil, fmt.Errorf(name + " not found in typemap")
 93	}
 94
 95	if models[0] == "map[string]interface{}" {
 96		return MapType, nil
 97	}
 98
 99	if models[0] == "interface{}" {
100		return InterfaceType, nil
101	}
102
103	pkgName, typeName := code.PkgAndType(models[0])
104	if pkgName == "" {
105		return nil, fmt.Errorf("missing package name for %s", name)
106	}
107
108	obj, err := b.FindObject(pkgName, typeName)
109	if err != nil {
110		return nil, err
111	}
112
113	return obj.Type(), nil
114}
115
116func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
117	if pkgName == "" {
118		return nil, fmt.Errorf("package cannot be nil")
119	}
120	fullName := typeName
121	if pkgName != "" {
122		fullName = pkgName + "." + typeName
123	}
124
125	pkg := b.getPkg(pkgName)
126	if pkg == nil {
127		return nil, errors.Errorf("required package was not loaded: %s", fullName)
128	}
129
130	// function based marshalers take precedence
131	for astNode, def := range pkg.TypesInfo.Defs {
132		// only look at defs in the top scope
133		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
134			continue
135		}
136
137		if astNode.Name == "Marshal"+typeName {
138			return def, nil
139		}
140	}
141
142	// then look for types directly
143	for astNode, def := range pkg.TypesInfo.Defs {
144		// only look at defs in the top scope
145		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146			continue
147		}
148
149		if astNode.Name == typeName {
150			return def, nil
151		}
152	}
153
154	return nil, errors.Errorf("unable to find type %s\n", fullName)
155}
156
157func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
158	newRef := &TypeReference{
159		GO:          types.NewPointer(ref.GO),
160		GQL:         ref.GQL,
161		CastType:    ref.CastType,
162		Definition:  ref.Definition,
163		Unmarshaler: ref.Unmarshaler,
164		Marshaler:   ref.Marshaler,
165		IsMarshaler: ref.IsMarshaler,
166	}
167
168	b.References = append(b.References, newRef)
169	return newRef
170}
171
172// TypeReference is used by args and field types. The Definition can refer to both input and output types.
173type TypeReference struct {
174	Definition  *ast.Definition
175	GQL         *ast.Type
176	GO          types.Type
177	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
178	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
179	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
180	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
181}
182
183func (ref *TypeReference) Elem() *TypeReference {
184	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
185		return &TypeReference{
186			GO:          p.Elem(),
187			GQL:         ref.GQL,
188			CastType:    ref.CastType,
189			Definition:  ref.Definition,
190			Unmarshaler: ref.Unmarshaler,
191			Marshaler:   ref.Marshaler,
192			IsMarshaler: ref.IsMarshaler,
193		}
194	}
195
196	if ref.IsSlice() {
197		return &TypeReference{
198			GO:          ref.GO.(*types.Slice).Elem(),
199			GQL:         ref.GQL.Elem,
200			CastType:    ref.CastType,
201			Definition:  ref.Definition,
202			Unmarshaler: ref.Unmarshaler,
203			Marshaler:   ref.Marshaler,
204			IsMarshaler: ref.IsMarshaler,
205		}
206	}
207	return nil
208}
209
210func (t *TypeReference) IsPtr() bool {
211	_, isPtr := t.GO.(*types.Pointer)
212	return isPtr
213}
214
215func (t *TypeReference) IsNilable() bool {
216	_, isPtr := t.GO.(*types.Pointer)
217	_, isMap := t.GO.(*types.Map)
218	_, isInterface := t.GO.(*types.Interface)
219	return isPtr || isMap || isInterface
220}
221
222func (t *TypeReference) IsSlice() bool {
223	_, isSlice := t.GO.(*types.Slice)
224	return t.GQL.Elem != nil && isSlice
225}
226
227func (t *TypeReference) IsNamed() bool {
228	_, isSlice := t.GO.(*types.Named)
229	return isSlice
230}
231
232func (t *TypeReference) IsStruct() bool {
233	_, isStruct := t.GO.Underlying().(*types.Struct)
234	return isStruct
235}
236
237func (t *TypeReference) IsScalar() bool {
238	return t.Definition.Kind == ast.Scalar
239}
240
241func (t *TypeReference) UniquenessKey() string {
242	var nullability = "O"
243	if t.GQL.NonNull {
244		nullability = "N"
245	}
246
247	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
248}
249
250func (t *TypeReference) MarshalFunc() string {
251	if t.Definition == nil {
252		panic(errors.New("Definition missing for " + t.GQL.Name()))
253	}
254
255	if t.Definition.Kind == ast.InputObject {
256		return ""
257	}
258
259	return "marshal" + t.UniquenessKey()
260}
261
262func (t *TypeReference) UnmarshalFunc() string {
263	if t.Definition == nil {
264		panic(errors.New("Definition missing for " + t.GQL.Name()))
265	}
266
267	if !t.Definition.IsInputType() {
268		return ""
269	}
270
271	return "unmarshal" + t.UniquenessKey()
272}
273
274func (b *Binder) PushRef(ret *TypeReference) {
275	b.References = append(b.References, ret)
276}
277
278func isMap(t types.Type) bool {
279	if t == nil {
280		return true
281	}
282	_, ok := t.(*types.Map)
283	return ok
284}
285
286func isIntf(t types.Type) bool {
287	if t == nil {
288		return true
289	}
290	_, ok := t.(*types.Interface)
291	return ok
292}
293
294func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
295	var pkgName, typeName string
296	def := b.schema.Types[schemaType.Name()]
297	defer func() {
298		if err == nil && ret != nil {
299			b.PushRef(ret)
300		}
301	}()
302
303	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
304		return nil, fmt.Errorf("%s was not found", schemaType.Name())
305	}
306
307	for _, model := range b.cfg.Models[schemaType.Name()].Model {
308		if model == "map[string]interface{}" {
309			if !isMap(bindTarget) {
310				continue
311			}
312			return &TypeReference{
313				Definition: def,
314				GQL:        schemaType,
315				GO:         MapType,
316			}, nil
317		}
318
319		if model == "interface{}" {
320			if !isIntf(bindTarget) {
321				continue
322			}
323			return &TypeReference{
324				Definition: def,
325				GQL:        schemaType,
326				GO:         InterfaceType,
327			}, nil
328		}
329
330		pkgName, typeName = code.PkgAndType(model)
331		if pkgName == "" {
332			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
333		}
334
335		ref := &TypeReference{
336			Definition: def,
337			GQL:        schemaType,
338		}
339
340		obj, err := b.FindObject(pkgName, typeName)
341		if err != nil {
342			return nil, err
343		}
344
345		if fun, isFunc := obj.(*types.Func); isFunc {
346			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
347			ref.Marshaler = fun
348			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
349		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
350			ref.GO = obj.Type()
351			ref.IsMarshaler = true
352		} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
353			// Special case for named types wrapping strings. Used by default enum implementations.
354
355			ref.GO = obj.Type()
356			ref.CastType = underlying
357
358			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
359			if err != nil {
360				return nil, err
361			}
362
363			ref.Marshaler = underlyingRef.Marshaler
364			ref.Unmarshaler = underlyingRef.Unmarshaler
365		} else {
366			ref.GO = obj.Type()
367		}
368
369		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
370
371		if bindTarget != nil {
372			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
373				continue
374			}
375			ref.GO = bindTarget
376		}
377
378		return ref, nil
379	}
380
381	return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
382}
383
384func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
385	if t.Elem != nil {
386		child := b.CopyModifiersFromAst(t.Elem, base)
387		if _, isStruct := child.Underlying().(*types.Struct); isStruct {
388			child = types.NewPointer(child)
389		}
390		return types.NewSlice(child)
391	}
392
393	var isInterface bool
394	if named, ok := base.(*types.Named); ok {
395		_, isInterface = named.Underlying().(*types.Interface)
396	}
397
398	if !isInterface && !t.NonNull {
399		return types.NewPointer(base)
400	}
401
402	return base
403}
404
405func hasMethod(it types.Type, name string) bool {
406	if ptr, isPtr := it.(*types.Pointer); isPtr {
407		it = ptr.Elem()
408	}
409	namedType, ok := it.(*types.Named)
410	if !ok {
411		return false
412	}
413
414	for i := 0; i < namedType.NumMethods(); i++ {
415		if namedType.Method(i).Name() == name {
416			return true
417		}
418	}
419	return false
420}
421
422func basicUnderlying(it types.Type) *types.Basic {
423	if ptr, isPtr := it.(*types.Pointer); isPtr {
424		it = ptr.Elem()
425	}
426	namedType, ok := it.(*types.Named)
427	if !ok {
428		return nil
429	}
430
431	if basic, ok := namedType.Underlying().(*types.Basic); ok {
432		return basic
433	}
434
435	return nil
436}