Skip to content

Commit

Permalink
Fix inconsistency with varying store/load widths on simd pointers.
Browse files Browse the repository at this point in the history
Signed-off-by: Max Brylski <[email protected]>
  • Loading branch information
helehex committed Nov 22, 2024
1 parent 4cd0762 commit 669ed15
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
28 changes: 28 additions & 0 deletions stdlib/src/memory/unsafe_pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,23 @@ struct UnsafePointer[
"both volatile and invariant cannot be set at the same time",
]()

# bool load unpacks bits differently depending on the width.
# This causes issues when doing stores and loads of different widths.
# Cast between uint8 to keep a consistent representation in memory.
# TODO: decide how/whether to pack SIMD[bool] bits in memory
@parameter
if type == DType.bool:
return (
self.bitcast[UInt8]()
.load[
width=width,
alignment=alignment,
volatile=volatile,
invariant=invariant,
]()
.cast[type]()
)

@parameter
if is_nvidia_gpu() and sizeof[type]() == 1 and alignment == 1:
# LLVM lowering to PTX incorrectly vectorizes loads for 1-byte types
Expand Down Expand Up @@ -785,6 +802,17 @@ struct UnsafePointer[
alignment > 0, "alignment must be a positive integer value"
]()

# bool store packs bits differently depending on the width.
# This causes issues when doing stores and loads of different widths.
# Cast between uint8 to keep a consistent representation in memory.
# TODO: decide how/whether to pack SIMD[bool] bits in memory
@parameter
if type == DType.bool:
self.bitcast[UInt8]()._store[
alignment=alignment, volatile=volatile
](val.cast[DType.uint8]())
return

@parameter
if volatile:
__mlir_op.`pop.store`[
Expand Down
22 changes: 22 additions & 0 deletions stdlib/test/memory/test_unsafepointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,34 @@ def test_volatile_load_and_store_simd():
for i in range(0, 16, 4):
var vec = ptr.load[width=4, volatile=True](i)
assert_equal(vec, SIMD[DType.int8, 4](i, i + 1, i + 2, i + 3))
ptr.free()

var ptr2 = UnsafePointer[Int8].alloc(16)
for i in range(0, 16, 4):
ptr2.store[volatile=True](i, SIMD[DType.int8, 4](i))
for i in range(16):
assert_equal(ptr2[i], i // 4 * 4)
ptr2.free()

# test for bool store/load consistency with different widths
var ptr3 = UnsafePointer[Scalar[DType.bool]].alloc(4)
ptr3.store[volatile=True](SIMD[DType.bool, 2](True, False))
ptr3[2] = False
ptr3[3] = True
assert_equal(
ptr3.load[width=4, volatile=True](),
SIMD[DType.bool, 4](True, False, False, True),
)
ptr3.free()

var ptr4 = UnsafePointer[Scalar[DType.bool]].alloc(4)
ptr4.store[volatile=True](SIMD[DType.bool, 4](True, False, False, True))
assert_equal(
ptr3.load[width=2, volatile=True](), SIMD[DType.bool, 2](True, False)
)
assert_equal(ptr4[2], False)
assert_equal(ptr4[3], True)
ptr4.free()


def main():
Expand Down

0 comments on commit 669ed15

Please sign in to comment.