walk.go

  1package validator
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/vektah/gqlparser/ast"
  8)
  9
 10type Events struct {
 11	operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
 12	field            []func(walker *Walker, field *ast.Field)
 13	fragment         []func(walker *Walker, fragment *ast.FragmentDefinition)
 14	inlineFragment   []func(walker *Walker, inlineFragment *ast.InlineFragment)
 15	fragmentSpread   []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
 16	directive        []func(walker *Walker, directive *ast.Directive)
 17	directiveList    []func(walker *Walker, directives []*ast.Directive)
 18	value            []func(walker *Walker, value *ast.Value)
 19}
 20
 21func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
 22	o.operationVisitor = append(o.operationVisitor, f)
 23}
 24func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
 25	o.field = append(o.field, f)
 26}
 27func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
 28	o.fragment = append(o.fragment, f)
 29}
 30func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
 31	o.inlineFragment = append(o.inlineFragment, f)
 32}
 33func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
 34	o.fragmentSpread = append(o.fragmentSpread, f)
 35}
 36func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
 37	o.directive = append(o.directive, f)
 38}
 39func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
 40	o.directiveList = append(o.directiveList, f)
 41}
 42func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
 43	o.value = append(o.value, f)
 44}
 45
 46func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
 47	w := Walker{
 48		Observers: observers,
 49		Schema:    schema,
 50		Document:  document,
 51	}
 52
 53	w.walk()
 54}
 55
 56type Walker struct {
 57	Context   context.Context
 58	Observers *Events
 59	Schema    *ast.Schema
 60	Document  *ast.QueryDocument
 61
 62	validatedFragmentSpreads map[string]bool
 63	CurrentOperation         *ast.OperationDefinition
 64}
 65
 66func (w *Walker) walk() {
 67	for _, child := range w.Document.Operations {
 68		w.validatedFragmentSpreads = make(map[string]bool)
 69		w.walkOperation(child)
 70	}
 71	for _, child := range w.Document.Fragments {
 72		w.validatedFragmentSpreads = make(map[string]bool)
 73		w.walkFragment(child)
 74	}
 75}
 76
 77func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
 78	w.CurrentOperation = operation
 79	for _, varDef := range operation.VariableDefinitions {
 80		varDef.Definition = w.Schema.Types[varDef.Type.Name()]
 81
 82		if varDef.DefaultValue != nil {
 83			varDef.DefaultValue.ExpectedType = varDef.Type
 84			varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
 85		}
 86	}
 87
 88	var def *ast.Definition
 89	var loc ast.DirectiveLocation
 90	switch operation.Operation {
 91	case ast.Query, "":
 92		def = w.Schema.Query
 93		loc = ast.LocationQuery
 94	case ast.Mutation:
 95		def = w.Schema.Mutation
 96		loc = ast.LocationMutation
 97	case ast.Subscription:
 98		def = w.Schema.Subscription
 99		loc = ast.LocationSubscription
100	}
101
102	w.walkDirectives(def, operation.Directives, loc)
103
104	for _, varDef := range operation.VariableDefinitions {
105		if varDef.DefaultValue != nil {
106			w.walkValue(varDef.DefaultValue)
107		}
108	}
109
110	w.walkSelectionSet(def, operation.SelectionSet)
111
112	for _, v := range w.Observers.operationVisitor {
113		v(w, operation)
114	}
115	w.CurrentOperation = nil
116}
117
118func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
119	def := w.Schema.Types[it.TypeCondition]
120
121	it.Definition = def
122
123	w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
124	w.walkSelectionSet(def, it.SelectionSet)
125
126	for _, v := range w.Observers.fragment {
127		v(w, it)
128	}
129}
130
131func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
132	for _, dir := range directives {
133		def := w.Schema.Directives[dir.Name]
134		dir.Definition = def
135		dir.ParentDefinition = parentDef
136		dir.Location = location
137
138		for _, arg := range dir.Arguments {
139			var argDef *ast.ArgumentDefinition
140			if def != nil {
141				argDef = def.Arguments.ForName(arg.Name)
142			}
143
144			w.walkArgument(argDef, arg)
145		}
146
147		for _, v := range w.Observers.directive {
148			v(w, dir)
149		}
150	}
151
152	for _, v := range w.Observers.directiveList {
153		v(w, directives)
154	}
155}
156
157func (w *Walker) walkValue(value *ast.Value) {
158	if value.Kind == ast.Variable && w.CurrentOperation != nil {
159		value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
160		if value.VariableDefinition != nil {
161			value.VariableDefinition.Used = true
162		}
163	}
164
165	if value.Kind == ast.ObjectValue {
166		for _, child := range value.Children {
167			if value.Definition != nil {
168				fieldDef := value.Definition.Fields.ForName(child.Name)
169				if fieldDef != nil {
170					child.Value.ExpectedType = fieldDef.Type
171					child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
172				}
173			}
174			w.walkValue(child.Value)
175		}
176	}
177
178	if value.Kind == ast.ListValue {
179		for _, child := range value.Children {
180			if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
181				child.Value.ExpectedType = value.ExpectedType.Elem
182				child.Value.Definition = value.Definition
183			}
184
185			w.walkValue(child.Value)
186		}
187	}
188
189	for _, v := range w.Observers.value {
190		v(w, value)
191	}
192}
193
194func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
195	if argDef != nil {
196		arg.Value.ExpectedType = argDef.Type
197		arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
198	}
199
200	w.walkValue(arg.Value)
201}
202
203func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
204	for _, child := range it {
205		w.walkSelection(parentDef, child)
206	}
207}
208
209func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
210	switch it := it.(type) {
211	case *ast.Field:
212		var def *ast.FieldDefinition
213		if it.Name == "__typename" {
214			def = &ast.FieldDefinition{
215				Name: "__typename",
216				Type: ast.NamedType("String", nil),
217			}
218		} else if parentDef != nil {
219			def = parentDef.Fields.ForName(it.Name)
220		}
221
222		it.Definition = def
223		it.ObjectDefinition = parentDef
224
225		var nextParentDef *ast.Definition
226		if def != nil {
227			nextParentDef = w.Schema.Types[def.Type.Name()]
228		}
229
230		for _, arg := range it.Arguments {
231			var argDef *ast.ArgumentDefinition
232			if def != nil {
233				argDef = def.Arguments.ForName(arg.Name)
234			}
235
236			w.walkArgument(argDef, arg)
237		}
238
239		w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
240		w.walkSelectionSet(nextParentDef, it.SelectionSet)
241
242		for _, v := range w.Observers.field {
243			v(w, it)
244		}
245
246	case *ast.InlineFragment:
247		it.ObjectDefinition = parentDef
248
249		nextParentDef := parentDef
250		if it.TypeCondition != "" {
251			nextParentDef = w.Schema.Types[it.TypeCondition]
252		}
253
254		w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
255		w.walkSelectionSet(nextParentDef, it.SelectionSet)
256
257		for _, v := range w.Observers.inlineFragment {
258			v(w, it)
259		}
260
261	case *ast.FragmentSpread:
262		def := w.Document.Fragments.ForName(it.Name)
263		it.Definition = def
264		it.ObjectDefinition = parentDef
265
266		var nextParentDef *ast.Definition
267		if def != nil {
268			nextParentDef = w.Schema.Types[def.TypeCondition]
269		}
270
271		w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
272
273		if def != nil && !w.validatedFragmentSpreads[def.Name] {
274			// prevent inifinite recursion
275			w.validatedFragmentSpreads[def.Name] = true
276			w.walkSelectionSet(nextParentDef, def.SelectionSet)
277		}
278
279		for _, v := range w.Observers.fragmentSpread {
280			v(w, it)
281		}
282
283	default:
284		panic(fmt.Errorf("unsupported %T", it))
285	}
286}