1package graphql
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/ast"
8)
9
10type ExecutableSchema interface {
11 Schema() *ast.Schema
12
13 Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
14 Query(ctx context.Context, op *ast.OperationDefinition) *Response
15 Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
16 Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
17}
18
19func CollectFields(ctx context.Context, selSet ast.SelectionSet, satisfies []string) []CollectedField {
20 return collectFields(GetRequestContext(ctx), selSet, satisfies, map[string]bool{})
21}
22
23func collectFields(reqCtx *RequestContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
24 var groupedFields []CollectedField
25
26 for _, sel := range selSet {
27 switch sel := sel.(type) {
28 case *ast.Field:
29 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
30 continue
31 }
32 f := getOrCreateField(&groupedFields, sel.Alias, func() CollectedField {
33 return CollectedField{Field: sel}
34 })
35
36 f.Selections = append(f.Selections, sel.SelectionSet...)
37 case *ast.InlineFragment:
38 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) || !instanceOf(sel.TypeCondition, satisfies) {
39 continue
40 }
41 for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
42 f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
43 f.Selections = append(f.Selections, childField.Selections...)
44 }
45
46 case *ast.FragmentSpread:
47 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
48 continue
49 }
50 fragmentName := sel.Name
51 if _, seen := visited[fragmentName]; seen {
52 continue
53 }
54 visited[fragmentName] = true
55
56 fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
57 if fragment == nil {
58 // should never happen, validator has already run
59 panic(fmt.Errorf("missing fragment %s", fragmentName))
60 }
61
62 if !instanceOf(fragment.TypeCondition, satisfies) {
63 continue
64 }
65
66 for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
67 f := getOrCreateField(&groupedFields, childField.Name, func() CollectedField { return childField })
68 f.Selections = append(f.Selections, childField.Selections...)
69 }
70
71 default:
72 panic(fmt.Errorf("unsupported %T", sel))
73 }
74 }
75
76 return groupedFields
77}
78
79type CollectedField struct {
80 *ast.Field
81
82 Selections ast.SelectionSet
83}
84
85func instanceOf(val string, satisfies []string) bool {
86 for _, s := range satisfies {
87 if val == s {
88 return true
89 }
90 }
91 return false
92}
93
94func getOrCreateField(c *[]CollectedField, name string, creator func() CollectedField) *CollectedField {
95 for i, cf := range *c {
96 if cf.Alias == name {
97 return &(*c)[i]
98 }
99 }
100
101 f := creator()
102
103 *c = append(*c, f)
104 return &(*c)[len(*c)-1]
105}
106
107func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
108 skip, include := false, true
109
110 if d := directives.ForName("skip"); d != nil {
111 skip = resolveIfArgument(d, variables)
112 }
113
114 if d := directives.ForName("include"); d != nil {
115 include = resolveIfArgument(d, variables)
116 }
117
118 return !skip && include
119}
120
121func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
122 arg := d.Arguments.ForName("if")
123 if arg == nil {
124 panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
125 }
126 value, err := arg.Value.Value(variables)
127 if err != nil {
128 panic(err)
129 }
130 ret, ok := value.(bool)
131 if !ok {
132 panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
133 }
134 return ret
135}