1package codegen
2
3import (
4 "fmt"
5 "go/types"
6 "strings"
7
8 "github.com/99designs/gqlgen/codegen/config"
9 "github.com/99designs/gqlgen/codegen/templates"
10 "github.com/pkg/errors"
11 "github.com/vektah/gqlparser/ast"
12)
13
14type ArgSet struct {
15 Args []*FieldArgument
16 FuncDecl string
17}
18
19type FieldArgument struct {
20 *ast.ArgumentDefinition
21 TypeReference *config.TypeReference
22 VarName string // The name of the var in go
23 Object *Object // A link back to the parent object
24 Default interface{} // The default value
25 Directives []*Directive
26 Value interface{} // value set in Data
27}
28
29func (f *FieldArgument) Stream() bool {
30 return f.Object != nil && f.Object.Stream
31}
32
33func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgument, error) {
34 tr, err := b.Binder.TypeReference(arg.Type, nil)
35 if err != nil {
36 return nil, err
37 }
38
39 argDirs, err := b.getDirectives(arg.Directives)
40 if err != nil {
41 return nil, err
42 }
43 newArg := FieldArgument{
44 ArgumentDefinition: arg,
45 TypeReference: tr,
46 Object: obj,
47 VarName: templates.ToGoPrivate(arg.Name),
48 Directives: argDirs,
49 }
50
51 if arg.DefaultValue != nil {
52 newArg.Default, err = arg.DefaultValue.Value(nil)
53 if err != nil {
54 return nil, errors.Errorf("default value is not valid: %s", err.Error())
55 }
56 }
57
58 return &newArg, nil
59}
60
61func (b *builder) bindArgs(field *Field, params *types.Tuple) error {
62 var newArgs []*FieldArgument
63
64nextArg:
65 for j := 0; j < params.Len(); j++ {
66 param := params.At(j)
67 for _, oldArg := range field.Args {
68 if strings.EqualFold(oldArg.Name, param.Name()) {
69 tr, err := b.Binder.TypeReference(oldArg.Type, param.Type())
70 if err != nil {
71 return err
72 }
73 oldArg.TypeReference = tr
74
75 newArgs = append(newArgs, oldArg)
76 continue nextArg
77 }
78 }
79
80 // no matching arg found, abort
81 return fmt.Errorf("arg %s not in schema", param.Name())
82 }
83
84 field.Args = newArgs
85 return nil
86}
87
88func (a *Data) Args() map[string][]*FieldArgument {
89 ret := map[string][]*FieldArgument{}
90 for _, o := range a.Objects {
91 for _, f := range o.Fields {
92 if len(f.Args) > 0 {
93 ret[f.ArgsFunc()] = f.Args
94 }
95 }
96 }
97
98 for _, d := range a.Directives {
99 if len(d.Args) > 0 {
100 ret[d.ArgsFunc()] = d.Args
101 }
102 }
103 return ret
104}