1package config
2
3import (
4 "fmt"
5 "go/token"
6 "go/types"
7
8 "github.com/99designs/gqlgen/codegen/templates"
9 "github.com/99designs/gqlgen/internal/code"
10 "github.com/pkg/errors"
11 "github.com/vektah/gqlparser/ast"
12 "golang.org/x/tools/go/packages"
13)
14
15// Binder connects graphql types to golang types using static analysis
16type Binder struct {
17 pkgs []*packages.Package
18 schema *ast.Schema
19 cfg *Config
20 References []*TypeReference
21}
22
23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
24 pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
25 if err != nil {
26 return nil, err
27 }
28
29 for _, p := range pkgs {
30 for _, e := range p.Errors {
31 if e.Kind == packages.ListError {
32 return nil, p.Errors[0]
33 }
34 }
35 }
36
37 return &Binder{
38 pkgs: pkgs,
39 schema: s,
40 cfg: c,
41 }, nil
42}
43
44func (b *Binder) TypePosition(typ types.Type) token.Position {
45 named, isNamed := typ.(*types.Named)
46 if !isNamed {
47 return token.Position{
48 Filename: "unknown",
49 }
50 }
51
52 return b.ObjectPosition(named.Obj())
53}
54
55func (b *Binder) ObjectPosition(typ types.Object) token.Position {
56 if typ == nil {
57 return token.Position{
58 Filename: "unknown",
59 }
60 }
61 pkg := b.getPkg(typ.Pkg().Path())
62 return pkg.Fset.Position(typ.Pos())
63}
64
65func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
66 obj, err := b.FindObject(pkgName, typeName)
67 if err != nil {
68 return nil, err
69 }
70
71 if fun, isFunc := obj.(*types.Func); isFunc {
72 return fun.Type().(*types.Signature).Params().At(0).Type(), nil
73 }
74 return obj.Type(), nil
75}
76
77func (b *Binder) getPkg(find string) *packages.Package {
78 for _, p := range b.pkgs {
79 if code.NormalizeVendor(find) == code.NormalizeVendor(p.PkgPath) {
80 return p
81 }
82 }
83 return nil
84}
85
86var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
87var InterfaceType = types.NewInterfaceType(nil, nil)
88
89func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
90 models := b.cfg.Models[name].Model
91 if len(models) == 0 {
92 return nil, fmt.Errorf(name + " not found in typemap")
93 }
94
95 if models[0] == "map[string]interface{}" {
96 return MapType, nil
97 }
98
99 if models[0] == "interface{}" {
100 return InterfaceType, nil
101 }
102
103 pkgName, typeName := code.PkgAndType(models[0])
104 if pkgName == "" {
105 return nil, fmt.Errorf("missing package name for %s", name)
106 }
107
108 obj, err := b.FindObject(pkgName, typeName)
109 if err != nil {
110 return nil, err
111 }
112
113 return obj.Type(), nil
114}
115
116func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
117 if pkgName == "" {
118 return nil, fmt.Errorf("package cannot be nil")
119 }
120 fullName := typeName
121 if pkgName != "" {
122 fullName = pkgName + "." + typeName
123 }
124
125 pkg := b.getPkg(pkgName)
126 if pkg == nil {
127 return nil, errors.Errorf("required package was not loaded: %s", fullName)
128 }
129
130 // function based marshalers take precedence
131 for astNode, def := range pkg.TypesInfo.Defs {
132 // only look at defs in the top scope
133 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
134 continue
135 }
136
137 if astNode.Name == "Marshal"+typeName {
138 return def, nil
139 }
140 }
141
142 // then look for types directly
143 for astNode, def := range pkg.TypesInfo.Defs {
144 // only look at defs in the top scope
145 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146 continue
147 }
148
149 if astNode.Name == typeName {
150 return def, nil
151 }
152 }
153
154 return nil, errors.Errorf("unable to find type %s\n", fullName)
155}
156
157func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
158 newRef := &TypeReference{
159 GO: types.NewPointer(ref.GO),
160 GQL: ref.GQL,
161 CastType: ref.CastType,
162 Definition: ref.Definition,
163 Unmarshaler: ref.Unmarshaler,
164 Marshaler: ref.Marshaler,
165 IsMarshaler: ref.IsMarshaler,
166 }
167
168 b.References = append(b.References, newRef)
169 return newRef
170}
171
172// TypeReference is used by args and field types. The Definition can refer to both input and output types.
173type TypeReference struct {
174 Definition *ast.Definition
175 GQL *ast.Type
176 GO types.Type
177 CastType types.Type // Before calling marshalling functions cast from/to this base type
178 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
179 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
180 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
181}
182
183func (ref *TypeReference) Elem() *TypeReference {
184 if p, isPtr := ref.GO.(*types.Pointer); isPtr {
185 return &TypeReference{
186 GO: p.Elem(),
187 GQL: ref.GQL,
188 CastType: ref.CastType,
189 Definition: ref.Definition,
190 Unmarshaler: ref.Unmarshaler,
191 Marshaler: ref.Marshaler,
192 IsMarshaler: ref.IsMarshaler,
193 }
194 }
195
196 if ref.IsSlice() {
197 return &TypeReference{
198 GO: ref.GO.(*types.Slice).Elem(),
199 GQL: ref.GQL.Elem,
200 CastType: ref.CastType,
201 Definition: ref.Definition,
202 Unmarshaler: ref.Unmarshaler,
203 Marshaler: ref.Marshaler,
204 IsMarshaler: ref.IsMarshaler,
205 }
206 }
207 return nil
208}
209
210func (t *TypeReference) IsPtr() bool {
211 _, isPtr := t.GO.(*types.Pointer)
212 return isPtr
213}
214
215func (t *TypeReference) IsNilable() bool {
216 _, isPtr := t.GO.(*types.Pointer)
217 _, isMap := t.GO.(*types.Map)
218 _, isInterface := t.GO.(*types.Interface)
219 return isPtr || isMap || isInterface
220}
221
222func (t *TypeReference) IsSlice() bool {
223 _, isSlice := t.GO.(*types.Slice)
224 return t.GQL.Elem != nil && isSlice
225}
226
227func (t *TypeReference) IsNamed() bool {
228 _, isSlice := t.GO.(*types.Named)
229 return isSlice
230}
231
232func (t *TypeReference) IsStruct() bool {
233 _, isStruct := t.GO.Underlying().(*types.Struct)
234 return isStruct
235}
236
237func (t *TypeReference) IsScalar() bool {
238 return t.Definition.Kind == ast.Scalar
239}
240
241func (t *TypeReference) HasIsZero() bool {
242 it := t.GO
243 if ptr, isPtr := it.(*types.Pointer); isPtr {
244 it = ptr.Elem()
245 }
246 namedType, ok := it.(*types.Named)
247 if !ok {
248 return false
249 }
250
251 for i := 0; i < namedType.NumMethods(); i++ {
252 switch namedType.Method(i).Name() {
253 case "IsZero":
254 return true
255 }
256 }
257 return false
258}
259
260func (t *TypeReference) UniquenessKey() string {
261 var nullability = "O"
262 if t.GQL.NonNull {
263 nullability = "N"
264 }
265
266 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
267}
268
269func (t *TypeReference) MarshalFunc() string {
270 if t.Definition == nil {
271 panic(errors.New("Definition missing for " + t.GQL.Name()))
272 }
273
274 if t.Definition.Kind == ast.InputObject {
275 return ""
276 }
277
278 return "marshal" + t.UniquenessKey()
279}
280
281func (t *TypeReference) UnmarshalFunc() string {
282 if t.Definition == nil {
283 panic(errors.New("Definition missing for " + t.GQL.Name()))
284 }
285
286 if !t.Definition.IsInputType() {
287 return ""
288 }
289
290 return "unmarshal" + t.UniquenessKey()
291}
292
293func (b *Binder) PushRef(ret *TypeReference) {
294 b.References = append(b.References, ret)
295}
296
297func isMap(t types.Type) bool {
298 if t == nil {
299 return true
300 }
301 _, ok := t.(*types.Map)
302 return ok
303}
304
305func isIntf(t types.Type) bool {
306 if t == nil {
307 return true
308 }
309 _, ok := t.(*types.Interface)
310 return ok
311}
312
313func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
314 var pkgName, typeName string
315 def := b.schema.Types[schemaType.Name()]
316 defer func() {
317 if err == nil && ret != nil {
318 b.PushRef(ret)
319 }
320 }()
321
322 if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
323 return nil, fmt.Errorf("%s was not found", schemaType.Name())
324 }
325
326 for _, model := range b.cfg.Models[schemaType.Name()].Model {
327 if model == "map[string]interface{}" {
328 if !isMap(bindTarget) {
329 continue
330 }
331 return &TypeReference{
332 Definition: def,
333 GQL: schemaType,
334 GO: MapType,
335 }, nil
336 }
337
338 if model == "interface{}" {
339 if !isIntf(bindTarget) {
340 continue
341 }
342 return &TypeReference{
343 Definition: def,
344 GQL: schemaType,
345 GO: InterfaceType,
346 }, nil
347 }
348
349 pkgName, typeName = code.PkgAndType(model)
350 if pkgName == "" {
351 return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
352 }
353
354 ref := &TypeReference{
355 Definition: def,
356 GQL: schemaType,
357 }
358
359 obj, err := b.FindObject(pkgName, typeName)
360 if err != nil {
361 return nil, err
362 }
363
364 if fun, isFunc := obj.(*types.Func); isFunc {
365 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
366 ref.Marshaler = fun
367 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
368 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
369 ref.GO = obj.Type()
370 ref.IsMarshaler = true
371 } else if underlying := basicUnderlying(obj.Type()); underlying != nil && underlying.Kind() == types.String {
372 // Special case for named types wrapping strings. Used by default enum implementations.
373
374 ref.GO = obj.Type()
375 ref.CastType = underlying
376
377 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
378 if err != nil {
379 return nil, err
380 }
381
382 ref.Marshaler = underlyingRef.Marshaler
383 ref.Unmarshaler = underlyingRef.Unmarshaler
384 } else {
385 ref.GO = obj.Type()
386 }
387
388 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
389
390 if bindTarget != nil {
391 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
392 continue
393 }
394 ref.GO = bindTarget
395 }
396
397 return ref, nil
398 }
399
400 return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
401}
402
403func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
404 if t.Elem != nil {
405 return types.NewSlice(b.CopyModifiersFromAst(t.Elem, base))
406 }
407
408 var isInterface bool
409 if named, ok := base.(*types.Named); ok {
410 _, isInterface = named.Underlying().(*types.Interface)
411 }
412
413 if !isInterface && !t.NonNull {
414 return types.NewPointer(base)
415 }
416
417 return base
418}
419
420func hasMethod(it types.Type, name string) bool {
421 if ptr, isPtr := it.(*types.Pointer); isPtr {
422 it = ptr.Elem()
423 }
424 namedType, ok := it.(*types.Named)
425 if !ok {
426 return false
427 }
428
429 for i := 0; i < namedType.NumMethods(); i++ {
430 if namedType.Method(i).Name() == name {
431 return true
432 }
433 }
434 return false
435}
436
437func basicUnderlying(it types.Type) *types.Basic {
438 if ptr, isPtr := it.(*types.Pointer); isPtr {
439 it = ptr.Elem()
440 }
441 namedType, ok := it.(*types.Named)
442 if !ok {
443 return nil
444 }
445
446 if basic, ok := namedType.Underlying().(*types.Basic); ok {
447 return basic
448 }
449
450 return nil
451}