1package config
2
3import (
4 "fmt"
5 "go/token"
6 "go/types"
7
8 "github.com/99designs/gqlgen/codegen/templates"
9 "github.com/99designs/gqlgen/internal/code"
10 "github.com/pkg/errors"
11 "github.com/vektah/gqlparser/ast"
12 "golang.org/x/tools/go/packages"
13)
14
15// Binder connects graphql types to golang types using static analysis
16type Binder struct {
17 pkgs []*packages.Package
18 schema *ast.Schema
19 cfg *Config
20 References []*TypeReference
21}
22
23func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
24 pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
25 if err != nil {
26 return nil, err
27 }
28
29 for _, p := range pkgs {
30 for _, e := range p.Errors {
31 if e.Kind == packages.ListError {
32 return nil, p.Errors[0]
33 }
34 }
35 }
36
37 return &Binder{
38 pkgs: pkgs,
39 schema: s,
40 cfg: c,
41 }, nil
42}
43
44func (b *Binder) TypePosition(typ types.Type) token.Position {
45 named, isNamed := typ.(*types.Named)
46 if !isNamed {
47 return token.Position{
48 Filename: "unknown",
49 }
50 }
51
52 return b.ObjectPosition(named.Obj())
53}
54
55func (b *Binder) ObjectPosition(typ types.Object) token.Position {
56 if typ == nil {
57 return token.Position{
58 Filename: "unknown",
59 }
60 }
61 pkg := b.getPkg(typ.Pkg().Path())
62 return pkg.Fset.Position(typ.Pos())
63}
64
65func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
66 obj, err := b.FindObject(pkgName, typeName)
67 if err != nil {
68 return nil, err
69 }
70
71 if fun, isFunc := obj.(*types.Func); isFunc {
72 return fun.Type().(*types.Signature).Params().At(0).Type(), nil
73 }
74 return obj.Type(), nil
75}
76
77func (b *Binder) getPkg(find string) *packages.Package {
78 for _, p := range b.pkgs {
79 if code.NormalizeVendor(find) == code.NormalizeVendor(p.PkgPath) {
80 return p
81 }
82 }
83 return nil
84}
85
86var MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
87var InterfaceType = types.NewInterfaceType(nil, nil)
88
89func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
90 models := b.cfg.Models[name].Model
91 if len(models) == 0 {
92 return nil, fmt.Errorf(name + " not found in typemap")
93 }
94
95 if models[0] == "map[string]interface{}" {
96 return MapType, nil
97 }
98
99 if models[0] == "interface{}" {
100 return InterfaceType, nil
101 }
102
103 pkgName, typeName := code.PkgAndType(models[0])
104 if pkgName == "" {
105 return nil, fmt.Errorf("missing package name for %s", name)
106 }
107
108 obj, err := b.FindObject(pkgName, typeName)
109 if err != nil {
110 return nil, err
111 }
112
113 return obj.Type(), nil
114}
115
116func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
117 if pkgName == "" {
118 return nil, fmt.Errorf("package cannot be nil")
119 }
120 fullName := typeName
121 if pkgName != "" {
122 fullName = pkgName + "." + typeName
123 }
124
125 pkg := b.getPkg(pkgName)
126 if pkg == nil {
127 return nil, errors.Errorf("required package was not loaded: %s", fullName)
128 }
129
130 // function based marshalers take precedence
131 for astNode, def := range pkg.TypesInfo.Defs {
132 // only look at defs in the top scope
133 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
134 continue
135 }
136
137 if astNode.Name == "Marshal"+typeName {
138 return def, nil
139 }
140 }
141
142 // then look for types directly
143 for astNode, def := range pkg.TypesInfo.Defs {
144 // only look at defs in the top scope
145 if def == nil || def.Parent() == nil || def.Parent() != pkg.Types.Scope() {
146 continue
147 }
148
149 if astNode.Name == typeName {
150 return def, nil
151 }
152 }
153
154 return nil, errors.Errorf("unable to find type %s\n", fullName)
155}
156
157func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
158 newRef := &TypeReference{
159 GO: types.NewPointer(ref.GO),
160 GQL: ref.GQL,
161 CastType: ref.CastType,
162 Definition: ref.Definition,
163 Unmarshaler: ref.Unmarshaler,
164 Marshaler: ref.Marshaler,
165 IsMarshaler: ref.IsMarshaler,
166 }
167
168 b.References = append(b.References, newRef)
169 return newRef
170}
171
172// TypeReference is used by args and field types. The Definition can refer to both input and output types.
173type TypeReference struct {
174 Definition *ast.Definition
175 GQL *ast.Type
176 GO types.Type
177 CastType types.Type // Before calling marshalling functions cast from/to this base type
178 Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
179 Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
180 IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
181}
182
183func (ref *TypeReference) Elem() *TypeReference {
184 if p, isPtr := ref.GO.(*types.Pointer); isPtr {
185 return &TypeReference{
186 GO: p.Elem(),
187 GQL: ref.GQL,
188 CastType: ref.CastType,
189 Definition: ref.Definition,
190 Unmarshaler: ref.Unmarshaler,
191 Marshaler: ref.Marshaler,
192 IsMarshaler: ref.IsMarshaler,
193 }
194 }
195
196 if ref.IsSlice() {
197 return &TypeReference{
198 GO: ref.GO.(*types.Slice).Elem(),
199 GQL: ref.GQL.Elem,
200 CastType: ref.CastType,
201 Definition: ref.Definition,
202 Unmarshaler: ref.Unmarshaler,
203 Marshaler: ref.Marshaler,
204 IsMarshaler: ref.IsMarshaler,
205 }
206 }
207 return nil
208}
209
210func (t *TypeReference) IsPtr() bool {
211 _, isPtr := t.GO.(*types.Pointer)
212 return isPtr
213}
214
215func (t *TypeReference) IsNilable() bool {
216 _, isPtr := t.GO.(*types.Pointer)
217 _, isMap := t.GO.(*types.Map)
218 _, isInterface := t.GO.(*types.Interface)
219 return isPtr || isMap || isInterface
220}
221
222func (t *TypeReference) IsSlice() bool {
223 _, isSlice := t.GO.(*types.Slice)
224 return t.GQL.Elem != nil && isSlice
225}
226
227func (t *TypeReference) IsNamed() bool {
228 _, isSlice := t.GO.(*types.Named)
229 return isSlice
230}
231
232func (t *TypeReference) IsStruct() bool {
233 _, isStruct := t.GO.Underlying().(*types.Struct)
234 return isStruct
235}
236
237func (t *TypeReference) IsScalar() bool {
238 return t.Definition.Kind == ast.Scalar
239}
240
241func (t *TypeReference) UniquenessKey() string {
242 var nullability = "O"
243 if t.GQL.NonNull {
244 nullability = "N"
245 }
246
247 return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO)
248}
249
250func (t *TypeReference) MarshalFunc() string {
251 if t.Definition == nil {
252 panic(errors.New("Definition missing for " + t.GQL.Name()))
253 }
254
255 if t.Definition.Kind == ast.InputObject {
256 return ""
257 }
258
259 return "marshal" + t.UniquenessKey()
260}
261
262func (t *TypeReference) UnmarshalFunc() string {
263 if t.Definition == nil {
264 panic(errors.New("Definition missing for " + t.GQL.Name()))
265 }
266
267 if !t.Definition.IsInputType() {
268 return ""
269 }
270
271 return "unmarshal" + t.UniquenessKey()
272}
273
274func (b *Binder) PushRef(ret *TypeReference) {
275 b.References = append(b.References, ret)
276}
277
278func isMap(t types.Type) bool {
279 if t == nil {
280 return true
281 }
282 _, ok := t.(*types.Map)
283 return ok
284}
285
286func isIntf(t types.Type) bool {
287 if t == nil {
288 return true
289 }
290 _, ok := t.(*types.Interface)
291 return ok
292}
293
294func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
295 var pkgName, typeName string
296 def := b.schema.Types[schemaType.Name()]
297 defer func() {
298 if err == nil && ret != nil {
299 b.PushRef(ret)
300 }
301 }()
302
303 if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
304 return nil, fmt.Errorf("%s was not found", schemaType.Name())
305 }
306
307 for _, model := range b.cfg.Models[schemaType.Name()].Model {
308 if model == "map[string]interface{}" {
309 if !isMap(bindTarget) {
310 continue
311 }
312 return &TypeReference{
313 Definition: def,
314 GQL: schemaType,
315 GO: MapType,
316 }, nil
317 }
318
319 if model == "interface{}" {
320 if !isIntf(bindTarget) {
321 continue
322 }
323 return &TypeReference{
324 Definition: def,
325 GQL: schemaType,
326 GO: InterfaceType,
327 }, nil
328 }
329
330 pkgName, typeName = code.PkgAndType(model)
331 if pkgName == "" {
332 return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
333 }
334
335 ref := &TypeReference{
336 Definition: def,
337 GQL: schemaType,
338 }
339
340 obj, err := b.FindObject(pkgName, typeName)
341 if err != nil {
342 return nil, err
343 }
344
345 if fun, isFunc := obj.(*types.Func); isFunc {
346 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
347 ref.Marshaler = fun
348 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
349 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
350 ref.GO = obj.Type()
351 ref.IsMarshaler = true
352 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
353 // Special case for named types wrapping strings. Used by default enum implementations.
354
355 ref.GO = obj.Type()
356 ref.CastType = underlying
357
358 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
359 if err != nil {
360 return nil, err
361 }
362
363 ref.Marshaler = underlyingRef.Marshaler
364 ref.Unmarshaler = underlyingRef.Unmarshaler
365 } else {
366 ref.GO = obj.Type()
367 }
368
369 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
370
371 if bindTarget != nil {
372 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
373 continue
374 }
375 ref.GO = bindTarget
376 }
377
378 return ref, nil
379 }
380
381 return nil, fmt.Errorf("%s has type compatible with %s", schemaType.Name(), bindTarget.String())
382}
383
384func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
385 if t.Elem != nil {
386 child := b.CopyModifiersFromAst(t.Elem, base)
387 if _, isStruct := child.Underlying().(*types.Struct); isStruct {
388 child = types.NewPointer(child)
389 }
390 return types.NewSlice(child)
391 }
392
393 var isInterface bool
394 if named, ok := base.(*types.Named); ok {
395 _, isInterface = named.Underlying().(*types.Interface)
396 }
397
398 if !isInterface && !t.NonNull {
399 return types.NewPointer(base)
400 }
401
402 return base
403}
404
405func hasMethod(it types.Type, name string) bool {
406 if ptr, isPtr := it.(*types.Pointer); isPtr {
407 it = ptr.Elem()
408 }
409 namedType, ok := it.(*types.Named)
410 if !ok {
411 return false
412 }
413
414 for i := 0; i < namedType.NumMethods(); i++ {
415 if namedType.Method(i).Name() == name {
416 return true
417 }
418 }
419 return false
420}
421
422func basicUnderlying(it types.Type) *types.Basic {
423 if ptr, isPtr := it.(*types.Pointer); isPtr {
424 it = ptr.Elem()
425 }
426 namedType, ok := it.(*types.Named)
427 if !ok {
428 return nil
429 }
430
431 if basic, ok := namedType.Underlying().(*types.Basic); ok {
432 return basic
433 }
434
435 return nil
436}