compiler_lower.go

  1package backend
  2
  3import (
  4	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc"
  5	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
  6)
  7
  8// Lower implements Compiler.Lower.
  9func (c *compiler) Lower() {
 10	c.assignVirtualRegisters()
 11	c.mach.SetCurrentABI(c.GetFunctionABI(c.ssaBuilder.Signature()))
 12	c.mach.StartLoweringFunction(c.ssaBuilder.BlockIDMax())
 13	c.lowerBlocks()
 14}
 15
 16// lowerBlocks lowers each block in the ssa.Builder.
 17func (c *compiler) lowerBlocks() {
 18	builder := c.ssaBuilder
 19	for blk := builder.BlockIteratorReversePostOrderBegin(); blk != nil; blk = builder.BlockIteratorReversePostOrderNext() {
 20		c.lowerBlock(blk)
 21	}
 22
 23	// After lowering all blocks, we need to link adjacent blocks to layout one single instruction list.
 24	var prev ssa.BasicBlock
 25	for next := builder.BlockIteratorReversePostOrderBegin(); next != nil; next = builder.BlockIteratorReversePostOrderNext() {
 26		if prev != nil {
 27			c.mach.LinkAdjacentBlocks(prev, next)
 28		}
 29		prev = next
 30	}
 31}
 32
 33func (c *compiler) lowerBlock(blk ssa.BasicBlock) {
 34	mach := c.mach
 35	mach.StartBlock(blk)
 36
 37	// We traverse the instructions in reverse order because we might want to lower multiple
 38	// instructions together.
 39	cur := blk.Tail()
 40
 41	// First gather the branching instructions at the end of the blocks.
 42	var br0, br1 *ssa.Instruction
 43	if cur.IsBranching() {
 44		br0 = cur
 45		cur = cur.Prev()
 46		if cur != nil && cur.IsBranching() {
 47			br1 = cur
 48			cur = cur.Prev()
 49		}
 50	}
 51
 52	if br0 != nil {
 53		c.lowerBranches(br0, br1)
 54	}
 55
 56	if br1 != nil && br0 == nil {
 57		panic("BUG? when a block has conditional branch but doesn't end with an unconditional branch?")
 58	}
 59
 60	// Now start lowering the non-branching instructions.
 61	for ; cur != nil; cur = cur.Prev() {
 62		c.setCurrentGroupID(cur.GroupID())
 63		if cur.Lowered() {
 64			continue
 65		}
 66
 67		switch cur.Opcode() {
 68		case ssa.OpcodeReturn:
 69			rets := cur.ReturnVals()
 70			if len(rets) > 0 {
 71				c.mach.LowerReturns(rets)
 72			}
 73			c.mach.InsertReturn()
 74		default:
 75			mach.LowerInstr(cur)
 76		}
 77		mach.FlushPendingInstructions()
 78	}
 79
 80	// Finally, if this is the entry block, we have to insert copies of arguments from the real location to the VReg.
 81	if blk.EntryBlock() {
 82		c.lowerFunctionArguments(blk)
 83	}
 84
 85	mach.EndBlock()
 86}
 87
 88// lowerBranches is called right after StartBlock and before any LowerInstr call if
 89// there are branches to the given block. br0 is the very end of the block and b1 is the before the br0 if it exists.
 90// At least br0 is not nil, but br1 can be nil if there's no branching before br0.
 91//
 92// See ssa.Instruction IsBranching, and the comment on ssa.BasicBlock.
 93func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) {
 94	mach := c.mach
 95
 96	c.setCurrentGroupID(br0.GroupID())
 97	c.mach.LowerSingleBranch(br0)
 98	mach.FlushPendingInstructions()
 99	if br1 != nil {
100		c.setCurrentGroupID(br1.GroupID())
101		c.mach.LowerConditionalBranch(br1)
102		mach.FlushPendingInstructions()
103	}
104
105	if br0.Opcode() == ssa.OpcodeJump {
106		_, args, targetBlockID := br0.BranchData()
107		argExists := len(args) != 0
108		if argExists && br1 != nil {
109			panic("BUG: critical edge split failed")
110		}
111		target := c.ssaBuilder.BasicBlock(targetBlockID)
112		if argExists && target.ReturnBlock() {
113			if len(args) > 0 {
114				c.mach.LowerReturns(args)
115			}
116		} else if argExists {
117			c.lowerBlockArguments(args, target)
118		}
119	}
120	mach.FlushPendingInstructions()
121}
122
123func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) {
124	mach := c.mach
125
126	c.tmpVals = c.tmpVals[:0]
127	data := c.ssaBuilder.ValuesInfo()
128	for i := 0; i < entry.Params(); i++ {
129		p := entry.Param(i)
130		if data[p.ID()].RefCount > 0 {
131			c.tmpVals = append(c.tmpVals, p)
132		} else {
133			// If the argument is not used, we can just pass an invalid value.
134			c.tmpVals = append(c.tmpVals, ssa.ValueInvalid)
135		}
136	}
137	mach.LowerParams(c.tmpVals)
138	mach.FlushPendingInstructions()
139}
140
141// lowerBlockArguments lowers how to pass arguments to the given successor block.
142func (c *compiler) lowerBlockArguments(args []ssa.Value, succ ssa.BasicBlock) {
143	if len(args) != succ.Params() {
144		panic("BUG: mismatched number of arguments")
145	}
146
147	c.varEdges = c.varEdges[:0]
148	c.varEdgeTypes = c.varEdgeTypes[:0]
149	c.constEdges = c.constEdges[:0]
150	for i := 0; i < len(args); i++ {
151		dst := succ.Param(i)
152		src := args[i]
153
154		dstReg := c.VRegOf(dst)
155		srcInstr := c.ssaBuilder.InstructionOfValue(src)
156		if srcInstr != nil && srcInstr.Constant() {
157			c.constEdges = append(c.constEdges, struct {
158				cInst *ssa.Instruction
159				dst   regalloc.VReg
160			}{cInst: srcInstr, dst: dstReg})
161		} else {
162			srcReg := c.VRegOf(src)
163			// Even when the src=dst, insert the move so that we can keep such registers keep-alive.
164			c.varEdges = append(c.varEdges, [2]regalloc.VReg{srcReg, dstReg})
165			c.varEdgeTypes = append(c.varEdgeTypes, src.Type())
166		}
167	}
168
169	// Check if there's an overlap among the dsts and srcs in varEdges.
170	c.vRegIDs = c.vRegIDs[:0]
171	for _, edge := range c.varEdges {
172		src := edge[0].ID()
173		if int(src) >= len(c.vRegSet) {
174			c.vRegSet = append(c.vRegSet, make([]bool, src+1)...)
175		}
176		c.vRegSet[src] = true
177		c.vRegIDs = append(c.vRegIDs, src)
178	}
179	separated := true
180	for _, edge := range c.varEdges {
181		dst := edge[1].ID()
182		if int(dst) >= len(c.vRegSet) {
183			c.vRegSet = append(c.vRegSet, make([]bool, dst+1)...)
184		} else {
185			if c.vRegSet[dst] {
186				separated = false
187				break
188			}
189		}
190	}
191	for _, id := range c.vRegIDs {
192		c.vRegSet[id] = false // reset for the next use.
193	}
194
195	if separated {
196		// If there's no overlap, we can simply move the source to destination.
197		for i, edge := range c.varEdges {
198			src, dst := edge[0], edge[1]
199			c.mach.InsertMove(dst, src, c.varEdgeTypes[i])
200		}
201	} else {
202		// Otherwise, we allocate a temporary registers and move the source to the temporary register,
203		//
204		// First move all of them to temporary registers.
205		c.tempRegs = c.tempRegs[:0]
206		for i, edge := range c.varEdges {
207			src := edge[0]
208			typ := c.varEdgeTypes[i]
209			temp := c.AllocateVReg(typ)
210			c.tempRegs = append(c.tempRegs, temp)
211			c.mach.InsertMove(temp, src, typ)
212		}
213		// Then move the temporary registers to the destination.
214		for i, edge := range c.varEdges {
215			temp := c.tempRegs[i]
216			dst := edge[1]
217			c.mach.InsertMove(dst, temp, c.varEdgeTypes[i])
218		}
219	}
220
221	// Finally, move the constants.
222	for _, edge := range c.constEdges {
223		cInst, dst := edge.cInst, edge.dst
224		c.mach.InsertLoadConstantBlockArg(cInst, dst)
225	}
226}