util.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"reflect"
  7	"regexp"
  8	"strings"
  9
 10	"github.com/pkg/errors"
 11	"golang.org/x/tools/go/loader"
 12)
 13
 14func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
 15	if pkgName == "" {
 16		return nil, nil
 17	}
 18	fullName := typeName
 19	if pkgName != "" {
 20		fullName = pkgName + "." + typeName
 21	}
 22
 23	pkgName, err := resolvePkg(pkgName)
 24	if err != nil {
 25		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
 26	}
 27
 28	pkg := prog.Imported[pkgName]
 29	if pkg == nil {
 30		return nil, errors.Errorf("required package was not loaded: %s", fullName)
 31	}
 32
 33	for astNode, def := range pkg.Defs {
 34		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
 35			continue
 36		}
 37
 38		return def, nil
 39	}
 40
 41	return nil, errors.Errorf("unable to find type %s\n", fullName)
 42}
 43
 44func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
 45	def, err := findGoType(prog, pkgName, typeName)
 46	if err != nil {
 47		return nil, err
 48	}
 49	if def == nil {
 50		return nil, nil
 51	}
 52
 53	namedType, ok := def.Type().(*types.Named)
 54	if !ok {
 55		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
 56	}
 57
 58	return namedType, nil
 59}
 60
 61func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
 62	namedType, err := findGoNamedType(prog, pkgName, typeName)
 63	if err != nil {
 64		return nil, err
 65	}
 66	if namedType == nil {
 67		return nil, nil
 68	}
 69
 70	underlying, ok := namedType.Underlying().(*types.Interface)
 71	if !ok {
 72		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
 73	}
 74
 75	return underlying, nil
 76}
 77
 78func findMethod(typ *types.Named, name string) *types.Func {
 79	for i := 0; i < typ.NumMethods(); i++ {
 80		method := typ.Method(i)
 81		if !method.Exported() {
 82			continue
 83		}
 84
 85		if strings.EqualFold(method.Name(), name) {
 86			return method
 87		}
 88	}
 89
 90	if s, ok := typ.Underlying().(*types.Struct); ok {
 91		for i := 0; i < s.NumFields(); i++ {
 92			field := s.Field(i)
 93			if !field.Anonymous() {
 94				continue
 95			}
 96
 97			if named, ok := field.Type().(*types.Named); ok {
 98				if f := findMethod(named, name); f != nil {
 99					return f
100				}
101			}
102		}
103	}
104
105	return nil
106}
107
108func equalFieldName(source, target string) bool {
109	source = strings.Replace(source, "_", "", -1)
110	target = strings.Replace(target, "_", "", -1)
111	return strings.EqualFold(source, target)
112}
113
114// findField attempts to match the name to a struct field with the following
115// priorites:
116// 1. If struct tag is passed then struct tag has highest priority
117// 2. Field in an embedded struct
118// 3. Actual Field name
119func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
120	var foundField *types.Var
121	foundFieldWasTag := false
122
123	for i := 0; i < typ.NumFields(); i++ {
124		field := typ.Field(i)
125
126		if structTag != "" {
127			tags := reflect.StructTag(typ.Tag(i))
128			if val, ok := tags.Lookup(structTag); ok {
129				if equalFieldName(val, name) {
130					if foundField != nil && foundFieldWasTag {
131						return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
132					}
133
134					foundField = field
135					foundFieldWasTag = true
136				}
137			}
138		}
139
140		if field.Anonymous() {
141
142			fieldType := field.Type()
143
144			if ptr, ok := fieldType.(*types.Pointer); ok {
145				fieldType = ptr.Elem()
146			}
147
148			// Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
149			// It should be safe to always call.
150			if named, ok := fieldType.Underlying().(*types.Struct); ok {
151				f, err := findField(named, name, structTag)
152				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
153					return nil, err
154				}
155				if f != nil && foundField == nil {
156					foundField = f
157				}
158			}
159		}
160
161		if !field.Exported() {
162			continue
163		}
164
165		if equalFieldName(field.Name(), name) && foundField == nil { // aqui!
166			foundField = field
167		}
168	}
169
170	if foundField == nil {
171		return nil, fmt.Errorf("no field named %s", name)
172	}
173
174	return foundField, nil
175}
176
177type BindError struct {
178	object    *Object
179	field     *Field
180	typ       types.Type
181	methodErr error
182	varErr    error
183}
184
185func (b BindError) Error() string {
186	return fmt.Sprintf(
187		"Unable to bind %s.%s to %s\n  %s\n  %s",
188		b.object.GQLType,
189		b.field.GQLName,
190		b.typ.String(),
191		b.methodErr.Error(),
192		b.varErr.Error(),
193	)
194}
195
196type BindErrors []BindError
197
198func (b BindErrors) Error() string {
199	var errs []string
200	for _, err := range b {
201		errs = append(errs, err.Error())
202	}
203	return strings.Join(errs, "\n\n")
204}
205
206func bindObject(t types.Type, object *Object, structTag string) BindErrors {
207	var errs BindErrors
208	for i := range object.Fields {
209		field := &object.Fields[i]
210
211		if field.ForceResolver {
212			continue
213		}
214
215		// first try binding to a method
216		methodErr := bindMethod(t, field)
217		if methodErr == nil {
218			continue
219		}
220
221		// otherwise try binding to a var
222		varErr := bindVar(t, field, structTag)
223
224		if varErr != nil {
225			errs = append(errs, BindError{
226				object:    object,
227				typ:       t,
228				field:     field,
229				varErr:    varErr,
230				methodErr: methodErr,
231			})
232		}
233	}
234	return errs
235}
236
237func bindMethod(t types.Type, field *Field) error {
238	namedType, ok := t.(*types.Named)
239	if !ok {
240		return fmt.Errorf("not a named type")
241	}
242
243	goName := field.GQLName
244	if field.GoFieldName != "" {
245		goName = field.GoFieldName
246	}
247	method := findMethod(namedType, goName)
248	if method == nil {
249		return fmt.Errorf("no method named %s", field.GQLName)
250	}
251	sig := method.Type().(*types.Signature)
252
253	if sig.Results().Len() == 1 {
254		field.NoErr = true
255	} else if sig.Results().Len() != 2 {
256		return fmt.Errorf("method has wrong number of args")
257	}
258	params := sig.Params()
259	// If the first argument is the context, remove it from the comparison and set
260	// the MethodHasContext flag so that the context will be passed to this model's method
261	if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
262		field.MethodHasContext = true
263		vars := make([]*types.Var, params.Len()-1)
264		for i := 1; i < params.Len(); i++ {
265			vars[i-1] = params.At(i)
266		}
267		params = types.NewTuple(vars...)
268	}
269
270	newArgs, err := matchArgs(field, params)
271	if err != nil {
272		return err
273	}
274
275	result := sig.Results().At(0)
276	if err := validateTypeBinding(field, result.Type()); err != nil {
277		return errors.Wrap(err, "method has wrong return type")
278	}
279
280	// success, args and return type match. Bind to method
281	field.GoFieldType = GoFieldMethod
282	field.GoReceiverName = "obj"
283	field.GoFieldName = method.Name()
284	field.Args = newArgs
285	return nil
286}
287
288func bindVar(t types.Type, field *Field, structTag string) error {
289	underlying, ok := t.Underlying().(*types.Struct)
290	if !ok {
291		return fmt.Errorf("not a struct")
292	}
293
294	goName := field.GQLName
295	if field.GoFieldName != "" {
296		goName = field.GoFieldName
297	}
298	structField, err := findField(underlying, goName, structTag)
299	if err != nil {
300		return err
301	}
302
303	if err := validateTypeBinding(field, structField.Type()); err != nil {
304		return errors.Wrap(err, "field has wrong type")
305	}
306
307	// success, bind to var
308	field.GoFieldType = GoFieldVariable
309	field.GoReceiverName = "obj"
310	field.GoFieldName = structField.Name()
311	return nil
312}
313
314func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
315	var newArgs []FieldArgument
316
317nextArg:
318	for j := 0; j < params.Len(); j++ {
319		param := params.At(j)
320		for _, oldArg := range field.Args {
321			if strings.EqualFold(oldArg.GQLName, param.Name()) {
322				if !field.ForceResolver {
323					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
324				}
325				newArgs = append(newArgs, oldArg)
326				continue nextArg
327			}
328		}
329
330		// no matching arg found, abort
331		return nil, fmt.Errorf("arg %s not found on method", param.Name())
332	}
333	return newArgs, nil
334}
335
336func validateTypeBinding(field *Field, goType types.Type) error {
337	gqlType := normalizeVendor(field.Type.FullSignature())
338	goTypeStr := normalizeVendor(goType.String())
339
340	if equalTypes(goTypeStr, gqlType) {
341		field.Type.Modifiers = modifiersFromGoType(goType)
342		return nil
343	}
344
345	// deal with type aliases
346	underlyingStr := normalizeVendor(goType.Underlying().String())
347	if equalTypes(underlyingStr, gqlType) {
348		field.Type.Modifiers = modifiersFromGoType(goType)
349		pkg, typ := pkgAndType(goType.String())
350		field.AliasedType = &Ref{GoType: typ, Package: pkg}
351		return nil
352	}
353
354	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
355}
356
357func modifiersFromGoType(t types.Type) []string {
358	var modifiers []string
359	for {
360		switch val := t.(type) {
361		case *types.Pointer:
362			modifiers = append(modifiers, modPtr)
363			t = val.Elem()
364		case *types.Array:
365			modifiers = append(modifiers, modList)
366			t = val.Elem()
367		case *types.Slice:
368			modifiers = append(modifiers, modList)
369			t = val.Elem()
370		default:
371			return modifiers
372		}
373	}
374}
375
376var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
377
378func normalizeVendor(pkg string) string {
379	modifiers := modsRegex.FindAllString(pkg, 1)[0]
380	pkg = strings.TrimPrefix(pkg, modifiers)
381	parts := strings.Split(pkg, "/vendor/")
382	return modifiers + parts[len(parts)-1]
383}
384
385func equalTypes(goType string, gqlType string) bool {
386	return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType
387}