exec.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/vektah/gqlparser/ast"
  8)
  9
 10type ExecutableSchema interface {
 11	Schema() *ast.Schema
 12
 13	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
 14	Query(ctx context.Context, op *ast.OperationDefinition) *Response
 15	Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
 16	Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
 17}
 18
 19// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
 20// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
 21// type conditions.
 22func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
 23	return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
 24}
 25
 26func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
 27	var groupedFields []CollectedField
 28
 29	for _, sel := range selSet {
 30		switch sel := sel.(type) {
 31		case *ast.Field:
 32			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 33				continue
 34			}
 35			f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
 36				return CollectedField{Field: sel}
 37			})
 38
 39			f.Selections = append(f.Selections, sel.SelectionSet...)
 40		case *ast.InlineFragment:
 41			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 42				continue
 43			}
 44			if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
 45				continue
 46			}
 47			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
 48				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
 49				f.Selections = append(f.Selections, childField.Selections...)
 50			}
 51
 52		case *ast.FragmentSpread:
 53			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 54				continue
 55			}
 56			fragmentName := sel.Name
 57			if _, seen := visited[fragmentName]; seen {
 58				continue
 59			}
 60			visited[fragmentName] = true
 61
 62			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
 63			if fragment == nil {
 64				// should never happen, validator has already run
 65				panic(fmt.Errorf("missing fragment %s", fragmentName))
 66			}
 67
 68			if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
 69				continue
 70			}
 71
 72			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
 73				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
 74				f.Selections = append(f.Selections, childField.Selections...)
 75			}
 76
 77		default:
 78			panic(fmt.Errorf("unsupported %T", sel))
 79		}
 80	}
 81
 82	return groupedFields
 83}
 84
 85type CollectedField struct {
 86	*ast.Field
 87
 88	Selections ast.SelectionSet
 89}
 90
 91func instanceOf(val string, satisfies []string) bool {
 92	for _, s := range satisfies {
 93		if val == s {
 94			return true
 95		}
 96	}
 97	return false
 98}
 99
100func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
101	for i, cf := range *c {
102		if cf.Alias == name {
103			return &(*c)[i]
104		}
105	}
106
107	f := creator()
108
109	*c = append(*c, f)
110	return &(*c)[len(*c)-1]
111}
112
113func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
114	skip, include := false, true
115
116	if d := directives.ForName("skip"); d != nil {
117		skip = resolveIfArgument(d, variables)
118	}
119
120	if d := directives.ForName("include"); d != nil {
121		include = resolveIfArgument(d, variables)
122	}
123
124	return !skip && include
125}
126
127func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
128	arg := d.Arguments.ForName("if")
129	if arg == nil {
130		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
131	}
132	value, err := arg.Value.Value(variables)
133	if err != nil {
134		panic(err)
135	}
136	ret, ok := value.(bool)
137	if !ok {
138		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
139	}
140	return ret
141}