exec.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/vektah/gqlparser/ast"
  8)
  9
 10type ExecutableSchema interface {
 11	Schema() *ast.Schema
 12
 13	Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
 14	Query(ctx context.Context, op *ast.OperationDefinition) *Response
 15	Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
 16	Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
 17}
 18
 19func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
 20	return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
 21}
 22
 23func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
 24	var groupedFields []CollectedField
 25
 26	for _, sel := range selSet {
 27		switch sel := sel.(type) {
 28		case *ast.Field:
 29			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 30				continue
 31			}
 32			f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
 33				return CollectedField{Field: sel}
 34			})
 35
 36			f.Selections = append(f.Selections, sel.SelectionSet...)
 37		case *ast.InlineFragment:
 38			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) || !instanceOf(sel.TypeCondition, satisfies) {
 39				continue
 40			}
 41			for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
 42				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
 43				f.Selections = append(f.Selections, childField.Selections...)
 44			}
 45
 46		case *ast.FragmentSpread:
 47			if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
 48				continue
 49			}
 50			fragmentName := sel.Name
 51			if _, seen := visited[fragmentName]; seen {
 52				continue
 53			}
 54			visited[fragmentName] = true
 55
 56			fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
 57			if fragment == nil {
 58				// should never happen, validator has already run
 59				panic(fmt.Errorf("missing fragment %s", fragmentName))
 60			}
 61
 62			if !instanceOf(fragment.TypeCondition, satisfies) {
 63				continue
 64			}
 65
 66			for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
 67				f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
 68				f.Selections = append(f.Selections, childField.Selections...)
 69			}
 70
 71		default:
 72			panic(fmt.Errorf("unsupported %T", sel))
 73		}
 74	}
 75
 76	return groupedFields
 77}
 78
 79type CollectedField struct {
 80	*ast.Field
 81
 82	Selections ast.SelectionSet
 83}
 84
 85func instanceOf(val string, satisfies []string) bool {
 86	for _, s := range satisfies {
 87		if val == s {
 88			return true
 89		}
 90	}
 91	return false
 92}
 93
 94func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
 95	for i, cf := range *c {
 96		if cf.Alias == name {
 97			return &(*c)[i]
 98		}
 99	}
100
101	f := creator()
102
103	*c = append(*c, f)
104	return &(*c)[len(*c)-1]
105}
106
107func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
108	skip, include := false, true
109
110	if d := directives.ForName("skip"); d != nil {
111		skip = resolveIfArgument(d, variables)
112	}
113
114	if d := directives.ForName("include"); d != nil {
115		include = resolveIfArgument(d, variables)
116	}
117
118	return !skip && include
119}
120
121func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
122	arg := d.Arguments.ForName("if")
123	if arg == nil {
124		panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
125	}
126	value, err := arg.Value.Value(variables)
127	if err != nil {
128		panic(err)
129	}
130	ret, ok := value.(bool)
131	if !ok {
132		panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
133	}
134	return ret
135}