object.go

  1package codegen
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"strconv"
  7	"strings"
  8	"text/template"
  9	"unicode"
 10
 11	"github.com/vektah/gqlparser/ast"
 12)
 13
 14type GoFieldType int
 15
 16const (
 17	GoFieldUndefined GoFieldType = iota
 18	GoFieldMethod
 19	GoFieldVariable
 20)
 21
 22type Object struct {
 23	*NamedType
 24
 25	Fields             []Field
 26	Satisfies          []string
 27	ResolverInterface  *Ref
 28	Root               bool
 29	DisableConcurrency bool
 30	Stream             bool
 31}
 32
 33type Field struct {
 34	*Type
 35	Description    string          // Description of a field
 36	GQLName        string          // The name of the field in graphql
 37	GoFieldType    GoFieldType     // The field type in go, if any
 38	GoReceiverName string          // The name of method & var receiver in go, if any
 39	GoFieldName    string          // The name of the method or var in go, if any
 40	Args           []FieldArgument // A list of arguments to be passed to this field
 41	ForceResolver  bool            // Should be emit Resolver method
 42	NoErr          bool            // If this is bound to a go method, does that method have an error as the second argument
 43	Object         *Object         // A link back to the parent object
 44	Default        interface{}     // The default value
 45}
 46
 47type FieldArgument struct {
 48	*Type
 49
 50	GQLName   string      // The name of the argument in graphql
 51	GoVarName string      // The name of the var in go
 52	Object    *Object     // A link back to the parent object
 53	Default   interface{} // The default value
 54}
 55
 56type Objects []*Object
 57
 58func (o *Object) Implementors() string {
 59	satisfiedBy := strconv.Quote(o.GQLType)
 60	for _, s := range o.Satisfies {
 61		satisfiedBy += ", " + strconv.Quote(s)
 62	}
 63	return "[]string{" + satisfiedBy + "}"
 64}
 65
 66func (o *Object) HasResolvers() bool {
 67	for _, f := range o.Fields {
 68		if f.IsResolver() {
 69			return true
 70		}
 71	}
 72	return false
 73}
 74
 75func (o *Object) IsConcurrent() bool {
 76	for _, f := range o.Fields {
 77		if f.IsConcurrent() {
 78			return true
 79		}
 80	}
 81	return false
 82}
 83
 84func (o *Object) IsReserved() bool {
 85	return strings.HasPrefix(o.GQLType, "__")
 86}
 87
 88func (f *Field) IsResolver() bool {
 89	return f.GoFieldName == ""
 90}
 91
 92func (f *Field) IsReserved() bool {
 93	return strings.HasPrefix(f.GQLName, "__")
 94}
 95
 96func (f *Field) IsMethod() bool {
 97	return f.GoFieldType == GoFieldMethod
 98}
 99
100func (f *Field) IsVariable() bool {
101	return f.GoFieldType == GoFieldVariable
102}
103
104func (f *Field) IsConcurrent() bool {
105	return f.IsResolver() && !f.Object.DisableConcurrency
106}
107
108func (f *Field) GoNameExported() string {
109	return lintName(ucFirst(f.GQLName))
110}
111
112func (f *Field) GoNameUnexported() string {
113	return lintName(f.GQLName)
114}
115
116func (f *Field) ShortInvocation() string {
117	if !f.IsResolver() {
118		return ""
119	}
120
121	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
122}
123
124func (f *Field) ArgsFunc() string {
125	if len(f.Args) == 0 {
126		return ""
127	}
128
129	return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
130}
131
132func (f *Field) ResolverType() string {
133	if !f.IsResolver() {
134		return ""
135	}
136
137	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
138}
139
140func (f *Field) ShortResolverDeclaration() string {
141	if !f.IsResolver() {
142		return ""
143	}
144	res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
145
146	if !f.Object.Root {
147		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
148	}
149	for _, arg := range f.Args {
150		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
151	}
152
153	result := f.Signature()
154	if f.Object.Stream {
155		result = "<-chan " + result
156	}
157
158	res += fmt.Sprintf(") (%s, error)", result)
159	return res
160}
161
162func (f *Field) ResolverDeclaration() string {
163	if !f.IsResolver() {
164		return ""
165	}
166	res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
167
168	if !f.Object.Root {
169		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
170	}
171	for _, arg := range f.Args {
172		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
173	}
174
175	result := f.Signature()
176	if f.Object.Stream {
177		result = "<-chan " + result
178	}
179
180	res += fmt.Sprintf(") (%s, error)", result)
181	return res
182}
183
184func (f *Field) ComplexitySignature() string {
185	res := fmt.Sprintf("func(childComplexity int")
186	for _, arg := range f.Args {
187		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
188	}
189	res += ") int"
190	return res
191}
192
193func (f *Field) ComplexityArgs() string {
194	var args []string
195	for _, arg := range f.Args {
196		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
197	}
198
199	return strings.Join(args, ", ")
200}
201
202func (f *Field) CallArgs() string {
203	var args []string
204
205	if f.IsResolver() {
206		args = append(args, "ctx")
207
208		if !f.Object.Root {
209			args = append(args, "obj")
210		}
211	}
212
213	for _, arg := range f.Args {
214		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
215	}
216
217	return strings.Join(args, ", ")
218}
219
220// should be in the template, but its recursive and has a bunch of args
221func (f *Field) WriteJson() string {
222	return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
223}
224
225func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
226	switch {
227	case len(remainingMods) > 0 && remainingMods[0] == modPtr:
228		return tpl(`
229			if {{.val}} == nil {
230				{{- if .nonNull }}
231					if !ec.HasError(rctx) {
232						ec.Errorf(ctx, "must not be null")
233					}
234				{{- end }}
235				return graphql.Null
236			}
237			{{.next }}`, map[string]interface{}{
238			"val":     val,
239			"nonNull": astType.NonNull,
240			"next":    f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
241		})
242
243	case len(remainingMods) > 0 && remainingMods[0] == modList:
244		if isPtr {
245			val = "*" + val
246		}
247		var arr = "arr" + strconv.Itoa(depth)
248		var index = "idx" + strconv.Itoa(depth)
249		var usePtr bool
250		if len(remainingMods) == 1 && !isPtr {
251			usePtr = true
252		}
253
254		return tpl(`
255			{{.arr}} := make(graphql.Array, len({{.val}}))
256			{{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
257			{{ if not .isScalar }}
258				isLen1 := len({{.val}}) == 1
259				if !isLen1 {
260					wg.Add(len({{.val}}))
261				}
262			{{ end }}
263			for {{.index}} := range {{.val}} {
264				{{- if not .isScalar }}
265					{{.index}} := {{.index}}
266					rctx := &graphql.ResolverContext{
267						Index: &{{.index}},
268						Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
269					}
270					ctx := graphql.WithResolverContext(ctx, rctx)
271					f := func({{.index}} int) {
272						if !isLen1 {
273							defer wg.Done()
274						}
275						{{.arr}}[{{.index}}] = func() graphql.Marshaler {
276							{{ .next }}
277						}()
278					}
279					if isLen1 {
280						f({{.index}})
281					} else {
282						go f({{.index}})
283					}
284				{{ else }}
285					{{.arr}}[{{.index}}] = func() graphql.Marshaler {
286						{{ .next }}
287					}()
288				{{- end}}
289			}
290			{{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
291			return {{.arr}}`, map[string]interface{}{
292			"val":      val,
293			"arr":      arr,
294			"index":    index,
295			"top":      depth == 1,
296			"arrayLen": len(val),
297			"isScalar": f.IsScalar,
298			"usePtr":   usePtr,
299			"next":     f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
300		})
301
302	case f.IsScalar:
303		if isPtr {
304			val = "*" + val
305		}
306		return f.Marshal(val)
307
308	default:
309		if !isPtr {
310			val = "&" + val
311		}
312		return tpl(`
313			return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
314			"type": f.GQLType,
315			"val":  val,
316		})
317	}
318}
319
320func (f *FieldArgument) Stream() bool {
321	return f.Object != nil && f.Object.Stream
322}
323
324func (os Objects) ByName(name string) *Object {
325	for i, o := range os {
326		if strings.EqualFold(o.GQLType, name) {
327			return os[i]
328		}
329	}
330	return nil
331}
332
333func tpl(tpl string, vars map[string]interface{}) string {
334	b := &bytes.Buffer{}
335	err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
336	if err != nil {
337		panic(err)
338	}
339	return b.String()
340}
341
342func ucFirst(s string) string {
343	if s == "" {
344		return ""
345	}
346
347	r := []rune(s)
348	r[0] = unicode.ToUpper(r[0])
349	return string(r)
350}
351
352// copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
353
354// lintName returns a different name if it should be different.
355func lintName(name string) (should string) {
356	// Fast path for simple cases: "_" and all lowercase.
357	if name == "_" {
358		return name
359	}
360	allLower := true
361	for _, r := range name {
362		if !unicode.IsLower(r) {
363			allLower = false
364			break
365		}
366	}
367	if allLower {
368		return name
369	}
370
371	// Split camelCase at any lower->upper transition, and split on underscores.
372	// Check each word for common initialisms.
373	runes := []rune(name)
374	w, i := 0, 0 // index of start of word, scan
375	for i+1 <= len(runes) {
376		eow := false // whether we hit the end of a word
377		if i+1 == len(runes) {
378			eow = true
379		} else if runes[i+1] == '_' {
380			// underscore; shift the remainder forward over any run of underscores
381			eow = true
382			n := 1
383			for i+n+1 < len(runes) && runes[i+n+1] == '_' {
384				n++
385			}
386
387			// Leave at most one underscore if the underscore is between two digits
388			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
389				n--
390			}
391
392			copy(runes[i+1:], runes[i+n+1:])
393			runes = runes[:len(runes)-n]
394		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
395			// lower->non-lower
396			eow = true
397		}
398		i++
399		if !eow {
400			continue
401		}
402
403		// [w,i) is a word.
404		word := string(runes[w:i])
405		if u := strings.ToUpper(word); commonInitialisms[u] {
406			// Keep consistent case, which is lowercase only at the start.
407			if w == 0 && unicode.IsLower(runes[w]) {
408				u = strings.ToLower(u)
409			}
410			// All the common initialisms are ASCII,
411			// so we can replace the bytes exactly.
412			copy(runes[w:], []rune(u))
413		} else if w > 0 && strings.ToLower(word) == word {
414			// already all lowercase, and not the first word, so uppercase the first character.
415			runes[w] = unicode.ToUpper(runes[w])
416		}
417		w = i
418	}
419	return string(runes)
420}
421
422// commonInitialisms is a set of common initialisms.
423// Only add entries that are highly unlikely to be non-initialisms.
424// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
425var commonInitialisms = map[string]bool{
426	"ACL":   true,
427	"API":   true,
428	"ASCII": true,
429	"CPU":   true,
430	"CSS":   true,
431	"DNS":   true,
432	"EOF":   true,
433	"GUID":  true,
434	"HTML":  true,
435	"HTTP":  true,
436	"HTTPS": true,
437	"ID":    true,
438	"IP":    true,
439	"JSON":  true,
440	"LHS":   true,
441	"QPS":   true,
442	"RAM":   true,
443	"RHS":   true,
444	"RPC":   true,
445	"SLA":   true,
446	"SMTP":  true,
447	"SQL":   true,
448	"SSH":   true,
449	"TCP":   true,
450	"TLS":   true,
451	"TTL":   true,
452	"UDP":   true,
453	"UI":    true,
454	"UID":   true,
455	"UUID":  true,
456	"URI":   true,
457	"URL":   true,
458	"UTF8":  true,
459	"VM":    true,
460	"XML":   true,
461	"XMPP":  true,
462	"XSRF":  true,
463	"XSS":   true,
464}