1package codegen
2
3import (
4 "bytes"
5 "fmt"
6 "strconv"
7 "strings"
8 "text/template"
9 "unicode"
10)
11
12type Object struct {
13 *NamedType
14
15 Fields []Field
16 Satisfies []string
17 Root bool
18 DisableConcurrency bool
19 Stream bool
20}
21
22type Field struct {
23 *Type
24
25 GQLName string // The name of the field in graphql
26 GoMethodName string // The name of the method in go, if any
27 GoVarName string // The name of the var in go, if any
28 Args []FieldArgument // A list of arguments to be passed to this field
29 ForceResolver bool // Should be emit Resolver method
30 NoErr bool // If this is bound to a go method, does that method have an error as the second argument
31 Object *Object // A link back to the parent object
32 Default interface{} // The default value
33}
34
35type FieldArgument struct {
36 *Type
37
38 GQLName string // The name of the argument in graphql
39 GoVarName string // The name of the var in go
40 Object *Object // A link back to the parent object
41 Default interface{} // The default value
42}
43
44type Objects []*Object
45
46func (o *Object) Implementors() string {
47 satisfiedBy := strconv.Quote(o.GQLType)
48 for _, s := range o.Satisfies {
49 satisfiedBy += ", " + strconv.Quote(s)
50 }
51 return "[]string{" + satisfiedBy + "}"
52}
53
54func (o *Object) HasResolvers() bool {
55 for _, f := range o.Fields {
56 if f.IsResolver() {
57 return true
58 }
59 }
60 return false
61}
62
63func (f *Field) IsResolver() bool {
64 return f.ForceResolver || f.GoMethodName == "" && f.GoVarName == ""
65}
66
67func (f *Field) IsConcurrent() bool {
68 return f.IsResolver() && !f.Object.DisableConcurrency
69}
70func (f *Field) ShortInvocation() string {
71 if !f.IsResolver() {
72 return ""
73 }
74 shortName := strings.ToUpper(f.GQLName[:1]) + f.GQLName[1:]
75 res := fmt.Sprintf("%s().%s(ctx", f.Object.GQLType, shortName)
76 if !f.Object.Root {
77 res += fmt.Sprintf(", obj")
78 }
79 for _, arg := range f.Args {
80 res += fmt.Sprintf(", %s", arg.GoVarName)
81 }
82 res += ")"
83 return res
84}
85func (f *Field) ShortResolverDeclaration() string {
86 if !f.IsResolver() {
87 return ""
88 }
89 decl := strings.TrimPrefix(f.ResolverDeclaration(), f.Object.GQLType+"_")
90 return strings.ToUpper(decl[:1]) + decl[1:]
91}
92
93func (f *Field) ResolverDeclaration() string {
94 if !f.IsResolver() {
95 return ""
96 }
97 res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GQLName)
98
99 if !f.Object.Root {
100 res += fmt.Sprintf(", obj *%s", f.Object.FullName())
101 }
102 for _, arg := range f.Args {
103 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
104 }
105
106 result := f.Signature()
107 if f.Object.Stream {
108 result = "<-chan " + result
109 }
110
111 res += fmt.Sprintf(") (%s, error)", result)
112 return res
113}
114
115func (f *Field) CallArgs() string {
116 var args []string
117
118 if f.GoMethodName == "" {
119 args = append(args, "ctx")
120
121 if !f.Object.Root {
122 args = append(args, "obj")
123 }
124 }
125
126 for _, arg := range f.Args {
127 args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
128 }
129
130 return strings.Join(args, ", ")
131}
132
133// should be in the template, but its recursive and has a bunch of args
134func (f *Field) WriteJson() string {
135 return f.doWriteJson("res", f.Type.Modifiers, false, 1)
136}
137
138func (f *Field) doWriteJson(val string, remainingMods []string, isPtr bool, depth int) string {
139 switch {
140 case len(remainingMods) > 0 && remainingMods[0] == modPtr:
141 return fmt.Sprintf("if %s == nil { return graphql.Null }\n%s", val, f.doWriteJson(val, remainingMods[1:], true, depth+1))
142
143 case len(remainingMods) > 0 && remainingMods[0] == modList:
144 if isPtr {
145 val = "*" + val
146 }
147 var arr = "arr" + strconv.Itoa(depth)
148 var index = "idx" + strconv.Itoa(depth)
149
150 return tpl(`{{.arr}} := graphql.Array{}
151 for {{.index}} := range {{.val}} {
152 {{.arr}} = append({{.arr}}, func() graphql.Marshaler {
153 rctx := graphql.GetResolverContext(ctx)
154 rctx.PushIndex({{.index}})
155 defer rctx.Pop()
156 {{ .next }}
157 }())
158 }
159 return {{.arr}}`, map[string]interface{}{
160 "val": val,
161 "arr": arr,
162 "index": index,
163 "next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], false, depth+1),
164 })
165
166 case f.IsScalar:
167 if isPtr {
168 val = "*" + val
169 }
170 return f.Marshal(val)
171
172 default:
173 if !isPtr {
174 val = "&" + val
175 }
176 return fmt.Sprintf("return ec._%s(ctx, field.Selections, %s)", f.GQLType, val)
177 }
178}
179
180func (os Objects) ByName(name string) *Object {
181 for i, o := range os {
182 if strings.EqualFold(o.GQLType, name) {
183 return os[i]
184 }
185 }
186 return nil
187}
188
189func tpl(tpl string, vars map[string]interface{}) string {
190 b := &bytes.Buffer{}
191 err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
192 if err != nil {
193 panic(err)
194 }
195 return b.String()
196}
197
198func ucFirst(s string) string {
199 if s == "" {
200 return ""
201 }
202
203 r := []rune(s)
204 r[0] = unicode.ToUpper(r[0])
205 return string(r)
206}