1package config
2
3import (
4 "fmt"
5 "go/token"
6 "go/types"
7
8 "github.com/99designs/gqlgen/codegen/templates"
9 "github.com/99designs/gqlgen/internal/code"
10 "github.com/pkg/errors"
11 "github.com/vektah/gqlparser/ast"
12 "golang.org/x/tools/go/packages"
13)
14
15// Binder connects graphql types to golang types using static analysis
16type Binder struct {
17 pkgs map[string]*packages.Package
18 schema *ast.Schema
19 cfg *Config
20 References []*TypeReference
21}
22
23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
24 pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
25 if err != nil {
26 return nil, err
27 }
28
29 mp := map[string]*packages.Package{}
30 for _, p := range pkgs {
31 populatePkg(mp, p)
32 for _, e := range p.Errors {
33 if e.Kind == packages.ListError {
34 return nil, p.Errors[0]
35 }
36 }
37 }
38
39 return &Binder{
40 pkgs: mp,
41 schema: s,
42 cfg: c,
43 }, nil
44}
45
46func populatePkg(mp map[string]*packages.Package, p *packages.Package) {
47 imp := code.NormalizeVendor(p.PkgPath)
48 if _, ok := mp[imp]; ok {
49 return
50 }
51 mp[imp] = p
52 for _, p := range p.Imports {
53 populatePkg(mp, p)
54 }
55}
56
57func (b *Binder) TypePosition(typ types.Type) token.Position {
58 named, isNamed := typ.(*types.Named)
59 if !isNamed {
60 return token.Position{
61 Filename: "unknown",
62 }
63 }
64
65 return b.ObjectPosition(named.Obj())
66}
67
68func (b *Binder) ObjectPosition(typ types.Object) token.Position {
69 if typ == nil {
70 return token.Position{
71 Filename: "unknown",
72 }
73 }
74 pkg := b.getPkg(typ.Pkg().Path())
75 return pkg.Fset.Position(typ.Pos())
76}
77
78func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
79 obj, err := b.FindObject(pkgName, typeName)
80 if err != nil {
81 return nil, err
82 }
83
84 if fun, isFunc := obj.(*types.Func); isFunc {
85 return fun.Type().(*types.Signature).Params().At(0).Type(), nil
86 }
87 return obj.Type(), nil
88}
89
90func (b *Binder) getPkg(find string) *packages.Package {
91 imp := code.NormalizeVendor(find)
92 if p, ok := b.pkgs[imp]; ok {
93 return p
94 }
95 return nil
96}
97
98var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
99var InterfaceType = types.NewInterfaceType(nil, nil)
100
101func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
102 models := b.cfg.Models[name].Model
103 if len(models) == 0 {
104 return nil, fmt.Errorf(name + " not found in typemap")
105 }
106
107 if models[0] == "map[string]interface{}" {
108 return MapType, nil
109 }
110
111 if models[0] == "interface{}" {
112 return InterfaceType, nil
113 }
114
115 pkgName, typeName := code.PkgAndType(models[0])
116 if pkgName == "" {
117 return nil, fmt.Errorf("missing package name for %s", name)
118 }
119
120 obj, err := b.FindObject(pkgName, typeName)
121 if err != nil {
122 return nil, err
123 }
124
125 return obj.Type(), nil
126}
127
128func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
129 if pkgName == "" {
130 return nil, fmt.Errorf("package cannot be nil")
131 }
132 fullName := typeName
133 if pkgName != "" {
134 fullName = pkgName + "." + typeName
135 }
136
137 pkg := b.getPkg(pkgName)
138 if pkg == nil {
139 return nil, errors.Errorf("required package was not loaded: %s", fullName)
140 }
141
142 // function based marshalers take precedence
143 for astNode, def := range pkg.TypesInfo.Defs {
144 // only look at defs in the top scope
145 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146 continue
147 }
148
149 if astNode.Name == "Marshal"+typeName {
150 return def, nil
151 }
152 }
153
154 // then look for types directly
155 for astNode, def := range pkg.TypesInfo.Defs {
156 // only look at defs in the top scope
157 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
158 continue
159 }
160
161 if astNode.Name == typeName {
162 return def, nil
163 }
164 }
165
166 return nil, errors.Errorf("unable to find type %s\n", fullName)
167}
168
169func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
170 newRef := &TypeReference{
171 GO: types.NewPointer(ref.GO),
172 GQL: ref.GQL,
173 CastType: ref.CastType,
174 Definition: ref.Definition,
175 Unmarshaler: ref.Unmarshaler,
176 Marshaler: ref.Marshaler,
177 IsMarshaler: ref.IsMarshaler,
178 }
179
180 b.References = append(b.References, newRef)
181 return newRef
182}
183
184// TypeReference is used by args and field types. The Definition can refer to both input and output types.
185type TypeReference struct {
186 Definition *ast.Definition
187 GQL *ast.Type
188 GO types.Type
189 CastType types.Type // Before calling marshalling functions cast from/to this base type
190 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
191 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
192 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
193}
194
195func (ref *TypeReference) Elem() *TypeReference {
196 if p, isPtr := ref.GO.(*types.Pointer); isPtr {
197 return &TypeReference{
198 GO: p.Elem(),
199 GQL: ref.GQL,
200 CastType: ref.CastType,
201 Definition: ref.Definition,
202 Unmarshaler: ref.Unmarshaler,
203 Marshaler: ref.Marshaler,
204 IsMarshaler: ref.IsMarshaler,
205 }
206 }
207
208 if ref.IsSlice() {
209 return &TypeReference{
210 GO: ref.GO.(*types.Slice).Elem(),
211 GQL: ref.GQL.Elem,
212 CastType: ref.CastType,
213 Definition: ref.Definition,
214 Unmarshaler: ref.Unmarshaler,
215 Marshaler: ref.Marshaler,
216 IsMarshaler: ref.IsMarshaler,
217 }
218 }
219 return nil
220}
221
222func (t *TypeReference) IsPtr() bool {
223 _, isPtr := t.GO.(*types.Pointer)
224 return isPtr
225}
226
227func (t *TypeReference) IsNilable() bool {
228 _, isPtr := t.GO.(*types.Pointer)
229 _, isMap := t.GO.(*types.Map)
230 _, isInterface := t.GO.(*types.Interface)
231 return isPtr || isMap || isInterface
232}
233
234func (t *TypeReference) IsSlice() bool {
235 _, isSlice := t.GO.(*types.Slice)
236 return t.GQL.Elem != nil && isSlice
237}
238
239func (t *TypeReference) IsNamed() bool {
240 _, isSlice := t.GO.(*types.Named)
241 return isSlice
242}
243
244func (t *TypeReference) IsStruct() bool {
245 _, isStruct := t.GO.Underlying().(*types.Struct)
246 return isStruct
247}
248
249func (t *TypeReference) IsScalar() bool {
250 return t.Definition.Kind == ast.Scalar
251}
252
253func (t *TypeReference) UniquenessKey() string {
254 var nullability = "O"
255 if t.GQL.NonNull {
256 nullability = "N"
257 }
258
259 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
260}
261
262func (t *TypeReference) MarshalFunc() string {
263 if t.Definition == nil {
264 panic(errors.New("Definition missing for " + t.GQL.Name()))
265 }
266
267 if t.Definition.Kind == ast.InputObject {
268 return ""
269 }
270
271 return "marshal" + t.UniquenessKey()
272}
273
274func (t *TypeReference) UnmarshalFunc() string {
275 if t.Definition == nil {
276 panic(errors.New("Definition missing for " + t.GQL.Name()))
277 }
278
279 if !t.Definition.IsInputType() {
280 return ""
281 }
282
283 return "unmarshal" + t.UniquenessKey()
284}
285
286func (b *Binder) PushRef(ret *TypeReference) {
287 b.References = append(b.References, ret)
288}
289
290func isMap(t types.Type) bool {
291 if t == nil {
292 return true
293 }
294 _, ok := t.(*types.Map)
295 return ok
296}
297
298func isIntf(t types.Type) bool {
299 if t == nil {
300 return true
301 }
302 _, ok := t.(*types.Interface)
303 return ok
304}
305
306func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
307 var pkgName, typeName string
308 def := b.schema.Types[schemaType.Name()]
309 defer func() {
310 if err == nil && ret != nil {
311 b.PushRef(ret)
312 }
313 }()
314
315 if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
316 return nil, fmt.Errorf("%s was not found", schemaType.Name())
317 }
318
319 for _, model := range b.cfg.Models[schemaType.Name()].Model {
320 if model == "map[string]interface{}" {
321 if !isMap(bindTarget) {
322 continue
323 }
324 return &TypeReference{
325 Definition: def,
326 GQL: schemaType,
327 GO: MapType,
328 }, nil
329 }
330
331 if model == "interface{}" {
332 if !isIntf(bindTarget) {
333 continue
334 }
335 return &TypeReference{
336 Definition: def,
337 GQL: schemaType,
338 GO: InterfaceType,
339 }, nil
340 }
341
342 pkgName, typeName = code.PkgAndType(model)
343 if pkgName == "" {
344 return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
345 }
346
347 ref := &TypeReference{
348 Definition: def,
349 GQL: schemaType,
350 }
351
352 obj, err := b.FindObject(pkgName, typeName)
353 if err != nil {
354 return nil, err
355 }
356
357 if fun, isFunc := obj.(*types.Func); isFunc {
358 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
359 ref.Marshaler = fun
360 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
361 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
362 ref.GO = obj.Type()
363 ref.IsMarshaler = true
364 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
365 // Special case for named types wrapping strings. Used by default enum implementations.
366
367 ref.GO = obj.Type()
368 ref.CastType = underlying
369
370 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
371 if err != nil {
372 return nil, err
373 }
374
375 ref.Marshaler = underlyingRef.Marshaler
376 ref.Unmarshaler = underlyingRef.Unmarshaler
377 } else {
378 ref.GO = obj.Type()
379 }
380
381 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
382
383 if bindTarget != nil {
384 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
385 continue
386 }
387 ref.GO = bindTarget
388 }
389
390 return ref, nil
391 }
392
393 return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
394}
395
396func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
397 if t.Elem != nil {
398 child := b.CopyModifiersFromAst(t.Elem, base)
399 if _, isStruct := child.Underlying().(*types.Struct); isStruct {
400 child = types.NewPointer(child)
401 }
402 return types.NewSlice(child)
403 }
404
405 var isInterface bool
406 if named, ok := base.(*types.Named); ok {
407 _, isInterface = named.Underlying().(*types.Interface)
408 }
409
410 if !isInterface && !t.NonNull {
411 return types.NewPointer(base)
412 }
413
414 return base
415}
416
417func hasMethod(it types.Type, name string) bool {
418 if ptr, isPtr := it.(*types.Pointer); isPtr {
419 it = ptr.Elem()
420 }
421 namedType, ok := it.(*types.Named)
422 if !ok {
423 return false
424 }
425
426 for i := 0; i < namedType.NumMethods(); i++ {
427 if namedType.Method(i).Name() == name {
428 return true
429 }
430 }
431 return false
432}
433
434func basicUnderlying(it types.Type) *types.Basic {
435 if ptr, isPtr := it.(*types.Pointer); isPtr {
436 it = ptr.Elem()
437 }
438 namedType, ok := it.(*types.Named)
439 if !ok {
440 return nil
441 }
442
443 if basic, ok := namedType.Underlying().(*types.Basic); ok {
444 return basic
445 }
446
447 return nil
448}