1package graphql
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/ast"
8)
9
10type ExecutableSchema interface {
11 Schema() *ast.Schema
12
13 Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
14 Query(ctx context.Context, op *ast.OperationDefinition) *Response
15 Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
16 Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
17}
18
19// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
20// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
21// type conditions.
22func CollectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
23 return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
24}
25
26func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
27 groupedFields := make([]CollectedField, 0, len(selSet))
28
29 for _, sel := range selSet {
30 switch sel := sel.(type) {
31 case *ast.Field:
32 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
33 continue
34 }
35 f := getOrCreateAndAppendField(&groupedFields, sel.Alias, func() CollectedField {
36 return CollectedField{Field: sel}
37 })
38
39 f.Selections = append(f.Selections, sel.SelectionSet...)
40 case *ast.InlineFragment:
41 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
42 continue
43 }
44 if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
45 continue
46 }
47 for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
48 f := getOrCreateAndAppendField(&groupedFields, childField.Name, func() CollectedField { return childField })
49 f.Selections = append(f.Selections, childField.Selections...)
50 }
51
52 case *ast.FragmentSpread:
53 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
54 continue
55 }
56 fragmentName := sel.Name
57 if _, seen := visited[fragmentName]; seen {
58 continue
59 }
60 visited[fragmentName] = true
61
62 fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
63 if fragment == nil {
64 // should never happen, validator has already run
65 panic(fmt.Errorf("missing fragment %s", fragmentName))
66 }
67
68 if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
69 continue
70 }
71
72 for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
73 f := getOrCreateAndAppendField(&groupedFields, childField.Name, func() CollectedField { return childField })
74 f.Selections = append(f.Selections, childField.Selections...)
75 }
76 default:
77 panic(fmt.Errorf("unsupported %T", sel))
78 }
79 }
80
81 return groupedFields
82}
83
84type CollectedField struct {
85 *ast.Field
86
87 Selections ast.SelectionSet
88}
89
90func instanceOf(val string, satisfies []string) bool {
91 for _, s := range satisfies {
92 if val == s {
93 return true
94 }
95 }
96 return false
97}
98
99func getOrCreateAndAppendField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
100 for i, cf := range *c {
101 if cf.Alias == name {
102 return &(*c)[i]
103 }
104 }
105
106 f := creator()
107
108 *c = append(*c, f)
109 return &(*c)[len(*c)-1]
110}
111
112func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
113 if len(directives) == 0 {
114 return true
115 }
116
117 skip, include := false, true
118
119 if d := directives.ForName("skip"); d != nil {
120 skip = resolveIfArgument(d, variables)
121 }
122
123 if d := directives.ForName("include"); d != nil {
124 include = resolveIfArgument(d, variables)
125 }
126
127 return !skip && include
128}
129
130func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
131 arg := d.Arguments.ForName("if")
132 if arg == nil {
133 panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
134 }
135 value, err := arg.Value.Value(variables)
136 if err != nil {
137 panic(err)
138 }
139 ret, ok := value.(bool)
140 if !ok {
141 panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
142 }
143 return ret
144}