1package codegen
2
3import (
4 "bytes"
5 "fmt"
6 "strconv"
7 "strings"
8 "text/template"
9 "unicode"
10
11 "github.com/vektah/gqlparser/ast"
12)
13
14type GoFieldType int
15
16const (
17 GoFieldUndefined GoFieldType = iota
18 GoFieldMethod
19 GoFieldVariable
20)
21
22type Object struct {
23 *NamedType
24
25 Fields []Field
26 Satisfies []string
27 ResolverInterface *Ref
28 Root bool
29 DisableConcurrency bool
30 Stream bool
31}
32
33type Field struct {
34 *Type
35 Description string // Description of a field
36 GQLName string // The name of the field in graphql
37 GoFieldType GoFieldType // The field type in go, if any
38 GoReceiverName string // The name of method & var receiver in go, if any
39 GoFieldName string // The name of the method or var in go, if any
40 Args []FieldArgument // A list of arguments to be passed to this field
41 ForceResolver bool // Should be emit Resolver method
42 NoErr bool // If this is bound to a go method, does that method have an error as the second argument
43 Object *Object // A link back to the parent object
44 Default interface{} // The default value
45}
46
47type FieldArgument struct {
48 *Type
49
50 GQLName string // The name of the argument in graphql
51 GoVarName string // The name of the var in go
52 Object *Object // A link back to the parent object
53 Default interface{} // The default value
54}
55
56type Objects []*Object
57
58func (o *Object) Implementors() string {
59 satisfiedBy := strconv.Quote(o.GQLType)
60 for _, s := range o.Satisfies {
61 satisfiedBy += ", " + strconv.Quote(s)
62 }
63 return "[]string{" + satisfiedBy + "}"
64}
65
66func (o *Object) HasResolvers() bool {
67 for _, f := range o.Fields {
68 if f.IsResolver() {
69 return true
70 }
71 }
72 return false
73}
74
75func (o *Object) IsConcurrent() bool {
76 for _, f := range o.Fields {
77 if f.IsConcurrent() {
78 return true
79 }
80 }
81 return false
82}
83
84func (o *Object) IsReserved() bool {
85 return strings.HasPrefix(o.GQLType, "__")
86}
87
88func (f *Field) IsResolver() bool {
89 return f.GoFieldName == ""
90}
91
92func (f *Field) IsReserved() bool {
93 return strings.HasPrefix(f.GQLName, "__")
94}
95
96func (f *Field) IsMethod() bool {
97 return f.GoFieldType == GoFieldMethod
98}
99
100func (f *Field) IsVariable() bool {
101 return f.GoFieldType == GoFieldVariable
102}
103
104func (f *Field) IsConcurrent() bool {
105 return f.IsResolver() && !f.Object.DisableConcurrency
106}
107
108func (f *Field) GoNameExported() string {
109 return lintName(ucFirst(f.GQLName))
110}
111
112func (f *Field) GoNameUnexported() string {
113 return lintName(f.GQLName)
114}
115
116func (f *Field) ShortInvocation() string {
117 if !f.IsResolver() {
118 return ""
119 }
120
121 return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
122}
123
124func (f *Field) ArgsFunc() string {
125 if len(f.Args) == 0 {
126 return ""
127 }
128
129 return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
130}
131
132func (f *Field) ResolverType() string {
133 if !f.IsResolver() {
134 return ""
135 }
136
137 return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
138}
139
140func (f *Field) ShortResolverDeclaration() string {
141 if !f.IsResolver() {
142 return ""
143 }
144 res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
145
146 if !f.Object.Root {
147 res += fmt.Sprintf(", obj *%s", f.Object.FullName())
148 }
149 for _, arg := range f.Args {
150 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
151 }
152
153 result := f.Signature()
154 if f.Object.Stream {
155 result = "<-chan " + result
156 }
157
158 res += fmt.Sprintf(") (%s, error)", result)
159 return res
160}
161
162func (f *Field) ResolverDeclaration() string {
163 if !f.IsResolver() {
164 return ""
165 }
166 res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
167
168 if !f.Object.Root {
169 res += fmt.Sprintf(", obj *%s", f.Object.FullName())
170 }
171 for _, arg := range f.Args {
172 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
173 }
174
175 result := f.Signature()
176 if f.Object.Stream {
177 result = "<-chan " + result
178 }
179
180 res += fmt.Sprintf(") (%s, error)", result)
181 return res
182}
183
184func (f *Field) ComplexitySignature() string {
185 res := fmt.Sprintf("func(childComplexity int")
186 for _, arg := range f.Args {
187 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
188 }
189 res += ") int"
190 return res
191}
192
193func (f *Field) ComplexityArgs() string {
194 var args []string
195 for _, arg := range f.Args {
196 args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
197 }
198
199 return strings.Join(args, ", ")
200}
201
202func (f *Field) CallArgs() string {
203 var args []string
204
205 if f.IsResolver() {
206 args = append(args, "ctx")
207
208 if !f.Object.Root {
209 args = append(args, "obj")
210 }
211 }
212
213 for _, arg := range f.Args {
214 args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
215 }
216
217 return strings.Join(args, ", ")
218}
219
220// should be in the template, but its recursive and has a bunch of args
221func (f *Field) WriteJson() string {
222 return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
223}
224
225func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
226 switch {
227 case len(remainingMods) > 0 && remainingMods[0] == modPtr:
228 return tpl(`
229 if {{.val}} == nil {
230 {{- if .nonNull }}
231 if !ec.HasError(rctx) {
232 ec.Errorf(ctx, "must not be null")
233 }
234 {{- end }}
235 return graphql.Null
236 }
237 {{.next }}`, map[string]interface{}{
238 "val": val,
239 "nonNull": astType.NonNull,
240 "next": f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
241 })
242
243 case len(remainingMods) > 0 && remainingMods[0] == modList:
244 if isPtr {
245 val = "*" + val
246 }
247 var arr = "arr" + strconv.Itoa(depth)
248 var index = "idx" + strconv.Itoa(depth)
249 var usePtr bool
250 if len(remainingMods) == 1 && !isPtr {
251 usePtr = true
252 }
253
254 return tpl(`
255 {{.arr}} := make(graphql.Array, len({{.val}}))
256 {{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
257 {{ if not .isScalar }}
258 isLen1 := len({{.val}}) == 1
259 if !isLen1 {
260 wg.Add(len({{.val}}))
261 }
262 {{ end }}
263 for {{.index}} := range {{.val}} {
264 {{- if not .isScalar }}
265 {{.index}} := {{.index}}
266 rctx := &graphql.ResolverContext{
267 Index: &{{.index}},
268 Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
269 }
270 ctx := graphql.WithResolverContext(ctx, rctx)
271 f := func({{.index}} int) {
272 if !isLen1 {
273 defer wg.Done()
274 }
275 {{.arr}}[{{.index}}] = func() graphql.Marshaler {
276 {{ .next }}
277 }()
278 }
279 if isLen1 {
280 f({{.index}})
281 } else {
282 go f({{.index}})
283 }
284 {{ else }}
285 {{.arr}}[{{.index}}] = func() graphql.Marshaler {
286 {{ .next }}
287 }()
288 {{- end}}
289 }
290 {{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
291 return {{.arr}}`, map[string]interface{}{
292 "val": val,
293 "arr": arr,
294 "index": index,
295 "top": depth == 1,
296 "arrayLen": len(val),
297 "isScalar": f.IsScalar,
298 "usePtr": usePtr,
299 "next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
300 })
301
302 case f.IsScalar:
303 if isPtr {
304 val = "*" + val
305 }
306 return f.Marshal(val)
307
308 default:
309 if !isPtr {
310 val = "&" + val
311 }
312 return tpl(`
313 return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
314 "type": f.GQLType,
315 "val": val,
316 })
317 }
318}
319
320func (f *FieldArgument) Stream() bool {
321 return f.Object != nil && f.Object.Stream
322}
323
324func (os Objects) ByName(name string) *Object {
325 for i, o := range os {
326 if strings.EqualFold(o.GQLType, name) {
327 return os[i]
328 }
329 }
330 return nil
331}
332
333func tpl(tpl string, vars map[string]interface{}) string {
334 b := &bytes.Buffer{}
335 err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
336 if err != nil {
337 panic(err)
338 }
339 return b.String()
340}
341
342func ucFirst(s string) string {
343 if s == "" {
344 return ""
345 }
346
347 r := []rune(s)
348 r[0] = unicode.ToUpper(r[0])
349 return string(r)
350}
351
352// copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
353
354// lintName returns a different name if it should be different.
355func lintName(name string) (should string) {
356 // Fast path for simple cases: "_" and all lowercase.
357 if name == "_" {
358 return name
359 }
360 allLower := true
361 for _, r := range name {
362 if !unicode.IsLower(r) {
363 allLower = false
364 break
365 }
366 }
367 if allLower {
368 return name
369 }
370
371 // Split camelCase at any lower->upper transition, and split on underscores.
372 // Check each word for common initialisms.
373 runes := []rune(name)
374 w, i := 0, 0 // index of start of word, scan
375 for i+1 <= len(runes) {
376 eow := false // whether we hit the end of a word
377 if i+1 == len(runes) {
378 eow = true
379 } else if runes[i+1] == '_' {
380 // underscore; shift the remainder forward over any run of underscores
381 eow = true
382 n := 1
383 for i+n+1 < len(runes) && runes[i+n+1] == '_' {
384 n++
385 }
386
387 // Leave at most one underscore if the underscore is between two digits
388 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
389 n--
390 }
391
392 copy(runes[i+1:], runes[i+n+1:])
393 runes = runes[:len(runes)-n]
394 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
395 // lower->non-lower
396 eow = true
397 }
398 i++
399 if !eow {
400 continue
401 }
402
403 // [w,i) is a word.
404 word := string(runes[w:i])
405 if u := strings.ToUpper(word); commonInitialisms[u] {
406 // Keep consistent case, which is lowercase only at the start.
407 if w == 0 && unicode.IsLower(runes[w]) {
408 u = strings.ToLower(u)
409 }
410 // All the common initialisms are ASCII,
411 // so we can replace the bytes exactly.
412 copy(runes[w:], []rune(u))
413 } else if w > 0 && strings.ToLower(word) == word {
414 // already all lowercase, and not the first word, so uppercase the first character.
415 runes[w] = unicode.ToUpper(runes[w])
416 }
417 w = i
418 }
419 return string(runes)
420}
421
422// commonInitialisms is a set of common initialisms.
423// Only add entries that are highly unlikely to be non-initialisms.
424// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
425var commonInitialisms = map[string]bool{
426 "ACL": true,
427 "API": true,
428 "ASCII": true,
429 "CPU": true,
430 "CSS": true,
431 "DNS": true,
432 "EOF": true,
433 "GUID": true,
434 "HTML": true,
435 "HTTP": true,
436 "HTTPS": true,
437 "ID": true,
438 "IP": true,
439 "JSON": true,
440 "LHS": true,
441 "QPS": true,
442 "RAM": true,
443 "RHS": true,
444 "RPC": true,
445 "SLA": true,
446 "SMTP": true,
447 "SQL": true,
448 "SSH": true,
449 "TCP": true,
450 "TLS": true,
451 "TTL": true,
452 "UDP": true,
453 "UI": true,
454 "UID": true,
455 "UUID": true,
456 "URI": true,
457 "URL": true,
458 "UTF8": true,
459 "VM": true,
460 "XML": true,
461 "XMPP": true,
462 "XSRF": true,
463 "XSS": true,
464}