Skip to content

Commit

Permalink
wazevo(amd64): SIMD lane load/store (tetratelabs#2045)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake authored Feb 12, 2024
1 parent 6eb0ab4 commit 44bc48f
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 23 deletions.
11 changes: 10 additions & 1 deletion internal/engine/wazevo/backend/isa/amd64/instr.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ func (i *instruction) String() string {
case xmmCmpRmR:
return fmt.Sprintf("%s %s, %s", sseOpcode(i.u1), i.op1.format(false), i.op2.format(false))
case xmmRmRImm:
return fmt.Sprintf("%s $%d, %s, %s", sseOpcode(i.u1), i.u2, i.op1.format(false), i.op2.format(false))
op := sseOpcode(i.u1)
r1, r2 := i.op1.format(op == sseOpcodePextrq || op == sseOpcodePinsrq),
i.op2.format(op == sseOpcodePextrq || op == sseOpcodePinsrq)
return fmt.Sprintf("%s $%d, %s, %s", op, i.u2, r1, r2)
case jmp:
return fmt.Sprintf("jmp %s", i.op1.format(true))
case jmpIf:
Expand Down Expand Up @@ -1512,9 +1515,11 @@ const (
sseOpcodePextrb
sseOpcodePextrw
sseOpcodePextrd
sseOpcodePextrq
sseOpcodePinsrb
sseOpcodePinsrw
sseOpcodePinsrd
sseOpcodePinsrq
sseOpcodePmaddwd
sseOpcodePmaxsb
sseOpcodePmaxsw
Expand Down Expand Up @@ -1769,12 +1774,16 @@ func (s sseOpcode) String() string {
return "pextrw"
case sseOpcodePextrd:
return "pextrd"
case sseOpcodePextrq:
return "pextrq"
case sseOpcodePinsrb:
return "pinsrb"
case sseOpcodePinsrw:
return "pinsrw"
case sseOpcodePinsrd:
return "pinsrd"
case sseOpcodePinsrq:
return "pinsrq"
case sseOpcodePmaddwd:
return "pmaddwd"
case sseOpcodePmaxsb:
Expand Down
42 changes: 41 additions & 1 deletion internal/engine/wazevo/backend/isa/amd64/instr_encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -1051,23 +1051,63 @@ func (i *instruction) encode(c backend.Compiler) (needsLabelResolution bool) {
var legPrex legacyPrefixes
var opcode uint32
var opcodeNum uint32
var swap bool
switch op {
case sseOpcodeCmpps:
legPrex, opcode, opcodeNum = legacyPrefixesNone, 0x0FC2, 2
case sseOpcodeCmppd:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0FC2, 2
case sseOpcodeCmpss:
legPrex, opcode, opcodeNum = legacyPrefixes0xF3, 0x0FC2, 2
case sseOpcodeCmpsd:
legPrex, opcode, opcodeNum = legacyPrefixes0xF2, 0x0FC2, 2
case sseOpcodeInsertps:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A21, 3
case sseOpcodePalignr:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A0F, 3
case sseOpcodePinsrb:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A20, 3
case sseOpcodePinsrw:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0FC4, 2
case sseOpcodePinsrd, sseOpcodePinsrq:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A22, 3
case sseOpcodePextrb:
swap = true
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A14, 3
case sseOpcodePextrw:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0FC5, 2
case sseOpcodePextrd, sseOpcodePextrq:
swap = true
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A16, 3
case sseOpcodePshufd:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F70, 2
case sseOpcodeRoundps:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A08, 3
case sseOpcodeRoundpd:
legPrex, opcode, opcodeNum = legacyPrefixes0x66, 0x0F3A09, 3
default:
panic(fmt.Sprintf("Unsupported sseOpcode: %s", op))
}

dst := regEncodings[i.op2.reg().RealReg()]

rex := rexInfo(0).clearW()
var rex rexInfo
if op == sseOpcodePextrq || op == sseOpcodePinsrq {
rex = rexInfo(0).setW()
} else {
rex = rexInfo(0).clearW()
}
op1 := i.op1
if op1.kind == operandKindReg {
src := regEncodings[op1.reg().RealReg()]
if swap {
src, dst = dst, src
}
encodeRegReg(c, legPrex, opcode, opcodeNum, dst, src, rex)
} else if i.op1.kind == operandKindMem {
if swap {
panic("BUG: this is not possible to encode")
}
m := i.op1.addressMode()
encodeRegMem(c, legPrex, opcode, opcodeNum, dst, m, rex)
} else {
Expand Down
82 changes: 75 additions & 7 deletions internal/engine/wazevo/backend/isa/amd64/instr_encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4050,6 +4050,81 @@ func TestInstruction_format_encode(t *testing.T) {
want: "66450fefe4",
wantFormat: "xor %xmm12, %xmm12",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodeCmpss, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "f30fc2c119",
wantFormat: "cmpss $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodeCmpsd, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "f20fc2c119",
wantFormat: "cmpsd $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodeInsertps, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "660f3a21c119",
wantFormat: "insertps $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePalignr, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "660f3a0fc119",
wantFormat: "palignr $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePinsrb, uint8(25), newOperandReg(r14VReg), xmm1VReg) },
want: "66410f3a20ce19",
wantFormat: "pinsrb $25, %r14d, %xmm1",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePinsrw, uint8(25), newOperandReg(r14VReg), xmm1VReg) },
want: "66410fc4ce19",
wantFormat: "pinsrw $25, %r14d, %xmm1",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePinsrd, uint8(25), newOperandReg(r14VReg), xmm1VReg) },
want: "66410f3a22ce19",
wantFormat: "pinsrd $25, %r14d, %xmm1",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePinsrq, uint8(25), newOperandReg(r14VReg), xmm1VReg) },
want: "66490f3a22ce19",
wantFormat: "pinsrq $25, %r14, %xmm1",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePextrb, uint8(25), newOperandReg(xmm1VReg), r14VReg) },
want: "66410f3a14ce19",
wantFormat: "pextrb $25, %xmm1, %r14d",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePextrw, uint8(25), newOperandReg(xmm1VReg), r14VReg) },
want: "66440fc5f119",
wantFormat: "pextrw $25, %xmm1, %r14d",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePextrd, uint8(25), newOperandReg(xmm1VReg), rbxVReg) },
want: "660f3a16cb19",
wantFormat: "pextrd $25, %xmm1, %ebx",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePextrq, uint8(25), newOperandReg(xmm1VReg), rdxVReg) },
want: "66480f3a16ca19",
wantFormat: "pextrq $25, %xmm1, %rdx",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodePshufd, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "660f70c119",
wantFormat: "pshufd $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodeRoundps, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "660f3a08c119",
wantFormat: "roundps $25, %xmm1, %xmm0",
},
{
setup: func(i *instruction) { i.asXmmRmRImm(sseOpcodeRoundpd, uint8(25), newOperandReg(xmm1VReg), xmm0VReg) },
want: "660f3a09c119",
wantFormat: "roundpd $25, %xmm1, %xmm0",
},
} {
tc := tc
t.Run(tc.wantFormat, func(t *testing.T) {
Expand All @@ -4062,13 +4137,6 @@ func TestInstruction_format_encode(t *testing.T) {
m := &machine{c: mc}
i.encode(m.c)
require.Equal(t, tc.want, hex.EncodeToString(mc.buf))

// TODO: verify the size of the encoded instructions.
//var actualSize int
//for cur := i; cur != nil; cur = cur.next {
// actualSize += int(cur.size())
//}
//require.Equal(t, len(tc.want)/2, actualSize)
})
}
}
84 changes: 84 additions & 0 deletions internal/engine/wazevo/backend/isa/amd64/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,14 @@ func (m *machine) LowerInstr(instr *ssa.Instruction) {
x, y, c, lane := instr.VFcmpData()
m.lowerVFcmp(x, y, c, instr.Return(), lane)

case ssa.OpcodeExtractlane:
x, index, signed, lane := instr.ExtractlaneData()
m.lowerExtractLane(x, index, signed, instr.Return(), lane)

case ssa.OpcodeInsertlane:
x, y, index, lane := instr.InsertlaneData()
m.lowerInsertLane(x, y, index, instr.Return(), lane)

case ssa.OpcodeVIabs:
m.lowerVIabs(instr)
case ssa.OpcodeVIpopcnt:
Expand Down Expand Up @@ -2846,6 +2854,82 @@ func (m *machine) lowerVbnot(instr *ssa.Instruction) {
m.copyTo(tmp, rd)
}

func (m *machine) lowerInsertLane(x, y ssa.Value, index byte, ret ssa.Value, lane ssa.VecLane) {
// Copy x to tmp.
tmpDst := m.c.AllocateVReg(ssa.TypeV128)
m.insert(m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqu, m.getOperand_Mem_Reg(m.c.ValueDefinition(x)), tmpDst))

yy := m.getOperand_Reg(m.c.ValueDefinition(y))
switch lane {
case ssa.VecLaneI8x16:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePinsrb, index, yy, tmpDst))
case ssa.VecLaneI16x8:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePinsrw, index, yy, tmpDst))
case ssa.VecLaneI32x4:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePinsrd, index, yy, tmpDst))
case ssa.VecLaneI64x2:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePinsrq, index, yy, tmpDst))
case ssa.VecLaneF32x4:
// In INSERTPS instruction, the destination index is encoded at 4 and 5 bits of the argument.
// See https://www.felixcloutier.com/x86/insertps
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodeInsertps, index<<4, yy, tmpDst))
case ssa.VecLaneF64x2:
if index == 0 {
m.insert(m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovsd, yy, tmpDst))
} else {
m.insert(m.allocateInstr().asXmmRmR(sseOpcodeMovlhps, yy, tmpDst))
}
default:
panic(fmt.Sprintf("invalid lane type: %s", lane))
}

m.copyTo(tmpDst, m.c.VRegOf(ret))
}

func (m *machine) lowerExtractLane(x ssa.Value, index byte, signed bool, ret ssa.Value, lane ssa.VecLane) {
// Pextr variants are used to extract a lane from a vector register.
xx := m.getOperand_Reg(m.c.ValueDefinition(x))

tmpDst := m.c.AllocateVReg(ret.Type())
m.insert(m.allocateInstr().asDefineUninitializedReg(tmpDst))
switch lane {
case ssa.VecLaneI8x16:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePextrb, index, xx, tmpDst))
if signed {
m.insert(m.allocateInstr().asMovsxRmR(extModeBL, newOperandReg(tmpDst), tmpDst))
} else {
m.insert(m.allocateInstr().asMovzxRmR(extModeBL, newOperandReg(tmpDst), tmpDst))
}
case ssa.VecLaneI16x8:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePextrw, index, xx, tmpDst))
if signed {
m.insert(m.allocateInstr().asMovsxRmR(extModeWL, newOperandReg(tmpDst), tmpDst))
} else {
m.insert(m.allocateInstr().asMovzxRmR(extModeWL, newOperandReg(tmpDst), tmpDst))
}
case ssa.VecLaneI32x4:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePextrd, index, xx, tmpDst))
case ssa.VecLaneI64x2:
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePextrq, index, xx, tmpDst))
case ssa.VecLaneF32x4:
if index == 0 {
m.insert(m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovss, xx, tmpDst))
} else {
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePshufd, index, xx, tmpDst))
}
case ssa.VecLaneF64x2:
if index == 0 {
m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovsd, xx, tmpDst)
} else {
m.insert(m.allocateInstr().asXmmRmRImm(sseOpcodePshufd, index, xx, tmpDst))
}
default:
panic(fmt.Sprintf("invalid lane type: %s", lane))
}

m.copyTo(tmpDst, m.c.VRegOf(ret))
}

func (m *machine) lowerVbBinOp(op sseOpcode, x, y, ret ssa.Value) {
rn := m.getOperand_Reg(m.c.ValueDefinition(x))
rm := m.getOperand_Mem_Reg(m.c.ValueDefinition(y))
Expand Down
20 changes: 10 additions & 10 deletions internal/engine/wazevo/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1539,18 +1539,18 @@ func TestSpectestV2(t *testing.T) {
//{"simd_load_extend"},
//{"simd_load_splat"},
//{"simd_load_zero"},
//{"simd_load8_lane"},
//{"simd_load16_lane"},
//{"simd_load32_lane"},
//{"simd_load64_lane"},
{"simd_load8_lane"},
{"simd_load16_lane"},
{"simd_load32_lane"},
{"simd_load64_lane"},
//{"simd_lane"},
//{"simd_linking"},
{"simd_linking"},
//{"simd_splat"},
//{"simd_store"},
//{"simd_store8_lane"},
//{"simd_store16_lane"},
//{"simd_store32_lane"},
//{"simd_store64_lane"},
{"simd_store"},
{"simd_store8_lane"},
{"simd_store16_lane"},
{"simd_store32_lane"},
{"simd_store64_lane"},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
Expand Down
6 changes: 2 additions & 4 deletions internal/engine/wazevo/ssa/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,10 @@ const (
// OpcodeSwizzle performs a vector swizzle operation: `v = Swizzle.lane x, y`.
OpcodeSwizzle

// OpcodeInsertlane ...
// `v = insertlane x, y, Idx`. (TernaryImm8)
// OpcodeInsertlane inserts a lane value into a vector: `v = InsertLane x, y, Idx`.
OpcodeInsertlane

// OpcodeExtractlane ...
// `v = extractlane x, Idx`. (BinaryImm8)
// OpcodeExtractlane extracts a lane value from a vector: `v = ExtractLane x, Idx`.
OpcodeExtractlane

// OpcodeLoad loads a Type value from the [base + offset] address: `v = Load base, offset`.
Expand Down

0 comments on commit 44bc48f

Please sign in to comment.