util.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"reflect"
  7	"regexp"
  8	"strings"
  9
 10	"github.com/pkg/errors"
 11	"golang.org/x/tools/go/loader"
 12)
 13
 14func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
 15	if pkgName == "" {
 16		return nil, nil
 17	}
 18	fullName := typeName
 19	if pkgName != "" {
 20		fullName = pkgName + "." + typeName
 21	}
 22
 23	pkgName, err := resolvePkg(pkgName)
 24	if err != nil {
 25		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
 26	}
 27
 28	pkg := prog.Imported[pkgName]
 29	if pkg == nil {
 30		return nil, errors.Errorf("required package was not loaded: %s", fullName)
 31	}
 32
 33	for astNode, def := range pkg.Defs {
 34		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
 35			continue
 36		}
 37
 38		return def, nil
 39	}
 40
 41	return nil, errors.Errorf("unable to find type %s\n", fullName)
 42}
 43
 44func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
 45	def, err := findGoType(prog, pkgName, typeName)
 46	if err != nil {
 47		return nil, err
 48	}
 49	if def == nil {
 50		return nil, nil
 51	}
 52
 53	namedType, ok := def.Type().(*types.Named)
 54	if !ok {
 55		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
 56	}
 57
 58	return namedType, nil
 59}
 60
 61func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
 62	namedType, err := findGoNamedType(prog, pkgName, typeName)
 63	if err != nil {
 64		return nil, err
 65	}
 66	if namedType == nil {
 67		return nil, nil
 68	}
 69
 70	underlying, ok := namedType.Underlying().(*types.Interface)
 71	if !ok {
 72		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
 73	}
 74
 75	return underlying, nil
 76}
 77
 78func findMethod(typ *types.Named, name string) *types.Func {
 79	for i := 0; i < typ.NumMethods(); i++ {
 80		method := typ.Method(i)
 81		if !method.Exported() {
 82			continue
 83		}
 84
 85		if strings.EqualFold(method.Name(), name) {
 86			return method
 87		}
 88	}
 89
 90	if s, ok := typ.Underlying().(*types.Struct); ok {
 91		for i := 0; i < s.NumFields(); i++ {
 92			field := s.Field(i)
 93			if !field.Anonymous() {
 94				continue
 95			}
 96
 97			if named, ok := field.Type().(*types.Named); ok {
 98				if f := findMethod(named, name); f != nil {
 99					return f
100				}
101			}
102		}
103	}
104
105	return nil
106}
107
108// findField attempts to match the name to a struct field with the following
109// priorites:
110// 1. If struct tag is passed then struct tag has highest priority
111// 2. Field in an embedded struct
112// 3. Actual Field name
113func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
114	var foundField *types.Var
115	foundFieldWasTag := false
116
117	for i := 0; i < typ.NumFields(); i++ {
118		field := typ.Field(i)
119
120		if structTag != "" {
121			tags := reflect.StructTag(typ.Tag(i))
122			if val, ok := tags.Lookup(structTag); ok {
123				if strings.EqualFold(val, name) {
124					if foundField != nil && foundFieldWasTag {
125						return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
126					}
127
128					foundField = field
129					foundFieldWasTag = true
130				}
131			}
132		}
133
134		if field.Anonymous() {
135			if named, ok := field.Type().(*types.Struct); ok {
136				f, err := findField(named, name, structTag)
137				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
138					return nil, err
139				}
140				if f != nil && foundField == nil {
141					foundField = f
142				}
143			}
144
145			if named, ok := field.Type().Underlying().(*types.Struct); ok {
146				f, err := findField(named, name, structTag)
147				if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
148					return nil, err
149				}
150				if f != nil && foundField == nil {
151					foundField = f
152				}
153			}
154		}
155
156		if !field.Exported() {
157			continue
158		}
159
160		if strings.EqualFold(field.Name(), name) && foundField == nil {
161			foundField = field
162		}
163	}
164
165	if foundField == nil {
166		return nil, fmt.Errorf("no field named %s", name)
167	}
168
169	return foundField, nil
170}
171
172type BindError struct {
173	object    *Object
174	field     *Field
175	typ       types.Type
176	methodErr error
177	varErr    error
178}
179
180func (b BindError) Error() string {
181	return fmt.Sprintf(
182		"Unable to bind %s.%s to %s\n  %s\n  %s",
183		b.object.GQLType,
184		b.field.GQLName,
185		b.typ.String(),
186		b.methodErr.Error(),
187		b.varErr.Error(),
188	)
189}
190
191type BindErrors []BindError
192
193func (b BindErrors) Error() string {
194	var errs []string
195	for _, err := range b {
196		errs = append(errs, err.Error())
197	}
198	return strings.Join(errs, "\n\n")
199}
200
201func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
202	var errs BindErrors
203	for i := range object.Fields {
204		field := &object.Fields[i]
205
206		if field.ForceResolver {
207			continue
208		}
209
210		// first try binding to a method
211		methodErr := bindMethod(imports, t, field)
212		if methodErr == nil {
213			continue
214		}
215
216		// otherwise try binding to a var
217		varErr := bindVar(imports, t, field, structTag)
218
219		if varErr != nil {
220			errs = append(errs, BindError{
221				object:    object,
222				typ:       t,
223				field:     field,
224				varErr:    varErr,
225				methodErr: methodErr,
226			})
227		}
228	}
229	return errs
230}
231
232func bindMethod(imports *Imports, t types.Type, field *Field) error {
233	namedType, ok := t.(*types.Named)
234	if !ok {
235		return fmt.Errorf("not a named type")
236	}
237
238	goName := field.GQLName
239	if field.GoFieldName != "" {
240		goName = field.GoFieldName
241	}
242	method := findMethod(namedType, goName)
243	if method == nil {
244		return fmt.Errorf("no method named %s", field.GQLName)
245	}
246	sig := method.Type().(*types.Signature)
247
248	if sig.Results().Len() == 1 {
249		field.NoErr = true
250	} else if sig.Results().Len() != 2 {
251		return fmt.Errorf("method has wrong number of args")
252	}
253	newArgs, err := matchArgs(field, sig.Params())
254	if err != nil {
255		return err
256	}
257
258	result := sig.Results().At(0)
259	if err := validateTypeBinding(imports, field, result.Type()); err != nil {
260		return errors.Wrap(err, "method has wrong return type")
261	}
262
263	// success, args and return type match. Bind to method
264	field.GoFieldType = GoFieldMethod
265	field.GoReceiverName = "obj"
266	field.GoFieldName = method.Name()
267	field.Args = newArgs
268	return nil
269}
270
271func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
272	underlying, ok := t.Underlying().(*types.Struct)
273	if !ok {
274		return fmt.Errorf("not a struct")
275	}
276
277	goName := field.GQLName
278	if field.GoFieldName != "" {
279		goName = field.GoFieldName
280	}
281	structField, err := findField(underlying, goName, structTag)
282	if err != nil {
283		return err
284	}
285
286	if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
287		return errors.Wrap(err, "field has wrong type")
288	}
289
290	// success, bind to var
291	field.GoFieldType = GoFieldVariable
292	field.GoReceiverName = "obj"
293	field.GoFieldName = structField.Name()
294	return nil
295}
296
297func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
298	var newArgs []FieldArgument
299
300nextArg:
301	for j := 0; j < params.Len(); j++ {
302		param := params.At(j)
303		for _, oldArg := range field.Args {
304			if strings.EqualFold(oldArg.GQLName, param.Name()) {
305				if !field.ForceResolver {
306					oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
307				}
308				newArgs = append(newArgs, oldArg)
309				continue nextArg
310			}
311		}
312
313		// no matching arg found, abort
314		return nil, fmt.Errorf("arg %s not found on method", param.Name())
315	}
316	return newArgs, nil
317}
318
319func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
320	gqlType := normalizeVendor(field.Type.FullSignature())
321	goTypeStr := normalizeVendor(goType.String())
322
323	if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
324		field.Type.Modifiers = modifiersFromGoType(goType)
325		return nil
326	}
327
328	// deal with type aliases
329	underlyingStr := normalizeVendor(goType.Underlying().String())
330	if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
331		field.Type.Modifiers = modifiersFromGoType(goType)
332		pkg, typ := pkgAndType(goType.String())
333		imp := imports.findByPath(pkg)
334		field.AliasedType = &Ref{GoType: typ, Import: imp}
335		return nil
336	}
337
338	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
339}
340
341func modifiersFromGoType(t types.Type) []string {
342	var modifiers []string
343	for {
344		switch val := t.(type) {
345		case *types.Pointer:
346			modifiers = append(modifiers, modPtr)
347			t = val.Elem()
348		case *types.Array:
349			modifiers = append(modifiers, modList)
350			t = val.Elem()
351		case *types.Slice:
352			modifiers = append(modifiers, modList)
353			t = val.Elem()
354		default:
355			return modifiers
356		}
357	}
358}
359
360var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
361
362func normalizeVendor(pkg string) string {
363	modifiers := modsRegex.FindAllString(pkg, 1)[0]
364	pkg = strings.TrimPrefix(pkg, modifiers)
365	parts := strings.Split(pkg, "/vendor/")
366	return modifiers + parts[len(parts)-1]
367}