pass_cfg.go

  1package ssa
  2
  3import (
  4	"fmt"
  5	"math"
  6	"strings"
  7
  8	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
  9)
 10
 11// passCalculateImmediateDominators calculates immediate dominators for each basic block.
 12// The result is stored in b.dominators. This make it possible for the following passes to
 13// use builder.isDominatedBy to check if a block is dominated by another block.
 14//
 15// At the last of pass, this function also does the loop detection and sets the basicBlock.loop flag.
 16func passCalculateImmediateDominators(b *builder) {
 17	reversePostOrder := b.reversePostOrderedBasicBlocks[:0]
 18
 19	// Store the reverse postorder from the entrypoint into reversePostOrder slice.
 20	// This calculation of reverse postorder is not described in the paper,
 21	// so we use heuristic to calculate it so that we could potentially handle arbitrary
 22	// complex CFGs under the assumption that success is sorted in program's natural order.
 23	// That means blk.success[i] always appears before blk.success[i+1] in the source program,
 24	// which is a reasonable assumption as long as SSA Builder is properly used.
 25	//
 26	// First we push blocks in postorder iteratively visit successors of the entry block.
 27	entryBlk := b.entryBlk()
 28	exploreStack := append(b.blkStack[:0], entryBlk)
 29	// These flags are used to track the state of the block in the DFS traversal.
 30	// We temporarily use the reversePostOrder field to store the state.
 31	const visitStateUnseen, visitStateSeen, visitStateDone = 0, 1, 2
 32	entryBlk.visited = visitStateSeen
 33	for len(exploreStack) > 0 {
 34		tail := len(exploreStack) - 1
 35		blk := exploreStack[tail]
 36		exploreStack = exploreStack[:tail]
 37		switch blk.visited {
 38		case visitStateUnseen:
 39			// This is likely a bug in the frontend.
 40			panic("BUG: unsupported CFG")
 41		case visitStateSeen:
 42			// This is the first time to pop this block, and we have to see the successors first.
 43			// So push this block again to the stack.
 44			exploreStack = append(exploreStack, blk)
 45			// And push the successors to the stack if necessary.
 46			for _, succ := range blk.success {
 47				if succ.ReturnBlock() || succ.invalid {
 48					continue
 49				}
 50				if succ.visited == visitStateUnseen {
 51					succ.visited = visitStateSeen
 52					exploreStack = append(exploreStack, succ)
 53				}
 54			}
 55			// Finally, we could pop this block once we pop all of its successors.
 56			blk.visited = visitStateDone
 57		case visitStateDone:
 58			// Note: at this point we push blk in postorder despite its name.
 59			reversePostOrder = append(reversePostOrder, blk)
 60		default:
 61			panic("BUG")
 62		}
 63	}
 64	// At this point, reversePostOrder has postorder actually, so we reverse it.
 65	for i := len(reversePostOrder)/2 - 1; i >= 0; i-- {
 66		j := len(reversePostOrder) - 1 - i
 67		reversePostOrder[i], reversePostOrder[j] = reversePostOrder[j], reversePostOrder[i]
 68	}
 69
 70	for i, blk := range reversePostOrder {
 71		blk.reversePostOrder = int32(i)
 72	}
 73
 74	// Reuse the dominators slice if possible from the previous computation of function.
 75	b.dominators = b.dominators[:cap(b.dominators)]
 76	if len(b.dominators) < b.basicBlocksPool.Allocated() {
 77		// Generously reserve space in the slice because the slice will be reused future allocation.
 78		b.dominators = append(b.dominators, make([]*basicBlock, b.basicBlocksPool.Allocated())...)
 79	}
 80	calculateDominators(reversePostOrder, b.dominators)
 81
 82	// Reuse the slices for the future use.
 83	b.blkStack = exploreStack
 84
 85	// For the following passes.
 86	b.reversePostOrderedBasicBlocks = reversePostOrder
 87
 88	// Ready to detect loops!
 89	subPassLoopDetection(b)
 90}
 91
 92// calculateDominators calculates the immediate dominator of each node in the CFG, and store the result in `doms`.
 93// The algorithm is based on the one described in the paper "A Simple, Fast Dominance Algorithm"
 94// https://www.cs.rice.edu/~keith/EMBED/dom.pdf which is a faster/simple alternative to the well known Lengauer-Tarjan algorithm.
 95//
 96// The following code almost matches the pseudocode in the paper with one exception (see the code comment below).
 97//
 98// The result slice `doms` must be pre-allocated with the size larger than the size of dfsBlocks.
 99func calculateDominators(reversePostOrderedBlks []*basicBlock, doms []*basicBlock) {
100	entry, reversePostOrderedBlks := reversePostOrderedBlks[0], reversePostOrderedBlks[1: /* skips entry point */]
101	for _, blk := range reversePostOrderedBlks {
102		doms[blk.id] = nil
103	}
104	doms[entry.id] = entry
105
106	changed := true
107	for changed {
108		changed = false
109		for _, blk := range reversePostOrderedBlks {
110			var u *basicBlock
111			for i := range blk.preds {
112				pred := blk.preds[i].blk
113				// Skip if this pred is not reachable yet. Note that this is not described in the paper,
114				// but it is necessary to handle nested loops etc.
115				if doms[pred.id] == nil {
116					continue
117				}
118
119				if u == nil {
120					u = pred
121					continue
122				} else {
123					u = intersect(doms, u, pred)
124				}
125			}
126			if doms[blk.id] != u {
127				doms[blk.id] = u
128				changed = true
129			}
130		}
131	}
132}
133
134// intersect returns the common dominator of blk1 and blk2.
135//
136// This is the `intersect` function in the paper.
137func intersect(doms []*basicBlock, blk1 *basicBlock, blk2 *basicBlock) *basicBlock {
138	finger1, finger2 := blk1, blk2
139	for finger1 != finger2 {
140		// Move the 'finger1' upwards to its immediate dominator.
141		for finger1.reversePostOrder > finger2.reversePostOrder {
142			finger1 = doms[finger1.id]
143		}
144		// Move the 'finger2' upwards to its immediate dominator.
145		for finger2.reversePostOrder > finger1.reversePostOrder {
146			finger2 = doms[finger2.id]
147		}
148	}
149	return finger1
150}
151
152// subPassLoopDetection detects loops in the function using the immediate dominators.
153//
154// This is run at the last of passCalculateImmediateDominators.
155func subPassLoopDetection(b *builder) {
156	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
157		for i := range blk.preds {
158			pred := blk.preds[i].blk
159			if pred.invalid {
160				continue
161			}
162			if b.isDominatedBy(pred, blk) {
163				blk.loopHeader = true
164			}
165		}
166	}
167}
168
169// buildLoopNestingForest builds the loop nesting forest for the function.
170// This must be called after branch splitting since it relies on the CFG.
171func passBuildLoopNestingForest(b *builder) {
172	ent := b.entryBlk()
173	doms := b.dominators
174	for _, blk := range b.reversePostOrderedBasicBlocks {
175		n := doms[blk.id]
176		for !n.loopHeader && n != ent {
177			n = doms[n.id]
178		}
179
180		if n == ent && blk.loopHeader {
181			b.loopNestingForestRoots = append(b.loopNestingForestRoots, blk)
182		} else if n == ent {
183		} else if n.loopHeader {
184			n.loopNestingForestChildren = n.loopNestingForestChildren.Append(&b.varLengthBasicBlockPool, blk)
185		}
186	}
187
188	if wazevoapi.SSALoggingEnabled {
189		for _, root := range b.loopNestingForestRoots {
190			printLoopNestingForest(root.(*basicBlock), 0)
191		}
192	}
193}
194
195func printLoopNestingForest(root *basicBlock, depth int) {
196	fmt.Println(strings.Repeat("\t", depth), "loop nesting forest root:", root.ID())
197	for _, child := range root.loopNestingForestChildren.View() {
198		fmt.Println(strings.Repeat("\t", depth+1), "child:", child.ID())
199		if child.LoopHeader() {
200			printLoopNestingForest(child.(*basicBlock), depth+2)
201		}
202	}
203}
204
205type dominatorSparseTree struct {
206	time         int32
207	euler        []*basicBlock
208	first, depth []int32
209	table        [][]int32
210}
211
212// passBuildDominatorTree builds the dominator tree for the function, and constructs builder.sparseTree.
213func passBuildDominatorTree(b *builder) {
214	// First we materialize the children of each node in the dominator tree.
215	idoms := b.dominators
216	for _, blk := range b.reversePostOrderedBasicBlocks {
217		parent := idoms[blk.id]
218		if parent == nil {
219			panic("BUG")
220		} else if parent == blk {
221			// This is the entry block.
222			continue
223		}
224		if prev := parent.child; prev == nil {
225			parent.child = blk
226		} else {
227			parent.child = blk
228			blk.sibling = prev
229		}
230	}
231
232	// Reset the state from the previous computation.
233	n := b.basicBlocksPool.Allocated()
234	st := &b.sparseTree
235	st.euler = append(st.euler[:0], make([]*basicBlock, 2*n-1)...)
236	st.first = append(st.first[:0], make([]int32, n)...)
237	for i := range st.first {
238		st.first[i] = -1
239	}
240	st.depth = append(st.depth[:0], make([]int32, 2*n-1)...)
241	st.time = 0
242
243	// Start building the sparse tree.
244	st.eulerTour(b.entryBlk(), 0)
245	st.buildSparseTable()
246}
247
248func (dt *dominatorSparseTree) eulerTour(node *basicBlock, height int32) {
249	if wazevoapi.SSALoggingEnabled {
250		fmt.Println(strings.Repeat("\t", int(height)), "euler tour:", node.ID())
251	}
252	dt.euler[dt.time] = node
253	dt.depth[dt.time] = height
254	if dt.first[node.id] == -1 {
255		dt.first[node.id] = dt.time
256	}
257	dt.time++
258
259	for child := node.child; child != nil; child = child.sibling {
260		dt.eulerTour(child, height+1)
261		dt.euler[dt.time] = node // add the current node again after visiting a child
262		dt.depth[dt.time] = height
263		dt.time++
264	}
265}
266
267// buildSparseTable builds a sparse table for RMQ queries.
268func (dt *dominatorSparseTree) buildSparseTable() {
269	n := len(dt.depth)
270	k := int(math.Log2(float64(n))) + 1
271	table := dt.table
272
273	if n >= len(table) {
274		table = append(table, make([][]int32, n-len(table)+1)...)
275	}
276	for i := range table {
277		if len(table[i]) < k {
278			table[i] = append(table[i], make([]int32, k-len(table[i]))...)
279		}
280		table[i][0] = int32(i)
281	}
282
283	for j := 1; 1<<j <= n; j++ {
284		for i := 0; i+(1<<j)-1 < n; i++ {
285			if dt.depth[table[i][j-1]] < dt.depth[table[i+(1<<(j-1))][j-1]] {
286				table[i][j] = table[i][j-1]
287			} else {
288				table[i][j] = table[i+(1<<(j-1))][j-1]
289			}
290		}
291	}
292	dt.table = table
293}
294
295// rmq performs a range minimum query on the sparse table.
296func (dt *dominatorSparseTree) rmq(l, r int32) int32 {
297	table := dt.table
298	depth := dt.depth
299	j := int(math.Log2(float64(r - l + 1)))
300	if depth[table[l][j]] <= depth[table[r-(1<<j)+1][j]] {
301		return table[l][j]
302	}
303	return table[r-(1<<j)+1][j]
304}
305
306// findLCA finds the LCA using the Euler tour and RMQ.
307func (dt *dominatorSparseTree) findLCA(u, v BasicBlockID) *basicBlock {
308	first := dt.first
309	if first[u] > first[v] {
310		u, v = v, u
311	}
312	return dt.euler[dt.rmq(first[u], first[v])]
313}