1package ssa
2
3import (
4 "fmt"
5 "math"
6 "strings"
7
8 "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
9)
10
11// passCalculateImmediateDominators calculates immediate dominators for each basic block.
12// The result is stored in b.dominators. This make it possible for the following passes to
13// use builder.isDominatedBy to check if a block is dominated by another block.
14//
15// At the last of pass, this function also does the loop detection and sets the basicBlock.loop flag.
16func passCalculateImmediateDominators(b *builder) {
17 reversePostOrder := b.reversePostOrderedBasicBlocks[:0]
18
19 // Store the reverse postorder from the entrypoint into reversePostOrder slice.
20 // This calculation of reverse postorder is not described in the paper,
21 // so we use heuristic to calculate it so that we could potentially handle arbitrary
22 // complex CFGs under the assumption that success is sorted in program's natural order.
23 // That means blk.success[i] always appears before blk.success[i+1] in the source program,
24 // which is a reasonable assumption as long as SSA Builder is properly used.
25 //
26 // First we push blocks in postorder iteratively visit successors of the entry block.
27 entryBlk := b.entryBlk()
28 exploreStack := append(b.blkStack[:0], entryBlk)
29 // These flags are used to track the state of the block in the DFS traversal.
30 // We temporarily use the reversePostOrder field to store the state.
31 const visitStateUnseen, visitStateSeen, visitStateDone = 0, 1, 2
32 entryBlk.visited = visitStateSeen
33 for len(exploreStack) > 0 {
34 tail := len(exploreStack) - 1
35 blk := exploreStack[tail]
36 exploreStack = exploreStack[:tail]
37 switch blk.visited {
38 case visitStateUnseen:
39 // This is likely a bug in the frontend.
40 panic("BUG: unsupported CFG")
41 case visitStateSeen:
42 // This is the first time to pop this block, and we have to see the successors first.
43 // So push this block again to the stack.
44 exploreStack = append(exploreStack, blk)
45 // And push the successors to the stack if necessary.
46 for _, succ := range blk.success {
47 if succ.ReturnBlock() || succ.invalid {
48 continue
49 }
50 if succ.visited == visitStateUnseen {
51 succ.visited = visitStateSeen
52 exploreStack = append(exploreStack, succ)
53 }
54 }
55 // Finally, we could pop this block once we pop all of its successors.
56 blk.visited = visitStateDone
57 case visitStateDone:
58 // Note: at this point we push blk in postorder despite its name.
59 reversePostOrder = append(reversePostOrder, blk)
60 default:
61 panic("BUG")
62 }
63 }
64 // At this point, reversePostOrder has postorder actually, so we reverse it.
65 for i := len(reversePostOrder)/2 - 1; i >= 0; i-- {
66 j := len(reversePostOrder) - 1 - i
67 reversePostOrder[i], reversePostOrder[j] = reversePostOrder[j], reversePostOrder[i]
68 }
69
70 for i, blk := range reversePostOrder {
71 blk.reversePostOrder = int32(i)
72 }
73
74 // Reuse the dominators slice if possible from the previous computation of function.
75 b.dominators = b.dominators[:cap(b.dominators)]
76 if len(b.dominators) < b.basicBlocksPool.Allocated() {
77 // Generously reserve space in the slice because the slice will be reused future allocation.
78 b.dominators = append(b.dominators, make([]*basicBlock, b.basicBlocksPool.Allocated())...)
79 }
80 calculateDominators(reversePostOrder, b.dominators)
81
82 // Reuse the slices for the future use.
83 b.blkStack = exploreStack
84
85 // For the following passes.
86 b.reversePostOrderedBasicBlocks = reversePostOrder
87
88 // Ready to detect loops!
89 subPassLoopDetection(b)
90}
91
92// calculateDominators calculates the immediate dominator of each node in the CFG, and store the result in `doms`.
93// The algorithm is based on the one described in the paper "A Simple, Fast Dominance Algorithm"
94// https://www.cs.rice.edu/~keith/EMBED/dom.pdf which is a faster/simple alternative to the well known Lengauer-Tarjan algorithm.
95//
96// The following code almost matches the pseudocode in the paper with one exception (see the code comment below).
97//
98// The result slice `doms` must be pre-allocated with the size larger than the size of dfsBlocks.
99func calculateDominators(reversePostOrderedBlks []*basicBlock, doms []*basicBlock) {
100 entry, reversePostOrderedBlks := reversePostOrderedBlks[0], reversePostOrderedBlks[1: /* skips entry point */]
101 for _, blk := range reversePostOrderedBlks {
102 doms[blk.id] = nil
103 }
104 doms[entry.id] = entry
105
106 changed := true
107 for changed {
108 changed = false
109 for _, blk := range reversePostOrderedBlks {
110 var u *basicBlock
111 for i := range blk.preds {
112 pred := blk.preds[i].blk
113 // Skip if this pred is not reachable yet. Note that this is not described in the paper,
114 // but it is necessary to handle nested loops etc.
115 if doms[pred.id] == nil {
116 continue
117 }
118
119 if u == nil {
120 u = pred
121 continue
122 } else {
123 u = intersect(doms, u, pred)
124 }
125 }
126 if doms[blk.id] != u {
127 doms[blk.id] = u
128 changed = true
129 }
130 }
131 }
132}
133
134// intersect returns the common dominator of blk1 and blk2.
135//
136// This is the `intersect` function in the paper.
137func intersect(doms []*basicBlock, blk1 *basicBlock, blk2 *basicBlock) *basicBlock {
138 finger1, finger2 := blk1, blk2
139 for finger1 != finger2 {
140 // Move the 'finger1' upwards to its immediate dominator.
141 for finger1.reversePostOrder > finger2.reversePostOrder {
142 finger1 = doms[finger1.id]
143 }
144 // Move the 'finger2' upwards to its immediate dominator.
145 for finger2.reversePostOrder > finger1.reversePostOrder {
146 finger2 = doms[finger2.id]
147 }
148 }
149 return finger1
150}
151
152// subPassLoopDetection detects loops in the function using the immediate dominators.
153//
154// This is run at the last of passCalculateImmediateDominators.
155func subPassLoopDetection(b *builder) {
156 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
157 for i := range blk.preds {
158 pred := blk.preds[i].blk
159 if pred.invalid {
160 continue
161 }
162 if b.isDominatedBy(pred, blk) {
163 blk.loopHeader = true
164 }
165 }
166 }
167}
168
169// buildLoopNestingForest builds the loop nesting forest for the function.
170// This must be called after branch splitting since it relies on the CFG.
171func passBuildLoopNestingForest(b *builder) {
172 ent := b.entryBlk()
173 doms := b.dominators
174 for _, blk := range b.reversePostOrderedBasicBlocks {
175 n := doms[blk.id]
176 for !n.loopHeader && n != ent {
177 n = doms[n.id]
178 }
179
180 if n == ent && blk.loopHeader {
181 b.loopNestingForestRoots = append(b.loopNestingForestRoots, blk)
182 } else if n == ent {
183 } else if n.loopHeader {
184 n.loopNestingForestChildren = n.loopNestingForestChildren.Append(&b.varLengthBasicBlockPool, blk)
185 }
186 }
187
188 if wazevoapi.SSALoggingEnabled {
189 for _, root := range b.loopNestingForestRoots {
190 printLoopNestingForest(root.(*basicBlock), 0)
191 }
192 }
193}
194
195func printLoopNestingForest(root *basicBlock, depth int) {
196 fmt.Println(strings.Repeat("\t", depth), "loop nesting forest root:", root.ID())
197 for _, child := range root.loopNestingForestChildren.View() {
198 fmt.Println(strings.Repeat("\t", depth+1), "child:", child.ID())
199 if child.LoopHeader() {
200 printLoopNestingForest(child.(*basicBlock), depth+2)
201 }
202 }
203}
204
205type dominatorSparseTree struct {
206 time int32
207 euler []*basicBlock
208 first, depth []int32
209 table [][]int32
210}
211
212// passBuildDominatorTree builds the dominator tree for the function, and constructs builder.sparseTree.
213func passBuildDominatorTree(b *builder) {
214 // First we materialize the children of each node in the dominator tree.
215 idoms := b.dominators
216 for _, blk := range b.reversePostOrderedBasicBlocks {
217 parent := idoms[blk.id]
218 if parent == nil {
219 panic("BUG")
220 } else if parent == blk {
221 // This is the entry block.
222 continue
223 }
224 if prev := parent.child; prev == nil {
225 parent.child = blk
226 } else {
227 parent.child = blk
228 blk.sibling = prev
229 }
230 }
231
232 // Reset the state from the previous computation.
233 n := b.basicBlocksPool.Allocated()
234 st := &b.sparseTree
235 st.euler = append(st.euler[:0], make([]*basicBlock, 2*n-1)...)
236 st.first = append(st.first[:0], make([]int32, n)...)
237 for i := range st.first {
238 st.first[i] = -1
239 }
240 st.depth = append(st.depth[:0], make([]int32, 2*n-1)...)
241 st.time = 0
242
243 // Start building the sparse tree.
244 st.eulerTour(b.entryBlk(), 0)
245 st.buildSparseTable()
246}
247
248func (dt *dominatorSparseTree) eulerTour(node *basicBlock, height int32) {
249 if wazevoapi.SSALoggingEnabled {
250 fmt.Println(strings.Repeat("\t", int(height)), "euler tour:", node.ID())
251 }
252 dt.euler[dt.time] = node
253 dt.depth[dt.time] = height
254 if dt.first[node.id] == -1 {
255 dt.first[node.id] = dt.time
256 }
257 dt.time++
258
259 for child := node.child; child != nil; child = child.sibling {
260 dt.eulerTour(child, height+1)
261 dt.euler[dt.time] = node // add the current node again after visiting a child
262 dt.depth[dt.time] = height
263 dt.time++
264 }
265}
266
267// buildSparseTable builds a sparse table for RMQ queries.
268func (dt *dominatorSparseTree) buildSparseTable() {
269 n := len(dt.depth)
270 k := int(math.Log2(float64(n))) + 1
271 table := dt.table
272
273 if n >= len(table) {
274 table = append(table, make([][]int32, n-len(table)+1)...)
275 }
276 for i := range table {
277 if len(table[i]) < k {
278 table[i] = append(table[i], make([]int32, k-len(table[i]))...)
279 }
280 table[i][0] = int32(i)
281 }
282
283 for j := 1; 1<<j <= n; j++ {
284 for i := 0; i+(1<<j)-1 < n; i++ {
285 if dt.depth[table[i][j-1]] < dt.depth[table[i+(1<<(j-1))][j-1]] {
286 table[i][j] = table[i][j-1]
287 } else {
288 table[i][j] = table[i+(1<<(j-1))][j-1]
289 }
290 }
291 }
292 dt.table = table
293}
294
295// rmq performs a range minimum query on the sparse table.
296func (dt *dominatorSparseTree) rmq(l, r int32) int32 {
297 table := dt.table
298 depth := dt.depth
299 j := int(math.Log2(float64(r - l + 1)))
300 if depth[table[l][j]] <= depth[table[r-(1<<j)+1][j]] {
301 return table[l][j]
302 }
303 return table[r-(1<<j)+1][j]
304}
305
306// findLCA finds the LCA using the Euler tour and RMQ.
307func (dt *dominatorSparseTree) findLCA(u, v BasicBlockID) *basicBlock {
308 first := dt.first
309 if first[u] > first[v] {
310 u, v = v, u
311 }
312 return dt.euler[dt.rmq(first[u], first[v])]
313}