1package codegen
2
3import (
4 "fmt"
5 "go/types"
6 "reflect"
7 "regexp"
8 "strings"
9
10 "github.com/pkg/errors"
11 "golang.org/x/tools/go/loader"
12)
13
14func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
15 if pkgName == "" {
16 return nil, nil
17 }
18 fullName := typeName
19 if pkgName != "" {
20 fullName = pkgName + "." + typeName
21 }
22
23 pkgName, err := resolvePkg(pkgName)
24 if err != nil {
25 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
26 }
27
28 pkg := prog.Imported[pkgName]
29 if pkg == nil {
30 return nil, errors.Errorf("required package was not loaded: %s", fullName)
31 }
32
33 for astNode, def := range pkg.Defs {
34 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
35 continue
36 }
37
38 return def, nil
39 }
40
41 return nil, errors.Errorf("unable to find type %s\n", fullName)
42}
43
44func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
45 def, err := findGoType(prog, pkgName, typeName)
46 if err != nil {
47 return nil, err
48 }
49 if def == nil {
50 return nil, nil
51 }
52
53 namedType, ok := def.Type().(*types.Named)
54 if !ok {
55 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
56 }
57
58 return namedType, nil
59}
60
61func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
62 namedType, err := findGoNamedType(prog, pkgName, typeName)
63 if err != nil {
64 return nil, err
65 }
66 if namedType == nil {
67 return nil, nil
68 }
69
70 underlying, ok := namedType.Underlying().(*types.Interface)
71 if !ok {
72 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
73 }
74
75 return underlying, nil
76}
77
78func findMethod(typ *types.Named, name string) *types.Func {
79 for i := 0; i < typ.NumMethods(); i++ {
80 method := typ.Method(i)
81 if !method.Exported() {
82 continue
83 }
84
85 if strings.EqualFold(method.Name(), name) {
86 return method
87 }
88 }
89
90 if s, ok := typ.Underlying().(*types.Struct); ok {
91 for i := 0; i < s.NumFields(); i++ {
92 field := s.Field(i)
93 if !field.Anonymous() {
94 continue
95 }
96
97 if named, ok := field.Type().(*types.Named); ok {
98 if f := findMethod(named, name); f != nil {
99 return f
100 }
101 }
102 }
103 }
104
105 return nil
106}
107
108// findField attempts to match the name to a struct field with the following
109// priorites:
110// 1. If struct tag is passed then struct tag has highest priority
111// 2. Field in an embedded struct
112// 3. Actual Field name
113func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
114 var foundField *types.Var
115 foundFieldWasTag := false
116
117 for i := 0; i < typ.NumFields(); i++ {
118 field := typ.Field(i)
119
120 if structTag != "" {
121 tags := reflect.StructTag(typ.Tag(i))
122 if val, ok := tags.Lookup(structTag); ok {
123 if strings.EqualFold(val, name) {
124 if foundField != nil && foundFieldWasTag {
125 return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
126 }
127
128 foundField = field
129 foundFieldWasTag = true
130 }
131 }
132 }
133
134 if field.Anonymous() {
135 if named, ok := field.Type().(*types.Struct); ok {
136 f, err := findField(named, name, structTag)
137 if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
138 return nil, err
139 }
140 if f != nil && foundField == nil {
141 foundField = f
142 }
143 }
144
145 if named, ok := field.Type().Underlying().(*types.Struct); ok {
146 f, err := findField(named, name, structTag)
147 if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
148 return nil, err
149 }
150 if f != nil && foundField == nil {
151 foundField = f
152 }
153 }
154 }
155
156 if !field.Exported() {
157 continue
158 }
159
160 if strings.EqualFold(field.Name(), name) && foundField == nil {
161 foundField = field
162 }
163 }
164
165 if foundField == nil {
166 return nil, fmt.Errorf("no field named %s", name)
167 }
168
169 return foundField, nil
170}
171
172type BindError struct {
173 object *Object
174 field *Field
175 typ types.Type
176 methodErr error
177 varErr error
178}
179
180func (b BindError) Error() string {
181 return fmt.Sprintf(
182 "Unable to bind %s.%s to %s\n %s\n %s",
183 b.object.GQLType,
184 b.field.GQLName,
185 b.typ.String(),
186 b.methodErr.Error(),
187 b.varErr.Error(),
188 )
189}
190
191type BindErrors []BindError
192
193func (b BindErrors) Error() string {
194 var errs []string
195 for _, err := range b {
196 errs = append(errs, err.Error())
197 }
198 return strings.Join(errs, "\n\n")
199}
200
201func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
202 var errs BindErrors
203 for i := range object.Fields {
204 field := &object.Fields[i]
205
206 if field.ForceResolver {
207 continue
208 }
209
210 // first try binding to a method
211 methodErr := bindMethod(imports, t, field)
212 if methodErr == nil {
213 continue
214 }
215
216 // otherwise try binding to a var
217 varErr := bindVar(imports, t, field, structTag)
218
219 if varErr != nil {
220 errs = append(errs, BindError{
221 object: object,
222 typ: t,
223 field: field,
224 varErr: varErr,
225 methodErr: methodErr,
226 })
227 }
228 }
229 return errs
230}
231
232func bindMethod(imports *Imports, t types.Type, field *Field) error {
233 namedType, ok := t.(*types.Named)
234 if !ok {
235 return fmt.Errorf("not a named type")
236 }
237
238 goName := field.GQLName
239 if field.GoFieldName != "" {
240 goName = field.GoFieldName
241 }
242 method := findMethod(namedType, goName)
243 if method == nil {
244 return fmt.Errorf("no method named %s", field.GQLName)
245 }
246 sig := method.Type().(*types.Signature)
247
248 if sig.Results().Len() == 1 {
249 field.NoErr = true
250 } else if sig.Results().Len() != 2 {
251 return fmt.Errorf("method has wrong number of args")
252 }
253 newArgs, err := matchArgs(field, sig.Params())
254 if err != nil {
255 return err
256 }
257
258 result := sig.Results().At(0)
259 if err := validateTypeBinding(imports, field, result.Type()); err != nil {
260 return errors.Wrap(err, "method has wrong return type")
261 }
262
263 // success, args and return type match. Bind to method
264 field.GoFieldType = GoFieldMethod
265 field.GoReceiverName = "obj"
266 field.GoFieldName = method.Name()
267 field.Args = newArgs
268 return nil
269}
270
271func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
272 underlying, ok := t.Underlying().(*types.Struct)
273 if !ok {
274 return fmt.Errorf("not a struct")
275 }
276
277 goName := field.GQLName
278 if field.GoFieldName != "" {
279 goName = field.GoFieldName
280 }
281 structField, err := findField(underlying, goName, structTag)
282 if err != nil {
283 return err
284 }
285
286 if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
287 return errors.Wrap(err, "field has wrong type")
288 }
289
290 // success, bind to var
291 field.GoFieldType = GoFieldVariable
292 field.GoReceiverName = "obj"
293 field.GoFieldName = structField.Name()
294 return nil
295}
296
297func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
298 var newArgs []FieldArgument
299
300nextArg:
301 for j := 0; j < params.Len(); j++ {
302 param := params.At(j)
303 for _, oldArg := range field.Args {
304 if strings.EqualFold(oldArg.GQLName, param.Name()) {
305 if !field.ForceResolver {
306 oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
307 }
308 newArgs = append(newArgs, oldArg)
309 continue nextArg
310 }
311 }
312
313 // no matching arg found, abort
314 return nil, fmt.Errorf("arg %s not found on method", param.Name())
315 }
316 return newArgs, nil
317}
318
319func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
320 gqlType := normalizeVendor(field.Type.FullSignature())
321 goTypeStr := normalizeVendor(goType.String())
322
323 if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
324 field.Type.Modifiers = modifiersFromGoType(goType)
325 return nil
326 }
327
328 // deal with type aliases
329 underlyingStr := normalizeVendor(goType.Underlying().String())
330 if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
331 field.Type.Modifiers = modifiersFromGoType(goType)
332 pkg, typ := pkgAndType(goType.String())
333 imp := imports.findByPath(pkg)
334 field.AliasedType = &Ref{GoType: typ, Import: imp}
335 return nil
336 }
337
338 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
339}
340
341func modifiersFromGoType(t types.Type) []string {
342 var modifiers []string
343 for {
344 switch val := t.(type) {
345 case *types.Pointer:
346 modifiers = append(modifiers, modPtr)
347 t = val.Elem()
348 case *types.Array:
349 modifiers = append(modifiers, modList)
350 t = val.Elem()
351 case *types.Slice:
352 modifiers = append(modifiers, modList)
353 t = val.Elem()
354 default:
355 return modifiers
356 }
357 }
358}
359
360var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
361
362func normalizeVendor(pkg string) string {
363 modifiers := modsRegex.FindAllString(pkg, 1)[0]
364 pkg = strings.TrimPrefix(pkg, modifiers)
365 parts := strings.Split(pkg, "/vendor/")
366 return modifiers + parts[len(parts)-1]
367}