1package codegen
2
3import (
4 "fmt"
5 "go/types"
6 "regexp"
7 "strings"
8
9 "github.com/pkg/errors"
10 "golang.org/x/tools/go/loader"
11)
12
13func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
14 if pkgName == "" {
15 return nil, nil
16 }
17 fullName := typeName
18 if pkgName != "" {
19 fullName = pkgName + "." + typeName
20 }
21
22 pkgName, err := resolvePkg(pkgName)
23 if err != nil {
24 return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
25 }
26
27 pkg := prog.Imported[pkgName]
28 if pkg == nil {
29 return nil, errors.Errorf("required package was not loaded: %s", fullName)
30 }
31
32 for astNode, def := range pkg.Defs {
33 if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
34 continue
35 }
36
37 return def, nil
38 }
39
40 return nil, errors.Errorf("unable to find type %s\n", fullName)
41}
42
43func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
44 def, err := findGoType(prog, pkgName, typeName)
45 if err != nil {
46 return nil, err
47 }
48 if def == nil {
49 return nil, nil
50 }
51
52 namedType, ok := def.Type().(*types.Named)
53 if !ok {
54 return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
55 }
56
57 return namedType, nil
58}
59
60func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
61 namedType, err := findGoNamedType(prog, pkgName, typeName)
62 if err != nil {
63 return nil, err
64 }
65 if namedType == nil {
66 return nil, nil
67 }
68
69 underlying, ok := namedType.Underlying().(*types.Interface)
70 if !ok {
71 return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
72 }
73
74 return underlying, nil
75}
76
77func findMethod(typ *types.Named, name string) *types.Func {
78 for i := 0; i < typ.NumMethods(); i++ {
79 method := typ.Method(i)
80 if !method.Exported() {
81 continue
82 }
83
84 if strings.EqualFold(method.Name(), name) {
85 return method
86 }
87 }
88
89 if s, ok := typ.Underlying().(*types.Struct); ok {
90 for i := 0; i < s.NumFields(); i++ {
91 field := s.Field(i)
92 if !field.Anonymous() {
93 continue
94 }
95
96 if named, ok := field.Type().(*types.Named); ok {
97 if f := findMethod(named, name); f != nil {
98 return f
99 }
100 }
101 }
102 }
103
104 return nil
105}
106
107func findField(typ *types.Struct, name string) *types.Var {
108 for i := 0; i < typ.NumFields(); i++ {
109 field := typ.Field(i)
110 if field.Anonymous() {
111 if named, ok := field.Type().(*types.Struct); ok {
112 if f := findField(named, name); f != nil {
113 return f
114 }
115 }
116
117 if named, ok := field.Type().Underlying().(*types.Struct); ok {
118 if f := findField(named, name); f != nil {
119 return f
120 }
121 }
122 }
123
124 if !field.Exported() {
125 continue
126 }
127
128 if strings.EqualFold(field.Name(), name) {
129 return field
130 }
131 }
132 return nil
133}
134
135type BindError struct {
136 object *Object
137 field *Field
138 typ types.Type
139 methodErr error
140 varErr error
141}
142
143func (b BindError) Error() string {
144 return fmt.Sprintf(
145 "Unable to bind %s.%s to %s\n %s\n %s",
146 b.object.GQLType,
147 b.field.GQLName,
148 b.typ.String(),
149 b.methodErr.Error(),
150 b.varErr.Error(),
151 )
152}
153
154type BindErrors []BindError
155
156func (b BindErrors) Error() string {
157 var errs []string
158 for _, err := range b {
159 errs = append(errs, err.Error())
160 }
161 return strings.Join(errs, "\n\n")
162}
163
164func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
165 var errs BindErrors
166 for i := range object.Fields {
167 field := &object.Fields[i]
168
169 // first try binding to a method
170 methodErr := bindMethod(imports, t, field)
171 if methodErr == nil {
172 continue
173 }
174
175 // otherwise try binding to a var
176 varErr := bindVar(imports, t, field)
177
178 if varErr != nil {
179 errs = append(errs, BindError{
180 object: object,
181 typ: t,
182 field: field,
183 varErr: varErr,
184 methodErr: methodErr,
185 })
186 }
187 }
188 return errs
189}
190
191func bindMethod(imports *Imports, t types.Type, field *Field) error {
192 namedType, ok := t.(*types.Named)
193 if !ok {
194 return fmt.Errorf("not a named type")
195 }
196
197 method := findMethod(namedType, field.GQLName)
198 if method == nil {
199 return fmt.Errorf("no method named %s", field.GQLName)
200 }
201 sig := method.Type().(*types.Signature)
202
203 if sig.Results().Len() == 1 {
204 field.NoErr = true
205 } else if sig.Results().Len() != 2 {
206 return fmt.Errorf("method has wrong number of args")
207 }
208 newArgs, err := matchArgs(field, sig.Params())
209 if err != nil {
210 return err
211 }
212
213 result := sig.Results().At(0)
214 if err := validateTypeBinding(imports, field, result.Type()); err != nil {
215 return errors.Wrap(err, "method has wrong return type")
216 }
217
218 // success, args and return type match. Bind to method
219 field.GoMethodName = "obj." + method.Name()
220 field.Args = newArgs
221 return nil
222}
223
224func bindVar(imports *Imports, t types.Type, field *Field) error {
225 underlying, ok := t.Underlying().(*types.Struct)
226 if !ok {
227 return fmt.Errorf("not a struct")
228 }
229
230 structField := findField(underlying, field.GQLName)
231 if structField == nil {
232 return fmt.Errorf("no field named %s", field.GQLName)
233 }
234
235 if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
236 return errors.Wrap(err, "field has wrong type")
237 }
238
239 // success, bind to var
240 field.GoVarName = structField.Name()
241 return nil
242}
243
244func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
245 var newArgs []FieldArgument
246
247nextArg:
248 for j := 0; j < params.Len(); j++ {
249 param := params.At(j)
250 for _, oldArg := range field.Args {
251 if strings.EqualFold(oldArg.GQLName, param.Name()) {
252 oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
253 newArgs = append(newArgs, oldArg)
254 continue nextArg
255 }
256 }
257
258 // no matching arg found, abort
259 return nil, fmt.Errorf("arg %s not found on method", param.Name())
260 }
261 return newArgs, nil
262}
263
264func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
265 gqlType := normalizeVendor(field.Type.FullSignature())
266 goTypeStr := normalizeVendor(goType.String())
267
268 if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
269 field.Type.Modifiers = modifiersFromGoType(goType)
270 return nil
271 }
272
273 // deal with type aliases
274 underlyingStr := normalizeVendor(goType.Underlying().String())
275 if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
276 field.Type.Modifiers = modifiersFromGoType(goType)
277 pkg, typ := pkgAndType(goType.String())
278 imp := imports.findByPath(pkg)
279 field.CastType = &Ref{GoType: typ, Import: imp}
280 return nil
281 }
282
283 return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
284}
285
286func modifiersFromGoType(t types.Type) []string {
287 var modifiers []string
288 for {
289 switch val := t.(type) {
290 case *types.Pointer:
291 modifiers = append(modifiers, modPtr)
292 t = val.Elem()
293 case *types.Array:
294 modifiers = append(modifiers, modList)
295 t = val.Elem()
296 case *types.Slice:
297 modifiers = append(modifiers, modList)
298 t = val.Elem()
299 default:
300 return modifiers
301 }
302 }
303}
304
305var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
306
307func normalizeVendor(pkg string) string {
308 modifiers := modsRegex.FindAllString(pkg, 1)[0]
309 pkg = strings.TrimPrefix(pkg, modifiers)
310 parts := strings.Split(pkg, "/vendor/")
311 return modifiers + parts[len(parts)-1]
312}