object.go

  1package codegen
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"strconv"
  7	"strings"
  8	"text/template"
  9	"unicode"
 10
 11	"github.com/vektah/gqlparser/ast"
 12)
 13
 14type GoFieldType int
 15
 16const (
 17	GoFieldUndefined GoFieldType = iota
 18	GoFieldMethod
 19	GoFieldVariable
 20)
 21
 22type Object struct {
 23	*NamedType
 24
 25	Fields             []Field
 26	Satisfies          []string
 27	Implements         []*NamedType
 28	ResolverInterface  *Ref
 29	Root               bool
 30	DisableConcurrency bool
 31	Stream             bool
 32}
 33
 34type Field struct {
 35	*Type
 36	Description      string          // Description of a field
 37	GQLName          string          // The name of the field in graphql
 38	GoFieldType      GoFieldType     // The field type in go, if any
 39	GoReceiverName   string          // The name of method & var receiver in go, if any
 40	GoFieldName      string          // The name of the method or var in go, if any
 41	Args             []FieldArgument // A list of arguments to be passed to this field
 42	ForceResolver    bool            // Should be emit Resolver method
 43	MethodHasContext bool            // If this is bound to a go method, does the method also take a context
 44	NoErr            bool            // If this is bound to a go method, does that method have an error as the second argument
 45	Object           *Object         // A link back to the parent object
 46	Default          interface{}     // The default value
 47}
 48
 49type FieldArgument struct {
 50	*Type
 51
 52	GQLName   string      // The name of the argument in graphql
 53	GoVarName string      // The name of the var in go
 54	Object    *Object     // A link back to the parent object
 55	Default   interface{} // The default value
 56}
 57
 58type Objects []*Object
 59
 60func (o *Object) Implementors() string {
 61	satisfiedBy := strconv.Quote(o.GQLType)
 62	for _, s := range o.Satisfies {
 63		satisfiedBy += ", " + strconv.Quote(s)
 64	}
 65	return "[]string{" + satisfiedBy + "}"
 66}
 67
 68func (o *Object) HasResolvers() bool {
 69	for _, f := range o.Fields {
 70		if f.IsResolver() {
 71			return true
 72		}
 73	}
 74	return false
 75}
 76
 77func (o *Object) IsConcurrent() bool {
 78	for _, f := range o.Fields {
 79		if f.IsConcurrent() {
 80			return true
 81		}
 82	}
 83	return false
 84}
 85
 86func (o *Object) IsReserved() bool {
 87	return strings.HasPrefix(o.GQLType, "__")
 88}
 89
 90func (f *Field) IsResolver() bool {
 91	return f.GoFieldName == ""
 92}
 93
 94func (f *Field) IsReserved() bool {
 95	return strings.HasPrefix(f.GQLName, "__")
 96}
 97
 98func (f *Field) IsMethod() bool {
 99	return f.GoFieldType == GoFieldMethod
100}
101
102func (f *Field) IsVariable() bool {
103	return f.GoFieldType == GoFieldVariable
104}
105
106func (f *Field) IsConcurrent() bool {
107	if f.Object.DisableConcurrency {
108		return false
109	}
110	return f.MethodHasContext || f.IsResolver()
111}
112
113func (f *Field) GoNameExported() string {
114	return lintName(ucFirst(f.GQLName))
115}
116
117func (f *Field) GoNameUnexported() string {
118	return lintName(f.GQLName)
119}
120
121func (f *Field) ShortInvocation() string {
122	if !f.IsResolver() {
123		return ""
124	}
125
126	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
127}
128
129func (f *Field) ArgsFunc() string {
130	if len(f.Args) == 0 {
131		return ""
132	}
133
134	return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
135}
136
137func (f *Field) ResolverType() string {
138	if !f.IsResolver() {
139		return ""
140	}
141
142	return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
143}
144
145func (f *Field) ShortResolverDeclaration() string {
146	if !f.IsResolver() {
147		return ""
148	}
149	res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
150
151	if !f.Object.Root {
152		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
153	}
154	for _, arg := range f.Args {
155		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
156	}
157
158	result := f.Signature()
159	if f.Object.Stream {
160		result = "<-chan " + result
161	}
162
163	res += fmt.Sprintf(") (%s, error)", result)
164	return res
165}
166
167func (f *Field) ResolverDeclaration() string {
168	if !f.IsResolver() {
169		return ""
170	}
171	res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
172
173	if !f.Object.Root {
174		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
175	}
176	for _, arg := range f.Args {
177		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
178	}
179
180	result := f.Signature()
181	if f.Object.Stream {
182		result = "<-chan " + result
183	}
184
185	res += fmt.Sprintf(") (%s, error)", result)
186	return res
187}
188
189func (f *Field) ComplexitySignature() string {
190	res := fmt.Sprintf("func(childComplexity int")
191	for _, arg := range f.Args {
192		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
193	}
194	res += ") int"
195	return res
196}
197
198func (f *Field) ComplexityArgs() string {
199	var args []string
200	for _, arg := range f.Args {
201		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
202	}
203
204	return strings.Join(args, ", ")
205}
206
207func (f *Field) CallArgs() string {
208	var args []string
209
210	if f.IsResolver() {
211		args = append(args, "rctx")
212
213		if !f.Object.Root {
214			args = append(args, "obj")
215		}
216	} else {
217		if f.MethodHasContext {
218			args = append(args, "ctx")
219		}
220	}
221
222	for _, arg := range f.Args {
223		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
224	}
225
226	return strings.Join(args, ", ")
227}
228
229// should be in the template, but its recursive and has a bunch of args
230func (f *Field) WriteJson() string {
231	return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
232}
233
234func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
235	switch {
236	case len(remainingMods) > 0 && remainingMods[0] == modPtr:
237		return tpl(`
238			if {{.val}} == nil {
239				{{- if .nonNull }}
240					if !ec.HasError(rctx) {
241						ec.Errorf(ctx, "must not be null")
242					}
243				{{- end }}
244				return graphql.Null
245			}
246			{{.next }}`, map[string]interface{}{
247			"val":     val,
248			"nonNull": astType.NonNull,
249			"next":    f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
250		})
251
252	case len(remainingMods) > 0 && remainingMods[0] == modList:
253		if isPtr {
254			val = "*" + val
255		}
256		var arr = "arr" + strconv.Itoa(depth)
257		var index = "idx" + strconv.Itoa(depth)
258		var usePtr bool
259		if len(remainingMods) == 1 && !isPtr {
260			usePtr = true
261		}
262
263		return tpl(`
264			{{.arr}} := make(graphql.Array, len({{.val}}))
265			{{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
266			{{ if not .isScalar }}
267				isLen1 := len({{.val}}) == 1
268				if !isLen1 {
269					wg.Add(len({{.val}}))
270				}
271			{{ end }}
272			for {{.index}} := range {{.val}} {
273				{{- if not .isScalar }}
274					{{.index}} := {{.index}}
275					rctx := &graphql.ResolverContext{
276						Index: &{{.index}},
277						Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
278					}
279					ctx := graphql.WithResolverContext(ctx, rctx)
280					f := func({{.index}} int) {
281						if !isLen1 {
282							defer wg.Done()
283						}
284						{{.arr}}[{{.index}}] = func() graphql.Marshaler {
285							{{ .next }}
286						}()
287					}
288					if isLen1 {
289						f({{.index}})
290					} else {
291						go f({{.index}})
292					}
293				{{ else }}
294					{{.arr}}[{{.index}}] = func() graphql.Marshaler {
295						{{ .next }}
296					}()
297				{{- end}}
298			}
299			{{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
300			return {{.arr}}`, map[string]interface{}{
301			"val":      val,
302			"arr":      arr,
303			"index":    index,
304			"top":      depth == 1,
305			"arrayLen": len(val),
306			"isScalar": f.IsScalar,
307			"usePtr":   usePtr,
308			"next":     f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
309		})
310
311	case f.IsScalar:
312		if isPtr {
313			val = "*" + val
314		}
315		return f.Marshal(val)
316
317	default:
318		if !isPtr {
319			val = "&" + val
320		}
321		return tpl(`
322			return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
323			"type": f.GQLType,
324			"val":  val,
325		})
326	}
327}
328
329func (f *FieldArgument) Stream() bool {
330	return f.Object != nil && f.Object.Stream
331}
332
333func (os Objects) ByName(name string) *Object {
334	for i, o := range os {
335		if strings.EqualFold(o.GQLType, name) {
336			return os[i]
337		}
338	}
339	return nil
340}
341
342func tpl(tpl string, vars map[string]interface{}) string {
343	b := &bytes.Buffer{}
344	err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
345	if err != nil {
346		panic(err)
347	}
348	return b.String()
349}
350
351func ucFirst(s string) string {
352	if s == "" {
353		return ""
354	}
355
356	r := []rune(s)
357	r[0] = unicode.ToUpper(r[0])
358	return string(r)
359}
360
361// copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
362
363// lintName returns a different name if it should be different.
364func lintName(name string) (should string) {
365	// Fast path for simple cases: "_" and all lowercase.
366	if name == "_" {
367		return name
368	}
369	allLower := true
370	for _, r := range name {
371		if !unicode.IsLower(r) {
372			allLower = false
373			break
374		}
375	}
376	if allLower {
377		return name
378	}
379
380	// Split camelCase at any lower->upper transition, and split on underscores.
381	// Check each word for common initialisms.
382	runes := []rune(name)
383	w, i := 0, 0 // index of start of word, scan
384	for i+1 <= len(runes) {
385		eow := false // whether we hit the end of a word
386		if i+1 == len(runes) {
387			eow = true
388		} else if runes[i+1] == '_' {
389			// underscore; shift the remainder forward over any run of underscores
390			eow = true
391			n := 1
392			for i+n+1 < len(runes) && runes[i+n+1] == '_' {
393				n++
394			}
395
396			// Leave at most one underscore if the underscore is between two digits
397			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
398				n--
399			}
400
401			copy(runes[i+1:], runes[i+n+1:])
402			runes = runes[:len(runes)-n]
403		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
404			// lower->non-lower
405			eow = true
406		}
407		i++
408		if !eow {
409			continue
410		}
411
412		// [w,i) is a word.
413		word := string(runes[w:i])
414		if u := strings.ToUpper(word); commonInitialisms[u] {
415			// Keep consistent case, which is lowercase only at the start.
416			if w == 0 && unicode.IsLower(runes[w]) {
417				u = strings.ToLower(u)
418			}
419			// All the common initialisms are ASCII,
420			// so we can replace the bytes exactly.
421			copy(runes[w:], []rune(u))
422		} else if w > 0 && strings.ToLower(word) == word {
423			// already all lowercase, and not the first word, so uppercase the first character.
424			runes[w] = unicode.ToUpper(runes[w])
425		}
426		w = i
427	}
428	return string(runes)
429}
430
431// commonInitialisms is a set of common initialisms.
432// Only add entries that are highly unlikely to be non-initialisms.
433// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
434var commonInitialisms = map[string]bool{
435	"ACL":   true,
436	"API":   true,
437	"ASCII": true,
438	"CPU":   true,
439	"CSS":   true,
440	"DNS":   true,
441	"EOF":   true,
442	"GUID":  true,
443	"HTML":  true,
444	"HTTP":  true,
445	"HTTPS": true,
446	"ID":    true,
447	"IP":    true,
448	"JSON":  true,
449	"LHS":   true,
450	"QPS":   true,
451	"RAM":   true,
452	"RHS":   true,
453	"RPC":   true,
454	"SLA":   true,
455	"SMTP":  true,
456	"SQL":   true,
457	"SSH":   true,
458	"TCP":   true,
459	"TLS":   true,
460	"TTL":   true,
461	"UDP":   true,
462	"UI":    true,
463	"UID":   true,
464	"UUID":  true,
465	"URI":   true,
466	"URL":   true,
467	"UTF8":  true,
468	"VM":    true,
469	"XML":   true,
470	"XMPP":  true,
471	"XSRF":  true,
472	"XSS":   true,
473}