1package codegen
2
3import (
4 "bytes"
5 "fmt"
6 "strconv"
7 "strings"
8 "text/template"
9 "unicode"
10
11 "github.com/vektah/gqlparser/ast"
12)
13
14type GoFieldType int
15
16const (
17 GoFieldUndefined GoFieldType = iota
18 GoFieldMethod
19 GoFieldVariable
20)
21
22type Object struct {
23 *NamedType
24
25 Fields []Field
26 Satisfies []string
27 Implements []*NamedType
28 ResolverInterface *Ref
29 Root bool
30 DisableConcurrency bool
31 Stream bool
32}
33
34type Field struct {
35 *Type
36 Description string // Description of a field
37 GQLName string // The name of the field in graphql
38 GoFieldType GoFieldType // The field type in go, if any
39 GoReceiverName string // The name of method & var receiver in go, if any
40 GoFieldName string // The name of the method or var in go, if any
41 Args []FieldArgument // A list of arguments to be passed to this field
42 ForceResolver bool // Should be emit Resolver method
43 MethodHasContext bool // If this is bound to a go method, does the method also take a context
44 NoErr bool // If this is bound to a go method, does that method have an error as the second argument
45 Object *Object // A link back to the parent object
46 Default interface{} // The default value
47}
48
49type FieldArgument struct {
50 *Type
51
52 GQLName string // The name of the argument in graphql
53 GoVarName string // The name of the var in go
54 Object *Object // A link back to the parent object
55 Default interface{} // The default value
56}
57
58type Objects []*Object
59
60func (o *Object) Implementors() string {
61 satisfiedBy := strconv.Quote(o.GQLType)
62 for _, s := range o.Satisfies {
63 satisfiedBy += ", " + strconv.Quote(s)
64 }
65 return "[]string{" + satisfiedBy + "}"
66}
67
68func (o *Object) HasResolvers() bool {
69 for _, f := range o.Fields {
70 if f.IsResolver() {
71 return true
72 }
73 }
74 return false
75}
76
77func (o *Object) IsConcurrent() bool {
78 for _, f := range o.Fields {
79 if f.IsConcurrent() {
80 return true
81 }
82 }
83 return false
84}
85
86func (o *Object) IsReserved() bool {
87 return strings.HasPrefix(o.GQLType, "__")
88}
89
90func (f *Field) IsResolver() bool {
91 return f.GoFieldName == ""
92}
93
94func (f *Field) IsReserved() bool {
95 return strings.HasPrefix(f.GQLName, "__")
96}
97
98func (f *Field) IsMethod() bool {
99 return f.GoFieldType == GoFieldMethod
100}
101
102func (f *Field) IsVariable() bool {
103 return f.GoFieldType == GoFieldVariable
104}
105
106func (f *Field) IsConcurrent() bool {
107 if f.Object.DisableConcurrency {
108 return false
109 }
110 return f.MethodHasContext || f.IsResolver()
111}
112
113func (f *Field) GoNameExported() string {
114 return lintName(ucFirst(f.GQLName))
115}
116
117func (f *Field) GoNameUnexported() string {
118 return lintName(f.GQLName)
119}
120
121func (f *Field) ShortInvocation() string {
122 if !f.IsResolver() {
123 return ""
124 }
125
126 return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
127}
128
129func (f *Field) ArgsFunc() string {
130 if len(f.Args) == 0 {
131 return ""
132 }
133
134 return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
135}
136
137func (f *Field) ResolverType() string {
138 if !f.IsResolver() {
139 return ""
140 }
141
142 return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
143}
144
145func (f *Field) ShortResolverDeclaration() string {
146 if !f.IsResolver() {
147 return ""
148 }
149 res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
150
151 if !f.Object.Root {
152 res += fmt.Sprintf(", obj *%s", f.Object.FullName())
153 }
154 for _, arg := range f.Args {
155 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
156 }
157
158 result := f.Signature()
159 if f.Object.Stream {
160 result = "<-chan " + result
161 }
162
163 res += fmt.Sprintf(") (%s, error)", result)
164 return res
165}
166
167func (f *Field) ResolverDeclaration() string {
168 if !f.IsResolver() {
169 return ""
170 }
171 res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
172
173 if !f.Object.Root {
174 res += fmt.Sprintf(", obj *%s", f.Object.FullName())
175 }
176 for _, arg := range f.Args {
177 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
178 }
179
180 result := f.Signature()
181 if f.Object.Stream {
182 result = "<-chan " + result
183 }
184
185 res += fmt.Sprintf(") (%s, error)", result)
186 return res
187}
188
189func (f *Field) ComplexitySignature() string {
190 res := fmt.Sprintf("func(childComplexity int")
191 for _, arg := range f.Args {
192 res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
193 }
194 res += ") int"
195 return res
196}
197
198func (f *Field) ComplexityArgs() string {
199 var args []string
200 for _, arg := range f.Args {
201 args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
202 }
203
204 return strings.Join(args, ", ")
205}
206
207func (f *Field) CallArgs() string {
208 var args []string
209
210 if f.IsResolver() {
211 args = append(args, "rctx")
212
213 if !f.Object.Root {
214 args = append(args, "obj")
215 }
216 } else {
217 if f.MethodHasContext {
218 args = append(args, "ctx")
219 }
220 }
221
222 for _, arg := range f.Args {
223 args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
224 }
225
226 return strings.Join(args, ", ")
227}
228
229// should be in the template, but its recursive and has a bunch of args
230func (f *Field) WriteJson() string {
231 return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
232}
233
234func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
235 switch {
236 case len(remainingMods) > 0 && remainingMods[0] == modPtr:
237 return tpl(`
238 if {{.val}} == nil {
239 {{- if .nonNull }}
240 if !ec.HasError(rctx) {
241 ec.Errorf(ctx, "must not be null")
242 }
243 {{- end }}
244 return graphql.Null
245 }
246 {{.next }}`, map[string]interface{}{
247 "val": val,
248 "nonNull": astType.NonNull,
249 "next": f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
250 })
251
252 case len(remainingMods) > 0 && remainingMods[0] == modList:
253 if isPtr {
254 val = "*" + val
255 }
256 var arr = "arr" + strconv.Itoa(depth)
257 var index = "idx" + strconv.Itoa(depth)
258 var usePtr bool
259 if len(remainingMods) == 1 && !isPtr {
260 usePtr = true
261 }
262
263 return tpl(`
264 {{.arr}} := make(graphql.Array, len({{.val}}))
265 {{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
266 {{ if not .isScalar }}
267 isLen1 := len({{.val}}) == 1
268 if !isLen1 {
269 wg.Add(len({{.val}}))
270 }
271 {{ end }}
272 for {{.index}} := range {{.val}} {
273 {{- if not .isScalar }}
274 {{.index}} := {{.index}}
275 rctx := &graphql.ResolverContext{
276 Index: &{{.index}},
277 Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
278 }
279 ctx := graphql.WithResolverContext(ctx, rctx)
280 f := func({{.index}} int) {
281 if !isLen1 {
282 defer wg.Done()
283 }
284 {{.arr}}[{{.index}}] = func() graphql.Marshaler {
285 {{ .next }}
286 }()
287 }
288 if isLen1 {
289 f({{.index}})
290 } else {
291 go f({{.index}})
292 }
293 {{ else }}
294 {{.arr}}[{{.index}}] = func() graphql.Marshaler {
295 {{ .next }}
296 }()
297 {{- end}}
298 }
299 {{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
300 return {{.arr}}`, map[string]interface{}{
301 "val": val,
302 "arr": arr,
303 "index": index,
304 "top": depth == 1,
305 "arrayLen": len(val),
306 "isScalar": f.IsScalar,
307 "usePtr": usePtr,
308 "next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
309 })
310
311 case f.IsScalar:
312 if isPtr {
313 val = "*" + val
314 }
315 return f.Marshal(val)
316
317 default:
318 if !isPtr {
319 val = "&" + val
320 }
321 return tpl(`
322 return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
323 "type": f.GQLType,
324 "val": val,
325 })
326 }
327}
328
329func (f *FieldArgument) Stream() bool {
330 return f.Object != nil && f.Object.Stream
331}
332
333func (os Objects) ByName(name string) *Object {
334 for i, o := range os {
335 if strings.EqualFold(o.GQLType, name) {
336 return os[i]
337 }
338 }
339 return nil
340}
341
342func tpl(tpl string, vars map[string]interface{}) string {
343 b := &bytes.Buffer{}
344 err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
345 if err != nil {
346 panic(err)
347 }
348 return b.String()
349}
350
351func ucFirst(s string) string {
352 if s == "" {
353 return ""
354 }
355
356 r := []rune(s)
357 r[0] = unicode.ToUpper(r[0])
358 return string(r)
359}
360
361// copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
362
363// lintName returns a different name if it should be different.
364func lintName(name string) (should string) {
365 // Fast path for simple cases: "_" and all lowercase.
366 if name == "_" {
367 return name
368 }
369 allLower := true
370 for _, r := range name {
371 if !unicode.IsLower(r) {
372 allLower = false
373 break
374 }
375 }
376 if allLower {
377 return name
378 }
379
380 // Split camelCase at any lower->upper transition, and split on underscores.
381 // Check each word for common initialisms.
382 runes := []rune(name)
383 w, i := 0, 0 // index of start of word, scan
384 for i+1 <= len(runes) {
385 eow := false // whether we hit the end of a word
386 if i+1 == len(runes) {
387 eow = true
388 } else if runes[i+1] == '_' {
389 // underscore; shift the remainder forward over any run of underscores
390 eow = true
391 n := 1
392 for i+n+1 < len(runes) && runes[i+n+1] == '_' {
393 n++
394 }
395
396 // Leave at most one underscore if the underscore is between two digits
397 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
398 n--
399 }
400
401 copy(runes[i+1:], runes[i+n+1:])
402 runes = runes[:len(runes)-n]
403 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
404 // lower->non-lower
405 eow = true
406 }
407 i++
408 if !eow {
409 continue
410 }
411
412 // [w,i) is a word.
413 word := string(runes[w:i])
414 if u := strings.ToUpper(word); commonInitialisms[u] {
415 // Keep consistent case, which is lowercase only at the start.
416 if w == 0 && unicode.IsLower(runes[w]) {
417 u = strings.ToLower(u)
418 }
419 // All the common initialisms are ASCII,
420 // so we can replace the bytes exactly.
421 copy(runes[w:], []rune(u))
422 } else if w > 0 && strings.ToLower(word) == word {
423 // already all lowercase, and not the first word, so uppercase the first character.
424 runes[w] = unicode.ToUpper(runes[w])
425 }
426 w = i
427 }
428 return string(runes)
429}
430
431// commonInitialisms is a set of common initialisms.
432// Only add entries that are highly unlikely to be non-initialisms.
433// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
434var commonInitialisms = map[string]bool{
435 "ACL": true,
436 "API": true,
437 "ASCII": true,
438 "CPU": true,
439 "CSS": true,
440 "DNS": true,
441 "EOF": true,
442 "GUID": true,
443 "HTML": true,
444 "HTTP": true,
445 "HTTPS": true,
446 "ID": true,
447 "IP": true,
448 "JSON": true,
449 "LHS": true,
450 "QPS": true,
451 "RAM": true,
452 "RHS": true,
453 "RPC": true,
454 "SLA": true,
455 "SMTP": true,
456 "SQL": true,
457 "SSH": true,
458 "TCP": true,
459 "TLS": true,
460 "TTL": true,
461 "UDP": true,
462 "UI": true,
463 "UID": true,
464 "UUID": true,
465 "URI": true,
466 "URL": true,
467 "UTF8": true,
468 "VM": true,
469 "XML": true,
470 "XMPP": true,
471 "XSRF": true,
472 "XSS": true,
473}