exec.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/vektah/gqlparser/ast"
  8)
  9
 10type ExecutableSchema interface {
 11	Schema() *ast.Schema
 12
 13	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
 14	Query(ctx context.Context, op *ast.OperationDefinition) *Response
 15	Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
 16	Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
 17}
 18
 19// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
 20// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
 21// type conditions.
 22func CollectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
 23	return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
 24}
 25
 26func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
 27	groupedFields := make([]CollectedField, 0, len(selSet))
 28
 29	for _, sel := range selSet {
 30		switch sel := sel.(type) {
 31		case *ast.Field:
 32			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 33				continue
 34			}
 35			f := getOrCreateAndAppendField(&groupedFields, sel.Alias, func() CollectedField {
 36				return CollectedField{Field: sel}
 37			})
 38
 39			f.Selections = append(f.Selections, sel.SelectionSet...)
 40		case *ast.InlineFragment:
 41			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 42				continue
 43			}
 44			if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
 45				continue
 46			}
 47			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
 48				f := getOrCreateAndAppendField(&groupedFields, childField.Name, func() CollectedField { return childField })
 49				f.Selections = append(f.Selections, childField.Selections...)
 50			}
 51
 52		case *ast.FragmentSpread:
 53			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 54				continue
 55			}
 56			fragmentName := sel.Name
 57			if _, seen := visited[fragmentName]; seen {
 58				continue
 59			}
 60			visited[fragmentName] = true
 61
 62			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
 63			if fragment == nil {
 64				// should never happen, validator has already run
 65				panic(fmt.Errorf("missing fragment %s", fragmentName))
 66			}
 67
 68			if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
 69				continue
 70			}
 71
 72			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
 73				f := getOrCreateAndAppendField(&groupedFields, childField.Name, func() CollectedField { return childField })
 74				f.Selections = append(f.Selections, childField.Selections...)
 75			}
 76		default:
 77			panic(fmt.Errorf("unsupported %T", sel))
 78		}
 79	}
 80
 81	return groupedFields
 82}
 83
 84type CollectedField struct {
 85	*ast.Field
 86
 87	Selections ast.SelectionSet
 88}
 89
 90func instanceOf(val string, satisfies []string) bool {
 91	for _, s := range satisfies {
 92		if val == s {
 93			return true
 94		}
 95	}
 96	return false
 97}
 98
 99func getOrCreateAndAppendField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
100	for i, cf := range *c {
101		if cf.Alias == name {
102			return &(*c)[i]
103		}
104	}
105
106	f := creator()
107
108	*c = append(*c, f)
109	return &(*c)[len(*c)-1]
110}
111
112func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
113	if len(directives) == 0 {
114		return true
115	}
116
117	skip, include := false, true
118
119	if d := directives.ForName("skip"); d != nil {
120		skip = resolveIfArgument(d, variables)
121	}
122
123	if d := directives.ForName("include"); d != nil {
124		include = resolveIfArgument(d, variables)
125	}
126
127	return !skip && include
128}
129
130func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
131	arg := d.Arguments.ForName("if")
132	if arg == nil {
133		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
134	}
135	value, err := arg.Value.Value(variables)
136	if err != nil {
137		panic(err)
138	}
139	ret, ok := value.(bool)
140	if !ok {
141		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
142	}
143	return ret
144}