walk.go

  1// Copyright (c) 2016, Daniel MartΓ­ <mvdan@mvdan.cc>
  2// See LICENSE for licensing information
  3
  4package syntax
  5
  6import (
  7	"fmt"
  8	"io"
  9	"reflect"
 10)
 11
 12// Walk traverses a syntax tree in depth-first order: It starts by calling
 13// f(node); node must not be nil. If f returns true, Walk invokes f
 14// recursively for each of the non-nil children of node, followed by
 15// f(nil).
 16func Walk(node Node, f func(Node) bool) {
 17	if !f(node) {
 18		return
 19	}
 20
 21	switch node := node.(type) {
 22	case *File:
 23		walkList(node.Stmts, f)
 24		walkComments(node.Last, f)
 25	case *Comment:
 26	case *Stmt:
 27		for _, c := range node.Comments {
 28			if !node.End().After(c.Pos()) {
 29				defer Walk(&c, f)
 30				break
 31			}
 32			Walk(&c, f)
 33		}
 34		if node.Cmd != nil {
 35			Walk(node.Cmd, f)
 36		}
 37		walkList(node.Redirs, f)
 38	case *Assign:
 39		walkNilable(node.Name, f)
 40		walkNilable(node.Value, f)
 41		walkNilable(node.Index, f)
 42		walkNilable(node.Array, f)
 43	case *Redirect:
 44		walkNilable(node.N, f)
 45		Walk(node.Word, f)
 46		walkNilable(node.Hdoc, f)
 47	case *CallExpr:
 48		walkList(node.Assigns, f)
 49		walkList(node.Args, f)
 50	case *Subshell:
 51		walkList(node.Stmts, f)
 52		walkComments(node.Last, f)
 53	case *Block:
 54		walkList(node.Stmts, f)
 55		walkComments(node.Last, f)
 56	case *IfClause:
 57		walkList(node.Cond, f)
 58		walkComments(node.CondLast, f)
 59		walkList(node.Then, f)
 60		walkComments(node.ThenLast, f)
 61		walkNilable(node.Else, f)
 62	case *WhileClause:
 63		walkList(node.Cond, f)
 64		walkComments(node.CondLast, f)
 65		walkList(node.Do, f)
 66		walkComments(node.DoLast, f)
 67	case *ForClause:
 68		Walk(node.Loop, f)
 69		walkList(node.Do, f)
 70		walkComments(node.DoLast, f)
 71	case *WordIter:
 72		Walk(node.Name, f)
 73		walkList(node.Items, f)
 74	case *CStyleLoop:
 75		walkNilable(node.Init, f)
 76		walkNilable(node.Cond, f)
 77		walkNilable(node.Post, f)
 78	case *BinaryCmd:
 79		Walk(node.X, f)
 80		Walk(node.Y, f)
 81	case *FuncDecl:
 82		Walk(node.Name, f)
 83		Walk(node.Body, f)
 84	case *Word:
 85		walkList(node.Parts, f)
 86	case *Lit:
 87	case *SglQuoted:
 88	case *DblQuoted:
 89		walkList(node.Parts, f)
 90	case *CmdSubst:
 91		walkList(node.Stmts, f)
 92		walkComments(node.Last, f)
 93	case *ParamExp:
 94		Walk(node.Param, f)
 95		walkNilable(node.Index, f)
 96		if node.Repl != nil {
 97			walkNilable(node.Repl.Orig, f)
 98			walkNilable(node.Repl.With, f)
 99		}
100		if node.Exp != nil {
101			walkNilable(node.Exp.Word, f)
102		}
103	case *ArithmExp:
104		Walk(node.X, f)
105	case *ArithmCmd:
106		Walk(node.X, f)
107	case *BinaryArithm:
108		Walk(node.X, f)
109		Walk(node.Y, f)
110	case *BinaryTest:
111		Walk(node.X, f)
112		Walk(node.Y, f)
113	case *UnaryArithm:
114		Walk(node.X, f)
115	case *UnaryTest:
116		Walk(node.X, f)
117	case *ParenArithm:
118		Walk(node.X, f)
119	case *ParenTest:
120		Walk(node.X, f)
121	case *CaseClause:
122		Walk(node.Word, f)
123		walkList(node.Items, f)
124		walkComments(node.Last, f)
125	case *CaseItem:
126		for _, c := range node.Comments {
127			if c.Pos().After(node.Pos()) {
128				defer Walk(&c, f)
129				break
130			}
131			Walk(&c, f)
132		}
133		walkList(node.Patterns, f)
134		walkList(node.Stmts, f)
135		walkComments(node.Last, f)
136	case *TestClause:
137		Walk(node.X, f)
138	case *DeclClause:
139		walkList(node.Args, f)
140	case *ArrayExpr:
141		walkList(node.Elems, f)
142		walkComments(node.Last, f)
143	case *ArrayElem:
144		for _, c := range node.Comments {
145			if c.Pos().After(node.Pos()) {
146				defer Walk(&c, f)
147				break
148			}
149			Walk(&c, f)
150		}
151		walkNilable(node.Index, f)
152		walkNilable(node.Value, f)
153	case *ExtGlob:
154		Walk(node.Pattern, f)
155	case *ProcSubst:
156		walkList(node.Stmts, f)
157		walkComments(node.Last, f)
158	case *TimeClause:
159		walkNilable(node.Stmt, f)
160	case *CoprocClause:
161		walkNilable(node.Name, f)
162		Walk(node.Stmt, f)
163	case *LetClause:
164		walkList(node.Exprs, f)
165	case *TestDecl:
166		Walk(node.Description, f)
167		Walk(node.Body, f)
168	default:
169		panic(fmt.Sprintf("syntax.Walk: unexpected node type %T", node))
170	}
171
172	f(nil)
173}
174
175type nilableNode interface {
176	Node
177	comparable // pointer nodes, which can be compared to nil
178}
179
180func walkNilable[N nilableNode](node N, f func(Node) bool) {
181	var zero N // nil
182	if node != zero {
183		Walk(node, f)
184	}
185}
186
187func walkList[N Node](list []N, f func(Node) bool) {
188	for _, node := range list {
189		Walk(node, f)
190	}
191}
192
193func walkComments(list []Comment, f func(Node) bool) {
194	// Note that []Comment does not satisfy the generic constraint []Node.
195	for i := range list {
196		Walk(&list[i], f)
197	}
198}
199
200// DebugPrint prints the provided syntax tree, spanning multiple lines and with
201// indentation. Can be useful to investigate the content of a syntax tree.
202func DebugPrint(w io.Writer, node Node) error {
203	p := debugPrinter{out: w}
204	p.print(reflect.ValueOf(node))
205	p.printf("\n")
206	return p.err
207}
208
209type debugPrinter struct {
210	out   io.Writer
211	level int
212	err   error
213}
214
215func (p *debugPrinter) printf(format string, args ...any) {
216	_, err := fmt.Fprintf(p.out, format, args...)
217	if err != nil && p.err == nil {
218		p.err = err
219	}
220}
221
222func (p *debugPrinter) newline() {
223	p.printf("\n")
224	for range p.level {
225		p.printf(".  ")
226	}
227}
228
229func (p *debugPrinter) print(x reflect.Value) {
230	switch x.Kind() {
231	case reflect.Interface:
232		if x.IsNil() {
233			p.printf("nil")
234			return
235		}
236		p.print(x.Elem())
237	case reflect.Ptr:
238		if x.IsNil() {
239			p.printf("nil")
240			return
241		}
242		p.printf("*")
243		p.print(x.Elem())
244	case reflect.Slice:
245		p.printf("%s (len = %d) {", x.Type(), x.Len())
246		if x.Len() > 0 {
247			p.level++
248			p.newline()
249			for i := range x.Len() {
250				p.printf("%d: ", i)
251				p.print(x.Index(i))
252				if i == x.Len()-1 {
253					p.level--
254				}
255				p.newline()
256			}
257		}
258		p.printf("}")
259
260	case reflect.Struct:
261		if v, ok := x.Interface().(Pos); ok {
262			if v.IsRecovered() {
263				p.printf("<recovered>")
264				return
265			}
266			p.printf("%v:%v", v.Line(), v.Col())
267			return
268		}
269		t := x.Type()
270		p.printf("%s {", t)
271		p.level++
272		p.newline()
273		for i := range t.NumField() {
274			p.printf("%s: ", t.Field(i).Name)
275			p.print(x.Field(i))
276			if i == x.NumField()-1 {
277				p.level--
278			}
279			p.newline()
280		}
281		p.printf("}")
282	default:
283		if s, ok := x.Interface().(fmt.Stringer); ok && !x.IsZero() {
284			p.printf("%#v (%s)", x.Interface(), s)
285		} else {
286			p.printf("%#v", x.Interface())
287		}
288	}
289}