1// Copyright (c) 2016, Daniel MartΓ <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4package syntax
5
6import (
7 "fmt"
8 "io"
9 "reflect"
10)
11
12// Walk traverses a syntax tree in depth-first order: It starts by calling
13// f(node); node must not be nil. If f returns true, Walk invokes f
14// recursively for each of the non-nil children of node, followed by
15// f(nil).
16func Walk(node Node, f func(Node) bool) {
17 if !f(node) {
18 return
19 }
20
21 switch node := node.(type) {
22 case *File:
23 walkList(node.Stmts, f)
24 walkComments(node.Last, f)
25 case *Comment:
26 case *Stmt:
27 for _, c := range node.Comments {
28 if !node.End().After(c.Pos()) {
29 defer Walk(&c, f)
30 break
31 }
32 Walk(&c, f)
33 }
34 if node.Cmd != nil {
35 Walk(node.Cmd, f)
36 }
37 walkList(node.Redirs, f)
38 case *Assign:
39 walkNilable(node.Name, f)
40 walkNilable(node.Value, f)
41 walkNilable(node.Index, f)
42 walkNilable(node.Array, f)
43 case *Redirect:
44 walkNilable(node.N, f)
45 Walk(node.Word, f)
46 walkNilable(node.Hdoc, f)
47 case *CallExpr:
48 walkList(node.Assigns, f)
49 walkList(node.Args, f)
50 case *Subshell:
51 walkList(node.Stmts, f)
52 walkComments(node.Last, f)
53 case *Block:
54 walkList(node.Stmts, f)
55 walkComments(node.Last, f)
56 case *IfClause:
57 walkList(node.Cond, f)
58 walkComments(node.CondLast, f)
59 walkList(node.Then, f)
60 walkComments(node.ThenLast, f)
61 walkNilable(node.Else, f)
62 case *WhileClause:
63 walkList(node.Cond, f)
64 walkComments(node.CondLast, f)
65 walkList(node.Do, f)
66 walkComments(node.DoLast, f)
67 case *ForClause:
68 Walk(node.Loop, f)
69 walkList(node.Do, f)
70 walkComments(node.DoLast, f)
71 case *WordIter:
72 Walk(node.Name, f)
73 walkList(node.Items, f)
74 case *CStyleLoop:
75 walkNilable(node.Init, f)
76 walkNilable(node.Cond, f)
77 walkNilable(node.Post, f)
78 case *BinaryCmd:
79 Walk(node.X, f)
80 Walk(node.Y, f)
81 case *FuncDecl:
82 Walk(node.Name, f)
83 Walk(node.Body, f)
84 case *Word:
85 walkList(node.Parts, f)
86 case *Lit:
87 case *SglQuoted:
88 case *DblQuoted:
89 walkList(node.Parts, f)
90 case *CmdSubst:
91 walkList(node.Stmts, f)
92 walkComments(node.Last, f)
93 case *ParamExp:
94 Walk(node.Param, f)
95 walkNilable(node.Index, f)
96 if node.Repl != nil {
97 walkNilable(node.Repl.Orig, f)
98 walkNilable(node.Repl.With, f)
99 }
100 if node.Exp != nil {
101 walkNilable(node.Exp.Word, f)
102 }
103 case *ArithmExp:
104 Walk(node.X, f)
105 case *ArithmCmd:
106 Walk(node.X, f)
107 case *BinaryArithm:
108 Walk(node.X, f)
109 Walk(node.Y, f)
110 case *BinaryTest:
111 Walk(node.X, f)
112 Walk(node.Y, f)
113 case *UnaryArithm:
114 Walk(node.X, f)
115 case *UnaryTest:
116 Walk(node.X, f)
117 case *ParenArithm:
118 Walk(node.X, f)
119 case *ParenTest:
120 Walk(node.X, f)
121 case *CaseClause:
122 Walk(node.Word, f)
123 walkList(node.Items, f)
124 walkComments(node.Last, f)
125 case *CaseItem:
126 for _, c := range node.Comments {
127 if c.Pos().After(node.Pos()) {
128 defer Walk(&c, f)
129 break
130 }
131 Walk(&c, f)
132 }
133 walkList(node.Patterns, f)
134 walkList(node.Stmts, f)
135 walkComments(node.Last, f)
136 case *TestClause:
137 Walk(node.X, f)
138 case *DeclClause:
139 walkList(node.Args, f)
140 case *ArrayExpr:
141 walkList(node.Elems, f)
142 walkComments(node.Last, f)
143 case *ArrayElem:
144 for _, c := range node.Comments {
145 if c.Pos().After(node.Pos()) {
146 defer Walk(&c, f)
147 break
148 }
149 Walk(&c, f)
150 }
151 walkNilable(node.Index, f)
152 walkNilable(node.Value, f)
153 case *ExtGlob:
154 Walk(node.Pattern, f)
155 case *ProcSubst:
156 walkList(node.Stmts, f)
157 walkComments(node.Last, f)
158 case *TimeClause:
159 walkNilable(node.Stmt, f)
160 case *CoprocClause:
161 walkNilable(node.Name, f)
162 Walk(node.Stmt, f)
163 case *LetClause:
164 walkList(node.Exprs, f)
165 case *TestDecl:
166 Walk(node.Description, f)
167 Walk(node.Body, f)
168 default:
169 panic(fmt.Sprintf("syntax.Walk: unexpected node type %T", node))
170 }
171
172 f(nil)
173}
174
175type nilableNode interface {
176 Node
177 comparable // pointer nodes, which can be compared to nil
178}
179
180func walkNilable[N nilableNode](node N, f func(Node) bool) {
181 var zero N // nil
182 if node != zero {
183 Walk(node, f)
184 }
185}
186
187func walkList[N Node](list []N, f func(Node) bool) {
188 for _, node := range list {
189 Walk(node, f)
190 }
191}
192
193func walkComments(list []Comment, f func(Node) bool) {
194 // Note that []Comment does not satisfy the generic constraint []Node.
195 for i := range list {
196 Walk(&list[i], f)
197 }
198}
199
200// DebugPrint prints the provided syntax tree, spanning multiple lines and with
201// indentation. Can be useful to investigate the content of a syntax tree.
202func DebugPrint(w io.Writer, node Node) error {
203 p := debugPrinter{out: w}
204 p.print(reflect.ValueOf(node))
205 p.printf("\n")
206 return p.err
207}
208
209type debugPrinter struct {
210 out io.Writer
211 level int
212 err error
213}
214
215func (p *debugPrinter) printf(format string, args ...any) {
216 _, err := fmt.Fprintf(p.out, format, args...)
217 if err != nil && p.err == nil {
218 p.err = err
219 }
220}
221
222func (p *debugPrinter) newline() {
223 p.printf("\n")
224 for range p.level {
225 p.printf(". ")
226 }
227}
228
229func (p *debugPrinter) print(x reflect.Value) {
230 switch x.Kind() {
231 case reflect.Interface:
232 if x.IsNil() {
233 p.printf("nil")
234 return
235 }
236 p.print(x.Elem())
237 case reflect.Ptr:
238 if x.IsNil() {
239 p.printf("nil")
240 return
241 }
242 p.printf("*")
243 p.print(x.Elem())
244 case reflect.Slice:
245 p.printf("%s (len = %d) {", x.Type(), x.Len())
246 if x.Len() > 0 {
247 p.level++
248 p.newline()
249 for i := range x.Len() {
250 p.printf("%d: ", i)
251 p.print(x.Index(i))
252 if i == x.Len()-1 {
253 p.level--
254 }
255 p.newline()
256 }
257 }
258 p.printf("}")
259
260 case reflect.Struct:
261 if v, ok := x.Interface().(Pos); ok {
262 if v.IsRecovered() {
263 p.printf("<recovered>")
264 return
265 }
266 p.printf("%v:%v", v.Line(), v.Col())
267 return
268 }
269 t := x.Type()
270 p.printf("%s {", t)
271 p.level++
272 p.newline()
273 for i := range t.NumField() {
274 p.printf("%s: ", t.Field(i).Name)
275 p.print(x.Field(i))
276 if i == x.NumField()-1 {
277 p.level--
278 }
279 p.newline()
280 }
281 p.printf("}")
282 default:
283 if s, ok := x.Interface().(fmt.Stringer); ok && !x.IsZero() {
284 p.printf("%#v (%s)", x.Interface(), s)
285 } else {
286 p.printf("%#v", x.Interface())
287 }
288 }
289}