1package graphql
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/ast"
8)
9
10type ExecutableSchema interface {
11 Schema() *ast.Schema
12
13 Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
14 Query(ctx context.Context, op *ast.OperationDefinition) *Response
15 Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
16 Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
17}
18
19// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
20// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
21// type conditions.
22func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
23 return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
24}
25
26func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
27 var groupedFields []CollectedField
28
29 for _, sel := range selSet {
30 switch sel := sel.(type) {
31 case *ast.Field:
32 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
33 continue
34 }
35 f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
36 return CollectedField{Field: sel}
37 })
38
39 f.Selections = append(f.Selections, sel.SelectionSet...)
40 case *ast.InlineFragment:
41 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
42 continue
43 }
44 if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
45 continue
46 }
47 for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
48 f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
49 f.Selections = append(f.Selections, childField.Selections...)
50 }
51
52 case *ast.FragmentSpread:
53 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
54 continue
55 }
56 fragmentName := sel.Name
57 if _, seen := visited[fragmentName]; seen {
58 continue
59 }
60 visited[fragmentName] = true
61
62 fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
63 if fragment == nil {
64 // should never happen, validator has already run
65 panic(fmt.Errorf("missing fragment %s", fragmentName))
66 }
67
68 if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
69 continue
70 }
71
72 for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
73 f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
74 f.Selections = append(f.Selections, childField.Selections...)
75 }
76
77 default:
78 panic(fmt.Errorf("unsupported %T", sel))
79 }
80 }
81
82 return groupedFields
83}
84
85type CollectedField struct {
86 *ast.Field
87
88 Selections ast.SelectionSet
89}
90
91func instanceOf(val string, satisfies []string) bool {
92 for _, s := range satisfies {
93 if val == s {
94 return true
95 }
96 }
97 return false
98}
99
100func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
101 for i, cf := range *c {
102 if cf.Alias == name {
103 return &(*c)[i]
104 }
105 }
106
107 f := creator()
108
109 *c = append(*c, f)
110 return &(*c)[len(*c)-1]
111}
112
113func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
114 skip, include := false, true
115
116 if d := directives.ForName("skip"); d != nil {
117 skip = resolveIfArgument(d, variables)
118 }
119
120 if d := directives.ForName("include"); d != nil {
121 include = resolveIfArgument(d, variables)
122 }
123
124 return !skip && include
125}
126
127func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
128 arg := d.Arguments.ForName("if")
129 if arg == nil {
130 panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
131 }
132 value, err := arg.Value.Value(variables)
133 if err != nil {
134 panic(err)
135 }
136 ret, ok := value.(bool)
137 if !ok {
138 panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
139 }
140 return ret
141}