field.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"log"
  7	"reflect"
  8	"strconv"
  9	"strings"
 10
 11	"github.com/99designs/gqlgen/codegen/config"
 12	"github.com/99designs/gqlgen/codegen/templates"
 13	"github.com/pkg/errors"
 14	"github.com/vektah/gqlparser/ast"
 15)
 16
 17type Field struct {
 18	*ast.FieldDefinition
 19
 20	TypeReference    *config.TypeReference
 21	GoFieldType      GoFieldType      // The field type in go, if any
 22	GoReceiverName   string           // The name of method & var receiver in go, if any
 23	GoFieldName      string           // The name of the method or var in go, if any
 24	IsResolver       bool             // Does this field need a resolver
 25	Args             []*FieldArgument // A list of arguments to be passed to this field
 26	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
 27	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
 28	Object           *Object          // A link back to the parent object
 29	Default          interface{}      // The default value
 30	Directives       []*Directive
 31}
 32
 33func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
 34	dirs, err := b.getDirectives(field.Directives)
 35	if err != nil {
 36		return nil, err
 37	}
 38
 39	f := Field{
 40		FieldDefinition: field,
 41		Object:          obj,
 42		Directives:      dirs,
 43		GoFieldName:     templates.ToGo(field.Name),
 44		GoFieldType:     GoFieldVariable,
 45		GoReceiverName:  "obj",
 46	}
 47
 48	if field.DefaultValue != nil {
 49		var err error
 50		f.Default, err = field.DefaultValue.Value(nil)
 51		if err != nil {
 52			return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error())
 53		}
 54	}
 55
 56	for _, arg := range field.Arguments {
 57		newArg, err := b.buildArg(obj, arg)
 58		if err != nil {
 59			return nil, err
 60		}
 61		f.Args = append(f.Args, newArg)
 62	}
 63
 64	if err = b.bindField(obj, &f); err != nil {
 65		f.IsResolver = true
 66		log.Println(err.Error())
 67	}
 68
 69	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
 70		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
 71	}
 72
 73	return &f, nil
 74}
 75
 76func (b *builder) bindField(obj *Object, f *Field) error {
 77	defer func() {
 78		if f.TypeReference == nil {
 79			tr, err := b.Binder.TypeReference(f.Type, nil)
 80			if err != nil {
 81				panic(err)
 82			}
 83			f.TypeReference = tr
 84		}
 85	}()
 86
 87	switch {
 88	case f.Name == "__schema":
 89		f.GoFieldType = GoFieldMethod
 90		f.GoReceiverName = "ec"
 91		f.GoFieldName = "introspectSchema"
 92		return nil
 93	case f.Name == "__type":
 94		f.GoFieldType = GoFieldMethod
 95		f.GoReceiverName = "ec"
 96		f.GoFieldName = "introspectType"
 97		return nil
 98	case obj.Root:
 99		f.IsResolver = true
100		return nil
101	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
102		f.IsResolver = true
103		return nil
104	case obj.Type == config.MapType:
105		f.GoFieldType = GoFieldMap
106		return nil
107	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
108		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
109	}
110
111	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
112	if err != nil {
113		return err
114	}
115
116	pos := b.Binder.ObjectPosition(target)
117
118	switch target := target.(type) {
119	case nil:
120		objPos := b.Binder.TypePosition(obj.Type)
121		return fmt.Errorf(
122			"%s:%d adding resolver method for %s.%s, nothing matched",
123			objPos.Filename,
124			objPos.Line,
125			obj.Name,
126			f.Name,
127		)
128
129	case *types.Func:
130		sig := target.Type().(*types.Signature)
131		if sig.Results().Len() == 1 {
132			f.NoErr = true
133		} else if sig.Results().Len() != 2 {
134			return fmt.Errorf("method has wrong number of args")
135		}
136		params := sig.Params()
137		// If the first argument is the context, remove it from the comparison and set
138		// the MethodHasContext flag so that the context will be passed to this model's method
139		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
140			f.MethodHasContext = true
141			vars := make([]*types.Var, params.Len()-1)
142			for i := 1; i < params.Len(); i++ {
143				vars[i-1] = params.At(i)
144			}
145			params = types.NewTuple(vars...)
146		}
147
148		if err = b.bindArgs(f, params); err != nil {
149			return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line)
150		}
151
152		result := sig.Results().At(0)
153		tr, err := b.Binder.TypeReference(f.Type, result.Type())
154		if err != nil {
155			return err
156		}
157
158		// success, args and return type match. Bind to method
159		f.GoFieldType = GoFieldMethod
160		f.GoReceiverName = "obj"
161		f.GoFieldName = target.Name()
162		f.TypeReference = tr
163
164		return nil
165
166	case *types.Var:
167		tr, err := b.Binder.TypeReference(f.Type, target.Type())
168		if err != nil {
169			return err
170		}
171
172		// success, bind to var
173		f.GoFieldType = GoFieldVariable
174		f.GoReceiverName = "obj"
175		f.GoFieldName = target.Name()
176		f.TypeReference = tr
177
178		return nil
179	default:
180		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
181	}
182}
183
184// findField attempts to match the name to a struct field with the following
185// priorites:
186// 1. Any method with a matching name
187// 2. Any Fields with a struct tag (see config.StructTag)
188// 3. Any fields with a matching name
189// 4. Same logic again for embedded fields
190func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
191	for i := 0; i < named.NumMethods(); i++ {
192		method := named.Method(i)
193		if !method.Exported() {
194			continue
195		}
196
197		if !strings.EqualFold(method.Name(), name) {
198			continue
199		}
200
201		return method, nil
202	}
203
204	strukt, ok := named.Underlying().(*types.Struct)
205	if !ok {
206		return nil, fmt.Errorf("not a struct")
207	}
208	return b.findBindStructTarget(strukt, name)
209}
210
211func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) {
212	// struct tags have the highest priority
213	if b.Config.StructTag != "" {
214		var foundField *types.Var
215		for i := 0; i < strukt.NumFields(); i++ {
216			field := strukt.Field(i)
217			if !field.Exported() {
218				continue
219			}
220			tags := reflect.StructTag(strukt.Tag(i))
221			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
222				if foundField != nil {
223					return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
224				}
225
226				foundField = field
227			}
228		}
229		if foundField != nil {
230			return foundField, nil
231		}
232	}
233
234	// Then matching field names
235	for i := 0; i < strukt.NumFields(); i++ {
236		field := strukt.Field(i)
237		if !field.Exported() {
238			continue
239		}
240		if equalFieldName(field.Name(), name) { // aqui!
241			return field, nil
242		}
243	}
244
245	// Then look in embedded structs
246	for i := 0; i < strukt.NumFields(); i++ {
247		field := strukt.Field(i)
248		if !field.Exported() {
249			continue
250		}
251
252		if !field.Anonymous() {
253			continue
254		}
255
256		fieldType := field.Type()
257		if ptr, ok := fieldType.(*types.Pointer); ok {
258			fieldType = ptr.Elem()
259		}
260
261		switch fieldType := fieldType.(type) {
262		case *types.Named:
263			f, err := b.findBindTarget(fieldType, name)
264			if err != nil {
265				return nil, err
266			}
267			if f != nil {
268				return f, nil
269			}
270		case *types.Struct:
271			f, err := b.findBindStructTarget(fieldType, name)
272			if err != nil {
273				return nil, err
274			}
275			if f != nil {
276				return f, nil
277			}
278		default:
279			panic(fmt.Errorf("unknown embedded field type %T", field.Type()))
280		}
281	}
282
283	return nil, nil
284}
285
286func (f *Field) HasDirectives() bool {
287	return len(f.ImplDirectives()) > 0
288}
289
290func (f *Field) DirectiveObjName() string {
291	if f.Object.Root {
292		return "nil"
293	}
294	return f.GoReceiverName
295}
296
297func (f *Field) ImplDirectives() []*Directive {
298	var d []*Directive
299	loc := ast.LocationFieldDefinition
300	if f.Object.IsInputType() {
301		loc = ast.LocationInputFieldDefinition
302	}
303	for i := range f.Directives {
304		if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) {
305			d = append(d, f.Directives[i])
306		}
307	}
308	return d
309}
310
311func (f *Field) IsReserved() bool {
312	return strings.HasPrefix(f.Name, "__")
313}
314
315func (f *Field) IsMethod() bool {
316	return f.GoFieldType == GoFieldMethod
317}
318
319func (f *Field) IsVariable() bool {
320	return f.GoFieldType == GoFieldVariable
321}
322
323func (f *Field) IsMap() bool {
324	return f.GoFieldType == GoFieldMap
325}
326
327func (f *Field) IsConcurrent() bool {
328	if f.Object.DisableConcurrency {
329		return false
330	}
331	return f.MethodHasContext || f.IsResolver
332}
333
334func (f *Field) GoNameUnexported() string {
335	return templates.ToGoPrivate(f.Name)
336}
337
338func (f *Field) ShortInvocation() string {
339	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
340}
341
342func (f *Field) ArgsFunc() string {
343	if len(f.Args) == 0 {
344		return ""
345	}
346
347	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
348}
349
350func (f *Field) ResolverType() string {
351	if !f.IsResolver {
352		return ""
353	}
354
355	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
356}
357
358func (f *Field) ShortResolverDeclaration() string {
359	res := "(ctx context.Context"
360
361	if !f.Object.Root {
362		res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type))
363	}
364	for _, arg := range f.Args {
365		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
366	}
367
368	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
369	if f.Object.Stream {
370		result = "<-chan " + result
371	}
372
373	res += fmt.Sprintf(") (%s, error)", result)
374	return res
375}
376
377func (f *Field) ComplexitySignature() string {
378	res := fmt.Sprintf("func(childComplexity int")
379	for _, arg := range f.Args {
380		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
381	}
382	res += ") int"
383	return res
384}
385
386func (f *Field) ComplexityArgs() string {
387	args := make([]string, len(f.Args))
388	for i, arg := range f.Args {
389		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
390	}
391
392	return strings.Join(args, ", ")
393}
394
395func (f *Field) CallArgs() string {
396	args := make([]string, 0, len(f.Args)+2)
397
398	if f.IsResolver {
399		args = append(args, "rctx")
400
401		if !f.Object.Root {
402			args = append(args, "obj")
403		}
404	} else if f.MethodHasContext {
405		args = append(args, "ctx")
406	}
407
408	for _, arg := range f.Args {
409		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
410	}
411
412	return strings.Join(args, ", ")
413}