util.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"regexp"
  7	"strings"
  8
  9	"github.com/pkg/errors"
 10	"golang.org/x/tools/go/loader"
 11)
 12
 13func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
 14	if pkgName == "" {
 15		return nil, nil
 16	}
 17	fullName := typeName
 18	if pkgName != "" {
 19		fullName = pkgName + "." + typeName
 20	}
 21
 22	pkgName, err := resolvePkg(pkgName)
 23	if err != nil {
 24		return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
 25	}
 26
 27	pkg := prog.Imported[pkgName]
 28	if pkg == nil {
 29		return nil, errors.Errorf("required package was not loaded: %s", fullName)
 30	}
 31
 32	for astNode, def := range pkg.Defs {
 33		if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
 34			continue
 35		}
 36
 37		return def, nil
 38	}
 39
 40	return nil, errors.Errorf("unable to find type %s\n", fullName)
 41}
 42
 43func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
 44	def, err := findGoType(prog, pkgName, typeName)
 45	if err != nil {
 46		return nil, err
 47	}
 48	if def == nil {
 49		return nil, nil
 50	}
 51
 52	namedType, ok := def.Type().(*types.Named)
 53	if !ok {
 54		return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
 55	}
 56
 57	return namedType, nil
 58}
 59
 60func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
 61	namedType, err := findGoNamedType(prog, pkgName, typeName)
 62	if err != nil {
 63		return nil, err
 64	}
 65	if namedType == nil {
 66		return nil, nil
 67	}
 68
 69	underlying, ok := namedType.Underlying().(*types.Interface)
 70	if !ok {
 71		return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
 72	}
 73
 74	return underlying, nil
 75}
 76
 77func findMethod(typ *types.Named, name string) *types.Func {
 78	for i := 0; i < typ.NumMethods(); i++ {
 79		method := typ.Method(i)
 80		if !method.Exported() {
 81			continue
 82		}
 83
 84		if strings.EqualFold(method.Name(), name) {
 85			return method
 86		}
 87	}
 88
 89	if s, ok := typ.Underlying().(*types.Struct); ok {
 90		for i := 0; i < s.NumFields(); i++ {
 91			field := s.Field(i)
 92			if !field.Anonymous() {
 93				continue
 94			}
 95
 96			if named, ok := field.Type().(*types.Named); ok {
 97				if f := findMethod(named, name); f != nil {
 98					return f
 99				}
100			}
101		}
102	}
103
104	return nil
105}
106
107func findField(typ *types.Struct, name string) *types.Var {
108	for i := 0; i < typ.NumFields(); i++ {
109		field := typ.Field(i)
110		if field.Anonymous() {
111			if named, ok := field.Type().(*types.Struct); ok {
112				if f := findField(named, name); f != nil {
113					return f
114				}
115			}
116
117			if named, ok := field.Type().Underlying().(*types.Struct); ok {
118				if f := findField(named, name); f != nil {
119					return f
120				}
121			}
122		}
123
124		if !field.Exported() {
125			continue
126		}
127
128		if strings.EqualFold(field.Name(), name) {
129			return field
130		}
131	}
132	return nil
133}
134
135type BindError struct {
136	object    *Object
137	field     *Field
138	typ       types.Type
139	methodErr error
140	varErr    error
141}
142
143func (b BindError) Error() string {
144	return fmt.Sprintf(
145		"Unable to bind %s.%s to %s\n  %s\n  %s",
146		b.object.GQLType,
147		b.field.GQLName,
148		b.typ.String(),
149		b.methodErr.Error(),
150		b.varErr.Error(),
151	)
152}
153
154type BindErrors []BindError
155
156func (b BindErrors) Error() string {
157	var errs []string
158	for _, err := range b {
159		errs = append(errs, err.Error())
160	}
161	return strings.Join(errs, "\n\n")
162}
163
164func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
165	var errs BindErrors
166	for i := range object.Fields {
167		field := &object.Fields[i]
168
169		// first try binding to a method
170		methodErr := bindMethod(imports, t, field)
171		if methodErr == nil {
172			continue
173		}
174
175		// otherwise try binding to a var
176		varErr := bindVar(imports, t, field)
177
178		if varErr != nil {
179			errs = append(errs, BindError{
180				object:    object,
181				typ:       t,
182				field:     field,
183				varErr:    varErr,
184				methodErr: methodErr,
185			})
186		}
187	}
188	return errs
189}
190
191func bindMethod(imports *Imports, t types.Type, field *Field) error {
192	namedType, ok := t.(*types.Named)
193	if !ok {
194		return fmt.Errorf("not a named type")
195	}
196
197	method := findMethod(namedType, field.GQLName)
198	if method == nil {
199		return fmt.Errorf("no method named %s", field.GQLName)
200	}
201	sig := method.Type().(*types.Signature)
202
203	if sig.Results().Len() == 1 {
204		field.NoErr = true
205	} else if sig.Results().Len() != 2 {
206		return fmt.Errorf("method has wrong number of args")
207	}
208	newArgs, err := matchArgs(field, sig.Params())
209	if err != nil {
210		return err
211	}
212
213	result := sig.Results().At(0)
214	if err := validateTypeBinding(imports, field, result.Type()); err != nil {
215		return errors.Wrap(err, "method has wrong return type")
216	}
217
218	// success, args and return type match. Bind to method
219	field.GoMethodName = "obj." + method.Name()
220	field.Args = newArgs
221	return nil
222}
223
224func bindVar(imports *Imports, t types.Type, field *Field) error {
225	underlying, ok := t.Underlying().(*types.Struct)
226	if !ok {
227		return fmt.Errorf("not a struct")
228	}
229
230	structField := findField(underlying, field.GQLName)
231	if structField == nil {
232		return fmt.Errorf("no field named %s", field.GQLName)
233	}
234
235	if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
236		return errors.Wrap(err, "field has wrong type")
237	}
238
239	// success, bind to var
240	field.GoVarName = structField.Name()
241	return nil
242}
243
244func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
245	var newArgs []FieldArgument
246
247nextArg:
248	for j := 0; j < params.Len(); j++ {
249		param := params.At(j)
250		for _, oldArg := range field.Args {
251			if strings.EqualFold(oldArg.GQLName, param.Name()) {
252				oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
253				newArgs = append(newArgs, oldArg)
254				continue nextArg
255			}
256		}
257
258		// no matching arg found, abort
259		return nil, fmt.Errorf("arg %s not found on method", param.Name())
260	}
261	return newArgs, nil
262}
263
264func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
265	gqlType := normalizeVendor(field.Type.FullSignature())
266	goTypeStr := normalizeVendor(goType.String())
267
268	if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
269		field.Type.Modifiers = modifiersFromGoType(goType)
270		return nil
271	}
272
273	// deal with type aliases
274	underlyingStr := normalizeVendor(goType.Underlying().String())
275	if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
276		field.Type.Modifiers = modifiersFromGoType(goType)
277		pkg, typ := pkgAndType(goType.String())
278		imp := imports.findByPath(pkg)
279		field.CastType = &Ref{GoType: typ, Import: imp}
280		return nil
281	}
282
283	return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
284}
285
286func modifiersFromGoType(t types.Type) []string {
287	var modifiers []string
288	for {
289		switch val := t.(type) {
290		case *types.Pointer:
291			modifiers = append(modifiers, modPtr)
292			t = val.Elem()
293		case *types.Array:
294			modifiers = append(modifiers, modList)
295			t = val.Elem()
296		case *types.Slice:
297			modifiers = append(modifiers, modList)
298			t = val.Elem()
299		default:
300			return modifiers
301		}
302	}
303}
304
305var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
306
307func normalizeVendor(pkg string) string {
308	modifiers := modsRegex.FindAllString(pkg, 1)[0]
309	pkg = strings.TrimPrefix(pkg, modifiers)
310	parts := strings.Split(pkg, "/vendor/")
311	return modifiers + parts[len(parts)-1]
312}