1package config
  2
  3import (
  4	"fmt"
  5	"go/token"
  6	"go/types"
  7
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/99designs/gqlgen/internal/code"
 10	"github.com/pkg/errors"
 11	"github.com/vektah/gqlparser/ast"
 12	"golang.org/x/tools/go/packages"
 13)
 14
 15// Binder connects graphql types to golang types using static analysis
 16type Binder struct {
 17	pkgs       []*packages.Package
 18	schema     *ast.Schema
 19	cfg        *Config
 20	References []*TypeReference
 21}
 22
 23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
 24	pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
 25	if err != nil {
 26		return nil, err
 27	}
 28
 29	for _, p := range pkgs {
 30		for _, e := range p.Errors {
 31			if e.Kind == packages.ListError {
 32				return nil, p.Errors[0]
 33			}
 34		}
 35	}
 36
 37	return &Binder{
 38		pkgs:   pkgs,
 39		schema: s,
 40		cfg:    c,
 41	}, nil
 42}
 43
 44func (b *Binder) TypePosition(typ types.Type) token.Position {
 45	named, isNamed := typ.(*types.Named)
 46	if !isNamed {
 47		return token.Position{
 48			Filename: "unknown",
 49		}
 50	}
 51
 52	return b.ObjectPosition(named.Obj())
 53}
 54
 55func (b *Binder) ObjectPosition(typ types.Object) token.Position {
 56	if typ == nil {
 57		return token.Position{
 58			Filename: "unknown",
 59		}
 60	}
 61	pkg := b.getPkg(typ.Pkg().Path())
 62	return pkg.Fset.Position(typ.Pos())
 63}
 64
 65func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
 66	obj, err := b.FindObject(pkgName, typeName)
 67	if err != nil {
 68		return nil, err
 69	}
 70
 71	if fun, isFunc := obj.(*types.Func); isFunc {
 72		return fun.Type().(*types.Signature).Params().At(0).Type(), nil
 73	}
 74	return obj.Type(), nil
 75}
 76
 77func (b *Binder) getPkg(find string) *packages.Package {
 78	for _, p := range b.pkgs {
 79		if code.NormalizeVendor(find) == code.NormalizeVendor(p.PkgPath) {
 80			return p
 81		}
 82	}
 83	return nil
 84}
 85
 86var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
 87var InterfaceType = types.NewInterfaceType(nil, nil)
 88
 89func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
 90	models := b.cfg.Models[name].Model
 91	if len(models) == 0 {
 92		return nil, fmt.Errorf(name + " not found in typemap")
 93	}
 94
 95	if models[0] == "map[string]interface{}" {
 96		return MapType, nil
 97	}
 98
 99	if models[0] == "interface{}" {
100		return InterfaceType, nil
101	}
102
103	pkgName, typeName := code.PkgAndType(models[0])
104	if pkgName == "" {
105		return nil, fmt.Errorf("missing package name for %s", name)
106	}
107
108	obj, err := b.FindObject(pkgName, typeName)
109	if err != nil {
110		return nil, err
111	}
112
113	return obj.Type(), nil
114}
115
116func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
117	if pkgName == "" {
118		return nil, fmt.Errorf("package cannot be nil")
119	}
120	fullName := typeName
121	if pkgName != "" {
122		fullName = pkgName + "." + typeName
123	}
124
125	pkg := b.getPkg(pkgName)
126	if pkg == nil {
127		return nil, errors.Errorf("required package was not loaded: %s", fullName)
128	}
129
130	// function based marshalers take precedence
131	for astNode, def := range pkg.TypesInfo.Defs {
132		// only look at defs in the top scope
133		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
134			continue
135		}
136
137		if astNode.Name == "Marshal"+typeName {
138			return def, nil
139		}
140	}
141
142	// then look for types directly
143	for astNode, def := range pkg.TypesInfo.Defs {
144		// only look at defs in the top scope
145		if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146			continue
147		}
148
149		if astNode.Name == typeName {
150			return def, nil
151		}
152	}
153
154	return nil, errors.Errorf("unable to find type %s\n", fullName)
155}
156
157func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
158	newRef := &TypeReference{
159		GO:          types.NewPointer(ref.GO),
160		GQL:         ref.GQL,
161		CastType:    ref.CastType,
162		Definition:  ref.Definition,
163		Unmarshaler: ref.Unmarshaler,
164		Marshaler:   ref.Marshaler,
165		IsMarshaler: ref.IsMarshaler,
166	}
167
168	b.References = append(b.References, newRef)
169	return newRef
170}
171
172// TypeReference is used by args and field types. The Definition can refer to both input and output types.
173type TypeReference struct {
174	Definition  *ast.Definition
175	GQL         *ast.Type
176	GO          types.Type
177	CastType    types.Type  // Before calling marshalling functions cast from/to this base type
178	Marshaler   *types.Func // When using external marshalling functions this will point to the Marshal function
179	Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
180	IsMarshaler bool        // Does the type implement graphql.Marshaler and graphql.Unmarshaler
181}
182
183func (ref *TypeReference) Elem() *TypeReference {
184	if p, isPtr := ref.GO.(*types.Pointer); isPtr {
185		return &TypeReference{
186			GO:          p.Elem(),
187			GQL:         ref.GQL,
188			CastType:    ref.CastType,
189			Definition:  ref.Definition,
190			Unmarshaler: ref.Unmarshaler,
191			Marshaler:   ref.Marshaler,
192			IsMarshaler: ref.IsMarshaler,
193		}
194	}
195
196	if ref.IsSlice() {
197		return &TypeReference{
198			GO:          ref.GO.(*types.Slice).Elem(),
199			GQL:         ref.GQL.Elem,
200			CastType:    ref.CastType,
201			Definition:  ref.Definition,
202			Unmarshaler: ref.Unmarshaler,
203			Marshaler:   ref.Marshaler,
204			IsMarshaler: ref.IsMarshaler,
205		}
206	}
207	return nil
208}
209
210func (t *TypeReference) IsPtr() bool {
211	_, isPtr := t.GO.(*types.Pointer)
212	return isPtr
213}
214
215func (t *TypeReference) IsNilable() bool {
216	_, isPtr := t.GO.(*types.Pointer)
217	_, isMap := t.GO.(*types.Map)
218	_, isInterface := t.GO.(*types.Interface)
219	return isPtr || isMap || isInterface
220}
221
222func (t *TypeReference) IsSlice() bool {
223	_, isSlice := t.GO.(*types.Slice)
224	return t.GQL.Elem != nil && isSlice
225}
226
227func (t *TypeReference) IsNamed() bool {
228	_, isSlice := t.GO.(*types.Named)
229	return isSlice
230}
231
232func (t *TypeReference) IsStruct() bool {
233	_, isStruct := t.GO.Underlying().(*types.Struct)
234	return isStruct
235}
236
237func (t *TypeReference) IsScalar() bool {
238	return t.Definition.Kind == ast.Scalar
239}
240
241func (t *TypeReference) HasIsZero() bool {
242	it := t.GO
243	if ptr, isPtr := it.(*types.Pointer); isPtr {
244		it = ptr.Elem()
245	}
246	namedType, ok := it.(*types.Named)
247	if !ok {
248		return false
249	}
250
251	for i := 0; i < namedType.NumMethods(); i++ {
252		switch namedType.Method(i).Name() {
253		case "IsZero":
254			return true
255		}
256	}
257	return false
258}
259
260func (t *TypeReference) UniquenessKey() string {
261	var nullability = "O"
262	if t.GQL.NonNull {
263		nullability = "N"
264	}
265
266	return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
267}
268
269func (t *TypeReference) MarshalFunc() string {
270	if t.Definition == nil {
271		panic(errors.New("Definition missing for " + t.GQL.Name()))
272	}
273
274	if t.Definition.Kind == ast.InputObject {
275		return ""
276	}
277
278	return "marshal" + t.UniquenessKey()
279}
280
281func (t *TypeReference) UnmarshalFunc() string {
282	if t.Definition == nil {
283		panic(errors.New("Definition missing for " + t.GQL.Name()))
284	}
285
286	if !t.Definition.IsInputType() {
287		return ""
288	}
289
290	return "unmarshal" + t.UniquenessKey()
291}
292
293func (b *Binder) PushRef(ret *TypeReference) {
294	b.References = append(b.References, ret)
295}
296
297func isMap(t types.Type) bool {
298	if t == nil {
299		return true
300	}
301	_, ok := t.(*types.Map)
302	return ok
303}
304
305func isIntf(t types.Type) bool {
306	if t == nil {
307		return true
308	}
309	_, ok := t.(*types.Interface)
310	return ok
311}
312
313func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
314	var pkgName, typeName string
315	def := b.schema.Types[schemaType.Name()]
316	defer func() {
317		if err == nil && ret != nil {
318			b.PushRef(ret)
319		}
320	}()
321
322	if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
323		return nil, fmt.Errorf("%s was not found", schemaType.Name())
324	}
325
326	for _, model := range b.cfg.Models[schemaType.Name()].Model {
327		if model == "map[string]interface{}" {
328			if !isMap(bindTarget) {
329				continue
330			}
331			return &TypeReference{
332				Definition: def,
333				GQL:        schemaType,
334				GO:         MapType,
335			}, nil
336		}
337
338		if model == "interface{}" {
339			if !isIntf(bindTarget) {
340				continue
341			}
342			return &TypeReference{
343				Definition: def,
344				GQL:        schemaType,
345				GO:         InterfaceType,
346			}, nil
347		}
348
349		pkgName, typeName = code.PkgAndType(model)
350		if pkgName == "" {
351			return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
352		}
353
354		ref := &TypeReference{
355			Definition: def,
356			GQL:        schemaType,
357		}
358
359		obj, err := b.FindObject(pkgName, typeName)
360		if err != nil {
361			return nil, err
362		}
363
364		if fun, isFunc := obj.(*types.Func); isFunc {
365			ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
366			ref.Marshaler = fun
367			ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
368		} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
369			ref.GO = obj.Type()
370			ref.IsMarshaler = true
371		} else if underlying := basicUnderlying(obj.Type()); underlying != nil && underlying.Kind() == types.String {
372			// Special case for named types wrapping strings. Used by default enum implementations.
373
374			ref.GO = obj.Type()
375			ref.CastType = underlying
376
377			underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
378			if err != nil {
379				return nil, err
380			}
381
382			ref.Marshaler = underlyingRef.Marshaler
383			ref.Unmarshaler = underlyingRef.Unmarshaler
384		} else {
385			ref.GO = obj.Type()
386		}
387
388		ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
389
390		if bindTarget != nil {
391			if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
392				continue
393			}
394			ref.GO = bindTarget
395		}
396
397		return ref, nil
398	}
399
400	return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
401}
402
403func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
404	if t.Elem != nil {
405		return types.NewSlice(b.CopyModifiersFromAst(t.Elem, base))
406	}
407
408	var isInterface bool
409	if named, ok := base.(*types.Named); ok {
410		_, isInterface = named.Underlying().(*types.Interface)
411	}
412
413	if !isInterface && !t.NonNull {
414		return types.NewPointer(base)
415	}
416
417	return base
418}
419
420func hasMethod(it types.Type, name string) bool {
421	if ptr, isPtr := it.(*types.Pointer); isPtr {
422		it = ptr.Elem()
423	}
424	namedType, ok := it.(*types.Named)
425	if !ok {
426		return false
427	}
428
429	for i := 0; i < namedType.NumMethods(); i++ {
430		if namedType.Method(i).Name() == name {
431			return true
432		}
433	}
434	return false
435}
436
437func basicUnderlying(it types.Type) *types.Basic {
438	if ptr, isPtr := it.(*types.Pointer); isPtr {
439		it = ptr.Elem()
440	}
441	namedType, ok := it.(*types.Named)
442	if !ok {
443		return nil
444	}
445
446	if basic, ok := namedType.Underlying().(*types.Basic); ok {
447		return basic
448	}
449
450	return nil
451}