rewrite.go

  1// Copyright 2017 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package astutil
  6
  7import (
  8	"fmt"
  9	"go/ast"
 10	"reflect"
 11	"sort"
 12)
 13
 14// An ApplyFunc is invoked by Apply for each node n, even if n is nil,
 15// before and/or after the node's children, using a Cursor describing
 16// the current node and providing operations on it.
 17//
 18// The return value of ApplyFunc controls the syntax tree traversal.
 19// See Apply for details.
 20type ApplyFunc func(*Cursor) bool
 21
 22// Apply traverses a syntax tree recursively, starting with root,
 23// and calling pre and post for each node as described below.
 24// Apply returns the syntax tree, possibly modified.
 25//
 26// If pre is not nil, it is called for each node before the node's
 27// children are traversed (pre-order). If pre returns false, no
 28// children are traversed, and post is not called for that node.
 29//
 30// If post is not nil, and a prior call of pre didn't return false,
 31// post is called for each node after its children are traversed
 32// (post-order). If post returns false, traversal is terminated and
 33// Apply returns immediately.
 34//
 35// Only fields that refer to AST nodes are considered children;
 36// i.e., token.Pos, Scopes, Objects, and fields of basic types
 37// (strings, etc.) are ignored.
 38//
 39// Children are traversed in the order in which they appear in the
 40// respective node's struct definition. A package's files are
 41// traversed in the filenames' alphabetical order.
 42//
 43func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
 44	parent := &struct{ ast.Node }{root}
 45	defer func() {
 46		if r := recover(); r != nil && r != abort {
 47			panic(r)
 48		}
 49		result = parent.Node
 50	}()
 51	a := &application{pre: pre, post: post}
 52	a.apply(parent, "Node", nil, root)
 53	return
 54}
 55
 56var abort = new(int) // singleton, to signal termination of Apply
 57
 58// A Cursor describes a node encountered during Apply.
 59// Information about the node and its parent is available
 60// from the Node, Parent, Name, and Index methods.
 61//
 62// If p is a variable of type and value of the current parent node
 63// c.Parent(), and f is the field identifier with name c.Name(),
 64// the following invariants hold:
 65//
 66//   p.f            == c.Node()  if c.Index() <  0
 67//   p.f[c.Index()] == c.Node()  if c.Index() >= 0
 68//
 69// The methods Replace, Delete, InsertBefore, and InsertAfter
 70// can be used to change the AST without disrupting Apply.
 71type Cursor struct {
 72	parent ast.Node
 73	name   string
 74	iter   *iterator // valid if non-nil
 75	node   ast.Node
 76}
 77
 78// Node returns the current Node.
 79func (c *Cursor) Node() ast.Node { return c.node }
 80
 81// Parent returns the parent of the current Node.
 82func (c *Cursor) Parent() ast.Node { return c.parent }
 83
 84// Name returns the name of the parent Node field that contains the current Node.
 85// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
 86// the filename for the current Node.
 87func (c *Cursor) Name() string { return c.name }
 88
 89// Index reports the index >= 0 of the current Node in the slice of Nodes that
 90// contains it, or a value < 0 if the current Node is not part of a slice.
 91// The index of the current node changes if InsertBefore is called while
 92// processing the current node.
 93func (c *Cursor) Index() int {
 94	if c.iter != nil {
 95		return c.iter.index
 96	}
 97	return -1
 98}
 99
100// field returns the current node's parent field value.
101func (c *Cursor) field() reflect.Value {
102	return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
103}
104
105// Replace replaces the current Node with n.
106// The replacement node is not walked by Apply.
107func (c *Cursor) Replace(n ast.Node) {
108	if _, ok := c.node.(*ast.File); ok {
109		file, ok := n.(*ast.File)
110		if !ok {
111			panic("attempt to replace *ast.File with non-*ast.File")
112		}
113		c.parent.(*ast.Package).Files[c.name] = file
114		return
115	}
116
117	v := c.field()
118	if i := c.Index(); i >= 0 {
119		v = v.Index(i)
120	}
121	v.Set(reflect.ValueOf(n))
122}
123
124// Delete deletes the current Node from its containing slice.
125// If the current Node is not part of a slice, Delete panics.
126// As a special case, if the current node is a package file,
127// Delete removes it from the package's Files map.
128func (c *Cursor) Delete() {
129	if _, ok := c.node.(*ast.File); ok {
130		delete(c.parent.(*ast.Package).Files, c.name)
131		return
132	}
133
134	i := c.Index()
135	if i < 0 {
136		panic("Delete node not contained in slice")
137	}
138	v := c.field()
139	l := v.Len()
140	reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
141	v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
142	v.SetLen(l - 1)
143	c.iter.step--
144}
145
146// InsertAfter inserts n after the current Node in its containing slice.
147// If the current Node is not part of a slice, InsertAfter panics.
148// Apply does not walk n.
149func (c *Cursor) InsertAfter(n ast.Node) {
150	i := c.Index()
151	if i < 0 {
152		panic("InsertAfter node not contained in slice")
153	}
154	v := c.field()
155	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
156	l := v.Len()
157	reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
158	v.Index(i + 1).Set(reflect.ValueOf(n))
159	c.iter.step++
160}
161
162// InsertBefore inserts n before the current Node in its containing slice.
163// If the current Node is not part of a slice, InsertBefore panics.
164// Apply will not walk n.
165func (c *Cursor) InsertBefore(n ast.Node) {
166	i := c.Index()
167	if i < 0 {
168		panic("InsertBefore node not contained in slice")
169	}
170	v := c.field()
171	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
172	l := v.Len()
173	reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
174	v.Index(i).Set(reflect.ValueOf(n))
175	c.iter.index++
176}
177
178// application carries all the shared data so we can pass it around cheaply.
179type application struct {
180	pre, post ApplyFunc
181	cursor    Cursor
182	iter      iterator
183}
184
185func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
186	// convert typed nil into untyped nil
187	if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
188		n = nil
189	}
190
191	// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
192	saved := a.cursor
193	a.cursor.parent = parent
194	a.cursor.name = name
195	a.cursor.iter = iter
196	a.cursor.node = n
197
198	if a.pre != nil && !a.pre(&a.cursor) {
199		a.cursor = saved
200		return
201	}
202
203	// walk children
204	// (the order of the cases matches the order of the corresponding node types in go/ast)
205	switch n := n.(type) {
206	case nil:
207		// nothing to do
208
209	// Comments and fields
210	case *ast.Comment:
211		// nothing to do
212
213	case *ast.CommentGroup:
214		if n != nil {
215			a.applyList(n, "List")
216		}
217
218	case *ast.Field:
219		a.apply(n, "Doc", nil, n.Doc)
220		a.applyList(n, "Names")
221		a.apply(n, "Type", nil, n.Type)
222		a.apply(n, "Tag", nil, n.Tag)
223		a.apply(n, "Comment", nil, n.Comment)
224
225	case *ast.FieldList:
226		a.applyList(n, "List")
227
228	// Expressions
229	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
230		// nothing to do
231
232	case *ast.Ellipsis:
233		a.apply(n, "Elt", nil, n.Elt)
234
235	case *ast.FuncLit:
236		a.apply(n, "Type", nil, n.Type)
237		a.apply(n, "Body", nil, n.Body)
238
239	case *ast.CompositeLit:
240		a.apply(n, "Type", nil, n.Type)
241		a.applyList(n, "Elts")
242
243	case *ast.ParenExpr:
244		a.apply(n, "X", nil, n.X)
245
246	case *ast.SelectorExpr:
247		a.apply(n, "X", nil, n.X)
248		a.apply(n, "Sel", nil, n.Sel)
249
250	case *ast.IndexExpr:
251		a.apply(n, "X", nil, n.X)
252		a.apply(n, "Index", nil, n.Index)
253
254	case *ast.SliceExpr:
255		a.apply(n, "X", nil, n.X)
256		a.apply(n, "Low", nil, n.Low)
257		a.apply(n, "High", nil, n.High)
258		a.apply(n, "Max", nil, n.Max)
259
260	case *ast.TypeAssertExpr:
261		a.apply(n, "X", nil, n.X)
262		a.apply(n, "Type", nil, n.Type)
263
264	case *ast.CallExpr:
265		a.apply(n, "Fun", nil, n.Fun)
266		a.applyList(n, "Args")
267
268	case *ast.StarExpr:
269		a.apply(n, "X", nil, n.X)
270
271	case *ast.UnaryExpr:
272		a.apply(n, "X", nil, n.X)
273
274	case *ast.BinaryExpr:
275		a.apply(n, "X", nil, n.X)
276		a.apply(n, "Y", nil, n.Y)
277
278	case *ast.KeyValueExpr:
279		a.apply(n, "Key", nil, n.Key)
280		a.apply(n, "Value", nil, n.Value)
281
282	// Types
283	case *ast.ArrayType:
284		a.apply(n, "Len", nil, n.Len)
285		a.apply(n, "Elt", nil, n.Elt)
286
287	case *ast.StructType:
288		a.apply(n, "Fields", nil, n.Fields)
289
290	case *ast.FuncType:
291		a.apply(n, "Params", nil, n.Params)
292		a.apply(n, "Results", nil, n.Results)
293
294	case *ast.InterfaceType:
295		a.apply(n, "Methods", nil, n.Methods)
296
297	case *ast.MapType:
298		a.apply(n, "Key", nil, n.Key)
299		a.apply(n, "Value", nil, n.Value)
300
301	case *ast.ChanType:
302		a.apply(n, "Value", nil, n.Value)
303
304	// Statements
305	case *ast.BadStmt:
306		// nothing to do
307
308	case *ast.DeclStmt:
309		a.apply(n, "Decl", nil, n.Decl)
310
311	case *ast.EmptyStmt:
312		// nothing to do
313
314	case *ast.LabeledStmt:
315		a.apply(n, "Label", nil, n.Label)
316		a.apply(n, "Stmt", nil, n.Stmt)
317
318	case *ast.ExprStmt:
319		a.apply(n, "X", nil, n.X)
320
321	case *ast.SendStmt:
322		a.apply(n, "Chan", nil, n.Chan)
323		a.apply(n, "Value", nil, n.Value)
324
325	case *ast.IncDecStmt:
326		a.apply(n, "X", nil, n.X)
327
328	case *ast.AssignStmt:
329		a.applyList(n, "Lhs")
330		a.applyList(n, "Rhs")
331
332	case *ast.GoStmt:
333		a.apply(n, "Call", nil, n.Call)
334
335	case *ast.DeferStmt:
336		a.apply(n, "Call", nil, n.Call)
337
338	case *ast.ReturnStmt:
339		a.applyList(n, "Results")
340
341	case *ast.BranchStmt:
342		a.apply(n, "Label", nil, n.Label)
343
344	case *ast.BlockStmt:
345		a.applyList(n, "List")
346
347	case *ast.IfStmt:
348		a.apply(n, "Init", nil, n.Init)
349		a.apply(n, "Cond", nil, n.Cond)
350		a.apply(n, "Body", nil, n.Body)
351		a.apply(n, "Else", nil, n.Else)
352
353	case *ast.CaseClause:
354		a.applyList(n, "List")
355		a.applyList(n, "Body")
356
357	case *ast.SwitchStmt:
358		a.apply(n, "Init", nil, n.Init)
359		a.apply(n, "Tag", nil, n.Tag)
360		a.apply(n, "Body", nil, n.Body)
361
362	case *ast.TypeSwitchStmt:
363		a.apply(n, "Init", nil, n.Init)
364		a.apply(n, "Assign", nil, n.Assign)
365		a.apply(n, "Body", nil, n.Body)
366
367	case *ast.CommClause:
368		a.apply(n, "Comm", nil, n.Comm)
369		a.applyList(n, "Body")
370
371	case *ast.SelectStmt:
372		a.apply(n, "Body", nil, n.Body)
373
374	case *ast.ForStmt:
375		a.apply(n, "Init", nil, n.Init)
376		a.apply(n, "Cond", nil, n.Cond)
377		a.apply(n, "Post", nil, n.Post)
378		a.apply(n, "Body", nil, n.Body)
379
380	case *ast.RangeStmt:
381		a.apply(n, "Key", nil, n.Key)
382		a.apply(n, "Value", nil, n.Value)
383		a.apply(n, "X", nil, n.X)
384		a.apply(n, "Body", nil, n.Body)
385
386	// Declarations
387	case *ast.ImportSpec:
388		a.apply(n, "Doc", nil, n.Doc)
389		a.apply(n, "Name", nil, n.Name)
390		a.apply(n, "Path", nil, n.Path)
391		a.apply(n, "Comment", nil, n.Comment)
392
393	case *ast.ValueSpec:
394		a.apply(n, "Doc", nil, n.Doc)
395		a.applyList(n, "Names")
396		a.apply(n, "Type", nil, n.Type)
397		a.applyList(n, "Values")
398		a.apply(n, "Comment", nil, n.Comment)
399
400	case *ast.TypeSpec:
401		a.apply(n, "Doc", nil, n.Doc)
402		a.apply(n, "Name", nil, n.Name)
403		a.apply(n, "Type", nil, n.Type)
404		a.apply(n, "Comment", nil, n.Comment)
405
406	case *ast.BadDecl:
407		// nothing to do
408
409	case *ast.GenDecl:
410		a.apply(n, "Doc", nil, n.Doc)
411		a.applyList(n, "Specs")
412
413	case *ast.FuncDecl:
414		a.apply(n, "Doc", nil, n.Doc)
415		a.apply(n, "Recv", nil, n.Recv)
416		a.apply(n, "Name", nil, n.Name)
417		a.apply(n, "Type", nil, n.Type)
418		a.apply(n, "Body", nil, n.Body)
419
420	// Files and packages
421	case *ast.File:
422		a.apply(n, "Doc", nil, n.Doc)
423		a.apply(n, "Name", nil, n.Name)
424		a.applyList(n, "Decls")
425		// Don't walk n.Comments; they have either been walked already if
426		// they are Doc comments, or they can be easily walked explicitly.
427
428	case *ast.Package:
429		// collect and sort names for reproducible behavior
430		var names []string
431		for name := range n.Files {
432			names = append(names, name)
433		}
434		sort.Strings(names)
435		for _, name := range names {
436			a.apply(n, name, nil, n.Files[name])
437		}
438
439	default:
440		panic(fmt.Sprintf("Apply: unexpected node type %T", n))
441	}
442
443	if a.post != nil && !a.post(&a.cursor) {
444		panic(abort)
445	}
446
447	a.cursor = saved
448}
449
450// An iterator controls iteration over a slice of nodes.
451type iterator struct {
452	index, step int
453}
454
455func (a *application) applyList(parent ast.Node, name string) {
456	// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
457	saved := a.iter
458	a.iter.index = 0
459	for {
460		// must reload parent.name each time, since cursor modifications might change it
461		v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
462		if a.iter.index >= v.Len() {
463			break
464		}
465
466		// element x may be nil in a bad AST - be cautious
467		var x ast.Node
468		if e := v.Index(a.iter.index); e.IsValid() {
469			x = e.Interface().(ast.Node)
470		}
471
472		a.iter.step = 1
473		a.apply(parent, name, &a.iter, x)
474		a.iter.index += a.iter.step
475	}
476	a.iter = saved
477}