1package codegen
2
3import (
4 "fmt"
5 "strconv"
6 "strings"
7
8 "github.com/99designs/gqlgen/codegen/templates"
9 "github.com/pkg/errors"
10 "github.com/vektah/gqlparser/ast"
11)
12
13type Directive struct {
14 Name string
15 Args []*FieldArgument
16 Builtin bool
17}
18
19func (b *builder) buildDirectives() (map[string]*Directive, error) {
20 directives := make(map[string]*Directive, len(b.Schema.Directives))
21
22 for name, dir := range b.Schema.Directives {
23 if _, ok := directives[name]; ok {
24 return nil, errors.Errorf("directive with name %s already exists", name)
25 }
26
27 var builtin bool
28 if name == "skip" || name == "include" || name == "deprecated" {
29 builtin = true
30 }
31
32 var args []*FieldArgument
33 for _, arg := range dir.Arguments {
34 tr, err := b.Binder.TypeReference(arg.Type, nil)
35 if err != nil {
36 return nil, err
37 }
38
39 newArg := &FieldArgument{
40 ArgumentDefinition: arg,
41 TypeReference: tr,
42 VarName: templates.ToGoPrivate(arg.Name),
43 }
44
45 if arg.DefaultValue != nil {
46 var err error
47 newArg.Default, err = arg.DefaultValue.Value(nil)
48 if err != nil {
49 return nil, errors.Errorf("default value for directive argument %s(%s) is not valid: %s", dir.Name, arg.Name, err.Error())
50 }
51 }
52 args = append(args, newArg)
53 }
54
55 directives[name] = &Directive{
56 Name: name,
57 Args: args,
58 Builtin: builtin,
59 }
60 }
61
62 return directives, nil
63}
64
65func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
66 dirs := make([]*Directive, len(list))
67 for i, d := range list {
68 argValues := make(map[string]interface{}, len(d.Arguments))
69 for _, da := range d.Arguments {
70 val, err := da.Value.Value(nil)
71 if err != nil {
72 return nil, err
73 }
74 argValues[da.Name] = val
75 }
76 def, ok := b.Directives[d.Name]
77 if !ok {
78 return nil, fmt.Errorf("directive %s not found", d.Name)
79 }
80
81 var args []*FieldArgument
82 for _, a := range def.Args {
83 value := a.Default
84 if argValue, ok := argValues[a.Name]; ok {
85 value = argValue
86 }
87 args = append(args, &FieldArgument{
88 ArgumentDefinition: a.ArgumentDefinition,
89 Value: value,
90 VarName: a.VarName,
91 TypeReference: a.TypeReference,
92 })
93 }
94 dirs[i] = &Directive{
95 Name: d.Name,
96 Args: args,
97 }
98
99 }
100
101 return dirs, nil
102}
103
104func (d *Directive) ArgsFunc() string {
105 if len(d.Args) == 0 {
106 return ""
107 }
108
109 return "dir_" + d.Name + "_args"
110}
111
112func (d *Directive) CallArgs() string {
113 args := []string{"ctx", "obj", "n"}
114
115 for _, arg := range d.Args {
116 args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
117 }
118
119 return strings.Join(args, ", ")
120}
121
122func (d *Directive) ResolveArgs(obj string, next string) string {
123 args := []string{"ctx", obj, next}
124
125 for _, arg := range d.Args {
126 dArg := "&" + arg.VarName
127 if !arg.TypeReference.IsPtr() {
128 if arg.Value != nil {
129 dArg = templates.Dump(arg.Value)
130 } else {
131 dArg = templates.Dump(arg.Default)
132 }
133 } else if arg.Value == nil && arg.Default == nil {
134 dArg = "nil"
135 }
136
137 args = append(args, dArg)
138 }
139
140 return strings.Join(args, ", ")
141}
142
143func (d *Directive) Declaration() string {
144 res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
145
146 for _, arg := range d.Args {
147 res += fmt.Sprintf(", %s %s", arg.Name, templates.CurrentImports.LookupType(arg.TypeReference.GO))
148 }
149
150 res += ") (res interface{}, err error)"
151 return res
152}