1package visitor
2
3import (
4 "encoding/json"
5 "fmt"
6 "github.com/graphql-go/graphql/language/ast"
7 "github.com/graphql-go/graphql/language/typeInfo"
8 "reflect"
9)
10
11const (
12 ActionNoChange = ""
13 ActionBreak = "BREAK"
14 ActionSkip = "SKIP"
15 ActionUpdate = "UPDATE"
16)
17
18type KeyMap map[string][]string
19
20// note that the keys are in Capital letters, equivalent to the ast.Node field Names
21var QueryDocumentKeys = KeyMap{
22 "Name": []string{},
23 "Document": []string{"Definitions"},
24 "OperationDefinition": []string{
25 "Name",
26 "VariableDefinitions",
27 "Directives",
28 "SelectionSet",
29 },
30 "VariableDefinition": []string{
31 "Variable",
32 "Type",
33 "DefaultValue",
34 },
35 "Variable": []string{"Name"},
36 "SelectionSet": []string{"Selections"},
37 "Field": []string{
38 "Alias",
39 "Name",
40 "Arguments",
41 "Directives",
42 "SelectionSet",
43 },
44 "Argument": []string{
45 "Name",
46 "Value",
47 },
48
49 "FragmentSpread": []string{
50 "Name",
51 "Directives",
52 },
53 "InlineFragment": []string{
54 "TypeCondition",
55 "Directives",
56 "SelectionSet",
57 },
58 "FragmentDefinition": []string{
59 "Name",
60 "TypeCondition",
61 "Directives",
62 "SelectionSet",
63 },
64
65 "IntValue": []string{},
66 "FloatValue": []string{},
67 "StringValue": []string{},
68 "BooleanValue": []string{},
69 "EnumValue": []string{},
70 "ListValue": []string{"Values"},
71 "ObjectValue": []string{"Fields"},
72 "ObjectField": []string{
73 "Name",
74 "Value",
75 },
76
77 "Directive": []string{
78 "Name",
79 "Arguments",
80 },
81
82 "Named": []string{"Name"},
83 "List": []string{"Type"},
84 "NonNull": []string{"Type"},
85
86 "SchemaDefinition": []string{
87 "Directives",
88 "OperationTypes",
89 },
90 "OperationTypeDefinition": []string{"Type"},
91
92 "ScalarDefinition": []string{
93 "Name",
94 "Directives",
95 },
96 "ObjectDefinition": []string{
97 "Name",
98 "Interfaces",
99 "Directives",
100 "Fields",
101 },
102 "FieldDefinition": []string{
103 "Name",
104 "Arguments",
105 "Type",
106 "Directives",
107 },
108 "InputValueDefinition": []string{
109 "Name",
110 "Type",
111 "DefaultValue",
112 "Directives",
113 },
114 "InterfaceDefinition": []string{
115 "Name",
116 "Directives",
117 "Fields",
118 },
119 "UnionDefinition": []string{
120 "Name",
121 "Directives",
122 "Types",
123 },
124 "EnumDefinition": []string{
125 "Name",
126 "Directives",
127 "Values",
128 },
129 "EnumValueDefinition": []string{
130 "Name",
131 "Directives",
132 },
133 "InputObjectDefinition": []string{
134 "Name",
135 "Directives",
136 "Fields",
137 },
138
139 "TypeExtensionDefinition": []string{"Definition"},
140
141 "DirectiveDefinition": []string{"Name", "Arguments", "Locations"},
142}
143
144type stack struct {
145 Index int
146 Keys []interface{}
147 Edits []*edit
148 inSlice bool
149 Prev *stack
150}
151type edit struct {
152 Key interface{}
153 Value interface{}
154}
155
156type VisitFuncParams struct {
157 Node interface{}
158 Key interface{}
159 Parent ast.Node
160 Path []interface{}
161 Ancestors []ast.Node
162}
163
164type VisitFunc func(p VisitFuncParams) (string, interface{})
165
166type NamedVisitFuncs struct {
167 Kind VisitFunc // 1) Named visitors triggered when entering a node a specific kind.
168 Leave VisitFunc // 2) Named visitors that trigger upon entering and leaving a node of
169 Enter VisitFunc // 2) Named visitors that trigger upon entering and leaving a node of
170}
171
172type VisitorOptions struct {
173 KindFuncMap map[string]NamedVisitFuncs
174 Enter VisitFunc // 3) Generic visitors that trigger upon entering and leaving any node.
175 Leave VisitFunc // 3) Generic visitors that trigger upon entering and leaving any node.
176
177 EnterKindMap map[string]VisitFunc // 4) Parallel visitors for entering and leaving nodes of a specific kind
178 LeaveKindMap map[string]VisitFunc // 4) Parallel visitors for entering and leaving nodes of a specific kind
179}
180
181func Visit(root ast.Node, visitorOpts *VisitorOptions, keyMap KeyMap) interface{} {
182 visitorKeys := keyMap
183 if visitorKeys == nil {
184 visitorKeys = QueryDocumentKeys
185 }
186
187 var result interface{}
188 var newRoot = root
189 var sstack *stack
190 var parent interface{}
191 var parentSlice []interface{}
192 inSlice := false
193 prevInSlice := false
194 keys := []interface{}{newRoot}
195 index := -1
196 edits := []*edit{}
197 path := []interface{}{}
198 ancestors := []interface{}{}
199 ancestorsSlice := [][]interface{}{}
200Loop:
201 for {
202 index = index + 1
203
204 isLeaving := (len(keys) == index)
205 var key interface{} // string for structs or int for slices
206 var node interface{} // ast.Node or can be anything
207 var nodeSlice []interface{}
208 isEdited := (isLeaving && len(edits) != 0)
209
210 if isLeaving {
211 if !inSlice {
212 if len(ancestors) == 0 {
213 key = nil
214 } else {
215 key, path = pop(path)
216 }
217 } else {
218 if len(ancestorsSlice) == 0 {
219 key = nil
220 } else {
221 key, path = pop(path)
222 }
223 }
224
225 node = parent
226 parent, ancestors = pop(ancestors)
227 nodeSlice = parentSlice
228 parentSlice, ancestorsSlice = popNodeSlice(ancestorsSlice)
229
230 if isEdited {
231 prevInSlice = inSlice
232 editOffset := 0
233 for _, edit := range edits {
234 arrayEditKey := 0
235 if inSlice {
236 keyInt := edit.Key.(int)
237 edit.Key = keyInt - editOffset
238 arrayEditKey = edit.Key.(int)
239 }
240 if inSlice && isNilNode(edit.Value) {
241 nodeSlice = spliceNode(nodeSlice, arrayEditKey)
242 editOffset = editOffset + 1
243 } else {
244 if inSlice {
245 nodeSlice[arrayEditKey] = edit.Value
246 } else {
247 key, _ := edit.Key.(string)
248
249 var updatedNode interface{}
250 if !isSlice(edit.Value) {
251 if isStructNode(edit.Value) {
252 updatedNode = updateNodeField(node, key, edit.Value)
253 } else {
254 var todoNode map[string]interface{}
255 b, err := json.Marshal(node)
256 if err != nil {
257 panic(fmt.Sprintf("Invalid root AST Node: %v", root))
258 }
259 err = json.Unmarshal(b, &todoNode)
260 if err != nil {
261 panic(fmt.Sprintf("Invalid root AST Node (2): %v", root))
262 }
263 todoNode[key] = edit.Value
264 updatedNode = todoNode
265 }
266 } else {
267 isSliceOfNodes := true
268
269 // check if edit.value slice is ast.nodes
270 switch reflect.TypeOf(edit.Value).Kind() {
271 case reflect.Slice:
272 s := reflect.ValueOf(edit.Value)
273 for i := 0; i < s.Len(); i++ {
274 elem := s.Index(i)
275 if !isStructNode(elem.Interface()) {
276 isSliceOfNodes = false
277 }
278 }
279 }
280
281 // is a slice of real nodes
282 if isSliceOfNodes {
283 // the node we are writing to is an ast.Node
284 updatedNode = updateNodeField(node, key, edit.Value)
285 } else {
286 var todoNode map[string]interface{}
287 b, err := json.Marshal(node)
288 if err != nil {
289 panic(fmt.Sprintf("Invalid root AST Node: %v", root))
290 }
291 err = json.Unmarshal(b, &todoNode)
292 if err != nil {
293 panic(fmt.Sprintf("Invalid root AST Node (2): %v", root))
294 }
295 todoNode[key] = edit.Value
296 updatedNode = todoNode
297 }
298
299 }
300 node = updatedNode
301 }
302 }
303 }
304 }
305 index = sstack.Index
306 keys = sstack.Keys
307 edits = sstack.Edits
308 inSlice = sstack.inSlice
309 sstack = sstack.Prev
310 } else {
311 // get key
312 if !inSlice {
313 if !isNilNode(parent) {
314 key = getFieldValue(keys, index)
315 } else {
316 // initial conditions
317 key = nil
318 }
319 } else {
320 key = index
321 }
322 // get node
323 if !inSlice {
324 if !isNilNode(parent) {
325 fieldValue := getFieldValue(parent, key)
326 if isNode(fieldValue) {
327 node = fieldValue.(ast.Node)
328 }
329 if isSlice(fieldValue) {
330 nodeSlice = toSliceInterfaces(fieldValue)
331 }
332 } else {
333 // initial conditions
334 node = newRoot
335 }
336 } else {
337 if len(parentSlice) != 0 {
338 fieldValue := getFieldValue(parentSlice, key)
339 if isNode(fieldValue) {
340 node = fieldValue.(ast.Node)
341 }
342 if isSlice(fieldValue) {
343 nodeSlice = toSliceInterfaces(fieldValue)
344 }
345 } else {
346 // initial conditions
347 nodeSlice = []interface{}{}
348 }
349 }
350
351 if isNilNode(node) && len(nodeSlice) == 0 {
352 continue
353 }
354
355 if !inSlice {
356 if !isNilNode(parent) {
357 path = append(path, key)
358 }
359 } else {
360 if len(parentSlice) != 0 {
361 path = append(path, key)
362 }
363 }
364 }
365
366 // get result from visitFn for a node if set
367 var result interface{}
368 resultIsUndefined := true
369 if !isNilNode(node) {
370 if !isNode(node) { // is node-ish.
371 panic(fmt.Sprintf("Invalid AST Node (4): %v", node))
372 }
373
374 // Try to pass in current node as ast.Node
375 // Note that since user can potentially return a non-ast.Node from visit functions.
376 // In that case, we try to unmarshal map[string]interface{} into ast.Node
377 var nodeIn interface{}
378 if _, ok := node.(map[string]interface{}); ok {
379 b, err := json.Marshal(node)
380 if err != nil {
381 panic(fmt.Sprintf("Invalid root AST Node: %v", root))
382 }
383 err = json.Unmarshal(b, &nodeIn)
384 if err != nil {
385 panic(fmt.Sprintf("Invalid root AST Node (2a): %v", root))
386 }
387 } else {
388 nodeIn = node
389 }
390 parentConcrete, _ := parent.(ast.Node)
391 // ancestorsConcrete slice may contain nil values
392 ancestorsConcrete := []ast.Node{}
393 for _, ancestor := range ancestors {
394 if ancestorConcrete, ok := ancestor.(ast.Node); ok {
395 ancestorsConcrete = append(ancestorsConcrete, ancestorConcrete)
396 } else {
397 ancestorsConcrete = append(ancestorsConcrete, nil)
398 }
399 }
400
401 kind := ""
402 if node, ok := node.(map[string]interface{}); ok {
403 kind, _ = node["Kind"].(string)
404 }
405 if node, ok := node.(ast.Node); ok {
406 kind = node.GetKind()
407 }
408
409 visitFn := GetVisitFn(visitorOpts, kind, isLeaving)
410 if visitFn != nil {
411 p := VisitFuncParams{
412 Node: nodeIn,
413 Key: key,
414 Parent: parentConcrete,
415 Path: path,
416 Ancestors: ancestorsConcrete,
417 }
418 action := ActionUpdate
419 action, result = visitFn(p)
420 if action == ActionBreak {
421 break Loop
422 }
423 if action == ActionSkip {
424 if !isLeaving {
425 _, path = pop(path)
426 continue
427 }
428 }
429 if action != ActionNoChange {
430 resultIsUndefined = false
431 edits = append(edits, &edit{
432 Key: key,
433 Value: result,
434 })
435 if !isLeaving {
436 if isNode(result) {
437 node = result
438 } else {
439 _, path = pop(path)
440 continue
441 }
442 }
443 } else {
444 resultIsUndefined = true
445 }
446 }
447
448 }
449
450 // collect back edits on the way out
451 if resultIsUndefined && isEdited {
452 if !prevInSlice {
453 edits = append(edits, &edit{
454 Key: key,
455 Value: node,
456 })
457 } else {
458 edits = append(edits, &edit{
459 Key: key,
460 Value: nodeSlice,
461 })
462 }
463 }
464 if !isLeaving {
465
466 // add to stack
467 prevStack := sstack
468 sstack = &stack{
469 inSlice: inSlice,
470 Index: index,
471 Keys: keys,
472 Edits: edits,
473 Prev: prevStack,
474 }
475
476 // replace keys
477 inSlice = false
478 if len(nodeSlice) > 0 {
479 inSlice = true
480 }
481 keys = []interface{}{}
482
483 if inSlice {
484 // get keys
485 for _, m := range nodeSlice {
486 keys = append(keys, m)
487 }
488 } else {
489 if !isNilNode(node) {
490 if node, ok := node.(ast.Node); ok {
491 kind := node.GetKind()
492 if n, ok := visitorKeys[kind]; ok {
493 for _, m := range n {
494 keys = append(keys, m)
495 }
496 }
497 }
498
499 }
500
501 }
502 index = -1
503 edits = []*edit{}
504
505 ancestors = append(ancestors, parent)
506 parent = node
507 ancestorsSlice = append(ancestorsSlice, parentSlice)
508 parentSlice = nodeSlice
509
510 }
511
512 // loop guard
513 if sstack == nil {
514 break Loop
515 }
516 }
517 if len(edits) != 0 {
518 result = edits[len(edits)-1].Value
519 }
520 return result
521}
522
523func pop(a []interface{}) (x interface{}, aa []interface{}) {
524 if len(a) == 0 {
525 return x, aa
526 }
527 x, aa = a[len(a)-1], a[:len(a)-1]
528 return x, aa
529}
530func popNodeSlice(a [][]interface{}) (x []interface{}, aa [][]interface{}) {
531 if len(a) == 0 {
532 return x, aa
533 }
534 x, aa = a[len(a)-1], a[:len(a)-1]
535 return x, aa
536}
537func spliceNode(a interface{}, i int) (result []interface{}) {
538 if i < 0 {
539 return result
540 }
541 typeOf := reflect.TypeOf(a)
542 if typeOf == nil {
543 return result
544 }
545 switch typeOf.Kind() {
546 case reflect.Slice:
547 s := reflect.ValueOf(a)
548 for i := 0; i < s.Len(); i++ {
549 elem := s.Index(i)
550 elemInterface := elem.Interface()
551 result = append(result, elemInterface)
552 }
553 if i >= s.Len() {
554 return result
555 }
556 return append(result[:i], result[i+1:]...)
557 default:
558 return result
559 }
560}
561
562func getFieldValue(obj interface{}, key interface{}) interface{} {
563 val := reflect.ValueOf(obj)
564 if val.Type().Kind() == reflect.Ptr {
565 val = val.Elem()
566 }
567 if val.Type().Kind() == reflect.Struct {
568 key, ok := key.(string)
569 if !ok {
570 return nil
571 }
572 valField := val.FieldByName(key)
573 if valField.IsValid() {
574 return valField.Interface()
575 }
576 return nil
577 }
578 if val.Type().Kind() == reflect.Slice {
579 key, ok := key.(int)
580 if !ok {
581 return nil
582 }
583 if key >= val.Len() {
584 return nil
585 }
586 valField := val.Index(key)
587 if valField.IsValid() {
588 return valField.Interface()
589 }
590 return nil
591 }
592 if val.Type().Kind() == reflect.Map {
593 keyVal := reflect.ValueOf(key)
594 valField := val.MapIndex(keyVal)
595 if valField.IsValid() {
596 return valField.Interface()
597 }
598 return nil
599 }
600 return nil
601}
602
603func updateNodeField(value interface{}, fieldName string, fieldValue interface{}) (retVal interface{}) {
604 retVal = value
605 val := reflect.ValueOf(value)
606
607 isPtr := false
608 if val.IsValid() && val.Type().Kind() == reflect.Ptr {
609 val = val.Elem()
610 isPtr = true
611 }
612 if !val.IsValid() {
613 return retVal
614 }
615 if val.Type().Kind() == reflect.Struct {
616 for i := 0; i < val.NumField(); i++ {
617 valueField := val.Field(i)
618 typeField := val.Type().Field(i)
619
620 // try matching the field name
621 if typeField.Name == fieldName {
622 fieldValueVal := reflect.ValueOf(fieldValue)
623 if valueField.CanSet() {
624
625 if fieldValueVal.IsValid() {
626 if valueField.Type().Kind() == fieldValueVal.Type().Kind() {
627 if fieldValueVal.Type().Kind() == reflect.Slice {
628 newSliceValue := reflect.MakeSlice(reflect.TypeOf(valueField.Interface()), fieldValueVal.Len(), fieldValueVal.Len())
629 for i := 0; i < newSliceValue.Len(); i++ {
630 dst := newSliceValue.Index(i)
631 src := fieldValueVal.Index(i)
632 srcValue := reflect.ValueOf(src.Interface())
633 if dst.CanSet() {
634 dst.Set(srcValue)
635 }
636 }
637 valueField.Set(newSliceValue)
638
639 } else {
640 valueField.Set(fieldValueVal)
641 }
642 }
643 } else {
644 valueField.Set(reflect.New(valueField.Type()).Elem())
645 }
646 if isPtr == true {
647 retVal = val.Addr().Interface()
648 return retVal
649 }
650 retVal = val.Interface()
651 return retVal
652
653 }
654 }
655 }
656 }
657 return retVal
658}
659func toSliceInterfaces(slice interface{}) (result []interface{}) {
660 switch reflect.TypeOf(slice).Kind() {
661 case reflect.Slice:
662 s := reflect.ValueOf(slice)
663 for i := 0; i < s.Len(); i++ {
664 elem := s.Index(i)
665 elemInterface := elem.Interface()
666 if elem, ok := elemInterface.(ast.Node); ok {
667 result = append(result, elem)
668 }
669 }
670 return result
671 default:
672 return result
673 }
674}
675
676func isSlice(value interface{}) bool {
677 val := reflect.ValueOf(value)
678 if val.IsValid() && val.Type().Kind() == reflect.Slice {
679 return true
680 }
681 return false
682}
683func isNode(node interface{}) bool {
684 val := reflect.ValueOf(node)
685 if val.IsValid() && val.Type().Kind() == reflect.Ptr {
686 val = val.Elem()
687 }
688 if !val.IsValid() {
689 return false
690 }
691 if val.Type().Kind() == reflect.Map {
692 keyVal := reflect.ValueOf("Kind")
693 valField := val.MapIndex(keyVal)
694 return valField.IsValid()
695 }
696 if val.Type().Kind() == reflect.Struct {
697 valField := val.FieldByName("Kind")
698 return valField.IsValid()
699 }
700 return false
701}
702func isStructNode(node interface{}) bool {
703 val := reflect.ValueOf(node)
704 if val.IsValid() && val.Type().Kind() == reflect.Ptr {
705 val = val.Elem()
706 }
707 if !val.IsValid() {
708 return false
709 }
710 if val.Type().Kind() == reflect.Struct {
711 valField := val.FieldByName("Kind")
712 return valField.IsValid()
713 }
714 return false
715}
716
717func isNilNode(node interface{}) bool {
718 val := reflect.ValueOf(node)
719 if !val.IsValid() {
720 return true
721 }
722 if val.Type().Kind() == reflect.Ptr {
723 return val.IsNil()
724 }
725 if val.Type().Kind() == reflect.Slice {
726 return val.Len() == 0
727 }
728 if val.Type().Kind() == reflect.Map {
729 return val.Len() == 0
730 }
731 if val.Type().Kind() == reflect.Bool {
732 return val.Interface().(bool)
733 }
734 return val.Interface() == nil
735}
736
737// VisitInParallel Creates a new visitor instance which delegates to many visitors to run in
738// parallel. Each visitor will be visited for each node before moving on.
739//
740// If a prior visitor edits a node, no following visitors will see that node.
741func VisitInParallel(visitorOptsSlice ...*VisitorOptions) *VisitorOptions {
742 skipping := map[int]interface{}{}
743
744 return &VisitorOptions{
745 Enter: func(p VisitFuncParams) (string, interface{}) {
746 for i, visitorOpts := range visitorOptsSlice {
747 if _, ok := skipping[i]; !ok {
748 switch node := p.Node.(type) {
749 case ast.Node:
750 kind := node.GetKind()
751 fn := GetVisitFn(visitorOpts, kind, false)
752 if fn != nil {
753 action, result := fn(p)
754 if action == ActionSkip {
755 skipping[i] = node
756 } else if action == ActionBreak {
757 skipping[i] = ActionBreak
758 } else if action == ActionUpdate {
759 return ActionUpdate, result
760 }
761 }
762 }
763 }
764 }
765 return ActionNoChange, nil
766 },
767 Leave: func(p VisitFuncParams) (string, interface{}) {
768 for i, visitorOpts := range visitorOptsSlice {
769 skippedNode, ok := skipping[i]
770 if !ok {
771 switch node := p.Node.(type) {
772 case ast.Node:
773 kind := node.GetKind()
774 fn := GetVisitFn(visitorOpts, kind, true)
775 if fn != nil {
776 action, result := fn(p)
777 if action == ActionBreak {
778 skipping[i] = ActionBreak
779 } else if action == ActionUpdate {
780 return ActionUpdate, result
781 }
782 }
783 }
784 } else if skippedNode == p.Node {
785 delete(skipping, i)
786 }
787 }
788 return ActionNoChange, nil
789 },
790 }
791}
792
793// VisitWithTypeInfo Creates a new visitor instance which maintains a provided TypeInfo instance
794// along with visiting visitor.
795func VisitWithTypeInfo(ttypeInfo typeInfo.TypeInfoI, visitorOpts *VisitorOptions) *VisitorOptions {
796 return &VisitorOptions{
797 Enter: func(p VisitFuncParams) (string, interface{}) {
798 if node, ok := p.Node.(ast.Node); ok {
799 ttypeInfo.Enter(node)
800 fn := GetVisitFn(visitorOpts, node.GetKind(), false)
801 if fn != nil {
802 action, result := fn(p)
803 if action == ActionUpdate {
804 ttypeInfo.Leave(node)
805 if isNode(result) {
806 if result, ok := result.(ast.Node); ok {
807 ttypeInfo.Enter(result)
808 }
809 }
810 }
811 return action, result
812 }
813 }
814 return ActionNoChange, nil
815 },
816 Leave: func(p VisitFuncParams) (string, interface{}) {
817 action := ActionNoChange
818 var result interface{}
819 if node, ok := p.Node.(ast.Node); ok {
820 fn := GetVisitFn(visitorOpts, node.GetKind(), true)
821 if fn != nil {
822 action, result = fn(p)
823 }
824 ttypeInfo.Leave(node)
825 }
826 return action, result
827 },
828 }
829}
830
831// GetVisitFn Given a visitor instance, if it is leaving or not, and a node kind, return
832// the function the visitor runtime should call.
833func GetVisitFn(visitorOpts *VisitorOptions, kind string, isLeaving bool) VisitFunc {
834 if visitorOpts == nil {
835 return nil
836 }
837 kindVisitor, ok := visitorOpts.KindFuncMap[kind]
838 if ok {
839 if !isLeaving && kindVisitor.Kind != nil {
840 // { Kind() {} }
841 return kindVisitor.Kind
842 }
843 if isLeaving {
844 // { Kind: { leave() {} } }
845 return kindVisitor.Leave
846 }
847 // { Kind: { enter() {} } }
848 return kindVisitor.Enter
849
850 }
851 if isLeaving {
852 // { enter() {} }
853 specificVisitor := visitorOpts.Leave
854 if specificVisitor != nil {
855 return specificVisitor
856 }
857 if specificKindVisitor, ok := visitorOpts.LeaveKindMap[kind]; ok {
858 // { leave: { Kind() {} } }
859 return specificKindVisitor
860 }
861
862 }
863 // { leave() {} }
864 specificVisitor := visitorOpts.Enter
865 if specificVisitor != nil {
866 return specificVisitor
867 }
868 if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok {
869 // { enter: { Kind() {} } }
870 return specificKindVisitor
871 }
872 return nil
873}