1package codegen
2
3import (
4 "fmt"
5 "go/types"
6 "reflect"
7 "regexp"
8 "strings"
9
10 "github.com/pkg/errors"
11 "golang.org/x/tools/go/loader"
12)
13
14func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
15 if pkgName == "" {
16 return nil, nil
17 }
18 fullName := typeName
19 if pkgName != "" {
20 fullName = pkgName + "." + typeName
21 }
22
23 pkgName, err := resolvePkg(pkgName)
24 if err != nil {
25 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
26 }
27
28 pkg := prog.Imported[pkgName]
29 if pkg == nil {
30 return nil, errors.Errorf("required package was not loaded: %s", fullName)
31 }
32
33 for astNode, def := range pkg.Defs {
34 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
35 continue
36 }
37
38 return def, nil
39 }
40
41 return nil, errors.Errorf("unable to find type %s\n", fullName)
42}
43
44func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
45 def, err := findGoType(prog, pkgName, typeName)
46 if err != nil {
47 return nil, err
48 }
49 if def == nil {
50 return nil, nil
51 }
52
53 namedType, ok := def.Type().(*types.Named)
54 if !ok {
55 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
56 }
57
58 return namedType, nil
59}
60
61func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
62 namedType, err := findGoNamedType(prog, pkgName, typeName)
63 if err != nil {
64 return nil, err
65 }
66 if namedType == nil {
67 return nil, nil
68 }
69
70 underlying, ok := namedType.Underlying().(*types.Interface)
71 if !ok {
72 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
73 }
74
75 return underlying, nil
76}
77
78func findMethod(typ *types.Named, name string) *types.Func {
79 for i := 0; i < typ.NumMethods(); i++ {
80 method := typ.Method(i)
81 if !method.Exported() {
82 continue
83 }
84
85 if strings.EqualFold(method.Name(), name) {
86 return method
87 }
88 }
89
90 if s, ok := typ.Underlying().(*types.Struct); ok {
91 for i := 0; i < s.NumFields(); i++ {
92 field := s.Field(i)
93 if !field.Anonymous() {
94 continue
95 }
96
97 if named, ok := field.Type().(*types.Named); ok {
98 if f := findMethod(named, name); f != nil {
99 return f
100 }
101 }
102 }
103 }
104
105 return nil
106}
107
108func equalFieldName(source, target string) bool {
109 source = strings.Replace(source, "_", "", -1)
110 target = strings.Replace(target, "_", "", -1)
111 return strings.EqualFold(source, target)
112}
113
114// findField attempts to match the name to a struct field with the following
115// priorites:
116// 1. If struct tag is passed then struct tag has highest priority
117// 2. Field in an embedded struct
118// 3. Actual Field name
119func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
120 var foundField *types.Var
121 foundFieldWasTag := false
122
123 for i := 0; i < typ.NumFields(); i++ {
124 field := typ.Field(i)
125
126 if structTag != "" {
127 tags := reflect.StructTag(typ.Tag(i))
128 if val, ok := tags.Lookup(structTag); ok {
129 if equalFieldName(val, name) {
130 if foundField != nil && foundFieldWasTag {
131 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
132 }
133
134 foundField = field
135 foundFieldWasTag = true
136 }
137 }
138 }
139
140 if field.Anonymous() {
141
142 fieldType := field.Type()
143
144 if ptr, ok := fieldType.(*types.Pointer); ok {
145 fieldType = ptr.Elem()
146 }
147
148 // Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
149 // It should be safe to always call.
150 if named, ok := fieldType.Underlying().(*types.Struct); ok {
151 f, err := findField(named, name, structTag)
152 if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
153 return nil, err
154 }
155 if f != nil && foundField == nil {
156 foundField = f
157 }
158 }
159 }
160
161 if !field.Exported() {
162 continue
163 }
164
165 if equalFieldName(field.Name(), name) && foundField == nil { // aqui!
166 foundField = field
167 }
168 }
169
170 if foundField == nil {
171 return nil, fmt.Errorf("no field named %s", name)
172 }
173
174 return foundField, nil
175}
176
177type BindError struct {
178 object *Object
179 field *Field
180 typ types.Type
181 methodErr error
182 varErr error
183}
184
185func (b BindError) Error() string {
186 return fmt.Sprintf(
187 "Unable to bind %s.%s to %s\n %s\n %s",
188 b.object.GQLType,
189 b.field.GQLName,
190 b.typ.String(),
191 b.methodErr.Error(),
192 b.varErr.Error(),
193 )
194}
195
196type BindErrors []BindError
197
198func (b BindErrors) Error() string {
199 var errs []string
200 for _, err := range b {
201 errs = append(errs, err.Error())
202 }
203 return strings.Join(errs, "\n\n")
204}
205
206func bindObject(t types.Type, object *Object, structTag string) BindErrors {
207 var errs BindErrors
208 for i := range object.Fields {
209 field := &object.Fields[i]
210
211 if field.ForceResolver {
212 continue
213 }
214
215 // first try binding to a method
216 methodErr := bindMethod(t, field)
217 if methodErr == nil {
218 continue
219 }
220
221 // otherwise try binding to a var
222 varErr := bindVar(t, field, structTag)
223
224 if varErr != nil {
225 errs = append(errs, BindError{
226 object: object,
227 typ: t,
228 field: field,
229 varErr: varErr,
230 methodErr: methodErr,
231 })
232 }
233 }
234 return errs
235}
236
237func bindMethod(t types.Type, field *Field) error {
238 namedType, ok := t.(*types.Named)
239 if !ok {
240 return fmt.Errorf("not a named type")
241 }
242
243 goName := field.GQLName
244 if field.GoFieldName != "" {
245 goName = field.GoFieldName
246 }
247 method := findMethod(namedType, goName)
248 if method == nil {
249 return fmt.Errorf("no method named %s", field.GQLName)
250 }
251 sig := method.Type().(*types.Signature)
252
253 if sig.Results().Len() == 1 {
254 field.NoErr = true
255 } else if sig.Results().Len() != 2 {
256 return fmt.Errorf("method has wrong number of args")
257 }
258 params := sig.Params()
259 // If the first argument is the context, remove it from the comparison and set
260 // the MethodHasContext flag so that the context will be passed to this model's method
261 if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
262 field.MethodHasContext = true
263 vars := make([]*types.Var, params.Len()-1)
264 for i := 1; i < params.Len(); i++ {
265 vars[i-1] = params.At(i)
266 }
267 params = types.NewTuple(vars...)
268 }
269
270 newArgs, err := matchArgs(field, params)
271 if err != nil {
272 return err
273 }
274
275 result := sig.Results().At(0)
276 if err := validateTypeBinding(field, result.Type()); err != nil {
277 return errors.Wrap(err, "method has wrong return type")
278 }
279
280 // success, args and return type match. Bind to method
281 field.GoFieldType = GoFieldMethod
282 field.GoReceiverName = "obj"
283 field.GoFieldName = method.Name()
284 field.Args = newArgs
285 return nil
286}
287
288func bindVar(t types.Type, field *Field, structTag string) error {
289 underlying, ok := t.Underlying().(*types.Struct)
290 if !ok {
291 return fmt.Errorf("not a struct")
292 }
293
294 goName := field.GQLName
295 if field.GoFieldName != "" {
296 goName = field.GoFieldName
297 }
298 structField, err := findField(underlying, goName, structTag)
299 if err != nil {
300 return err
301 }
302
303 if err := validateTypeBinding(field, structField.Type()); err != nil {
304 return errors.Wrap(err, "field has wrong type")
305 }
306
307 // success, bind to var
308 field.GoFieldType = GoFieldVariable
309 field.GoReceiverName = "obj"
310 field.GoFieldName = structField.Name()
311 return nil
312}
313
314func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
315 var newArgs []FieldArgument
316
317nextArg:
318 for j := 0; j < params.Len(); j++ {
319 param := params.At(j)
320 for _, oldArg := range field.Args {
321 if strings.EqualFold(oldArg.GQLName, param.Name()) {
322 if !field.ForceResolver {
323 oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
324 }
325 newArgs = append(newArgs, oldArg)
326 continue nextArg
327 }
328 }
329
330 // no matching arg found, abort
331 return nil, fmt.Errorf("arg %s not found on method", param.Name())
332 }
333 return newArgs, nil
334}
335
336func validateTypeBinding(field *Field, goType types.Type) error {
337 gqlType := normalizeVendor(field.Type.FullSignature())
338 goTypeStr := normalizeVendor(goType.String())
339
340 if equalTypes(goTypeStr, gqlType) {
341 field.Type.Modifiers = modifiersFromGoType(goType)
342 return nil
343 }
344
345 // deal with type aliases
346 underlyingStr := normalizeVendor(goType.Underlying().String())
347 if equalTypes(underlyingStr, gqlType) {
348 field.Type.Modifiers = modifiersFromGoType(goType)
349 pkg, typ := pkgAndType(goType.String())
350 field.AliasedType = &Ref{GoType: typ, Package: pkg}
351 return nil
352 }
353
354 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
355}
356
357func modifiersFromGoType(t types.Type) []string {
358 var modifiers []string
359 for {
360 switch val := t.(type) {
361 case *types.Pointer:
362 modifiers = append(modifiers, modPtr)
363 t = val.Elem()
364 case *types.Array:
365 modifiers = append(modifiers, modList)
366 t = val.Elem()
367 case *types.Slice:
368 modifiers = append(modifiers, modList)
369 t = val.Elem()
370 default:
371 return modifiers
372 }
373 }
374}
375
376var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
377
378func normalizeVendor(pkg string) string {
379 modifiers := modsRegex.FindAllString(pkg, 1)[0]
380 pkg = strings.TrimPrefix(pkg, modifiers)
381 parts := strings.Split(pkg, "/vendor/")
382 return modifiers + parts[len(parts)-1]
383}
384
385func equalTypes(goType string, gqlType string) bool {
386 return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType
387}