Skip to content

Commit

Permalink
Use faster operations on packed-quantized, add tests (#211)
Browse files Browse the repository at this point in the history
* bitwise op to make it faster, add tests

* Update tests/test_compressors/quantized_compressors/test_pack_quant.py

Co-authored-by: Michael Goin <[email protected]>

* rahul-tuli comment

* comments

---------

Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
horheynm and mgoin authored Jan 31, 2025
1 parent 890608d commit b9c536d
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,20 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
"""
Packs a tensor of quantized weights stored in int8 into int32s with padding
Pseudocode:
1. Shift wrt num_bits to convert to unsigned. num_bits=8
[1,2] -> [129, 130]
2. Pad to fill in 32 bits
[129, 130] -> [129, 130, 0, 0]
3. convert to binary align in order
[129, 130, 0, 0] -> 00000000 00000000 10000010 10000001
4. convert aligned binary to number
00000000000000001000001010000001 -> 33409
5. covert back to uint32
33409 -> 33409
:param value: tensor to pack
:param num_bits: number of bits used to store underlying data
:param num_bits: number of bits used to store underlying data, must be at least 1
:returns: packed int32 tensor
"""
if value.dtype is not torch.int8:
Expand All @@ -148,19 +160,22 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
if num_bits > 8:
raise ValueError("Packing is only supported for less than 8 bits")

if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# convert to unsigned for packing
offset = pow(2, num_bits) // 2
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
value = value.cpu().numpy().astype(np.uint32)
pack_factor = 32 // num_bits

# pad input tensor and initialize packed output
packed_size = math.ceil(value.shape[1] / pack_factor)
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
padding = packed.shape[1] * pack_factor - value.shape[1]
padding = packed_size * pack_factor - value.shape[1]
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)

# pack values
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[:, i::pack_factor] << num_bits * i

Expand All @@ -174,7 +189,9 @@ def unpack_from_int32(
) -> torch.Tensor:
"""
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
original their bit range
original bit range.
Return tensors in int8
:param value: tensor to upack
:param num_bits: number of bits to unpack each data point into
Expand All @@ -192,7 +209,7 @@ def unpack_from_int32(
pack_factor = 32 // num_bits

# unpack
mask = pow(2, num_bits) - 1
mask = (1 << num_bits) - 1
unpacked = torch.zeros(
(value.shape[0], value.shape[1] * pack_factor),
device=value.device,
Expand Down
162 changes: 162 additions & 0 deletions tests/test_compressors/quantized_compressors/test_pack_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,165 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"])

shutil.rmtree(tmp_path)


@pytest.mark.parametrize(
"num_bits,values,expected_values",
[
(
4,
torch.tensor([[1]]),
torch.tensor([[9]], dtype=torch.int32),
),
(
8,
torch.tensor([[1]]),
torch.tensor([[129]], dtype=torch.int32),
),
# 0000 0000 0000 0000 1100 1011 1010 1001
(4, torch.tensor([[1, 2, 3, 4]]), torch.tensor([[52137]], dtype=torch.int32)),
# 0111 0110 0101 0100 0011 0010 0001 0000
(
4,
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]]),
torch.tensor([[1985229328]], dtype=torch.int32),
),
# 10000100 10000011 10000010 10000001
(
8,
torch.tensor([[1, 2, 3, 4]]),
torch.tensor([[-2071756159]], dtype=torch.int32),
),
# 00000011 00000010 00000001 00000000
(
8,
torch.tensor([[-128, -127, -126, -125]]),
torch.tensor([[50462976]], dtype=torch.int32),
),
(
4,
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]]),
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
),
(
4,
torch.tensor(
[
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
]
),
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
),
(
8,
torch.tensor(
[
[1, 2, 3, 4],
[-128, -127, -126, -125],
]
),
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
),
(
8,
torch.tensor(
[
[1, 2, 3, 4, -128, -127, -126, -125],
[-128, -127, -126, -125, 1, 2, 3, 4],
]
),
torch.tensor(
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
),
),
],
)
def test_pack_to_int32(num_bits, values, expected_values):
values = values.to(torch.int8)
packed_values = pack_to_int32(values, num_bits)
assert torch.equal(packed_values, expected_values)
assert packed_values.dtype == expected_values.dtype


@pytest.mark.parametrize(
"num_bits,values,expected_tensor",
[
(
4,
torch.tensor([[9]], dtype=torch.int32),
torch.tensor([[1]], dtype=torch.int8),
),
(
8,
torch.tensor([[129]], dtype=torch.int32),
torch.tensor([[1]], dtype=torch.int8),
),
(
4,
torch.tensor([[52137]], dtype=torch.int32),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
),
(
4,
torch.tensor([[1985229328]], dtype=torch.int32),
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]], dtype=torch.int8),
),
(
8,
torch.tensor([[-2071756159]], dtype=torch.int32),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
),
(
8,
torch.tensor([[50462976]], dtype=torch.int32),
torch.tensor([[-128, -127, -126, -125]], dtype=torch.int8),
),
(
4,
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
torch.tensor(
[[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]], dtype=torch.int8
),
),
(
4,
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
torch.tensor(
[
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
],
dtype=torch.int8,
),
),
(
8,
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
torch.tensor(
[
[1, 2, 3, 4],
[-128, -127, -126, -125],
],
dtype=torch.int8,
),
),
(
8,
torch.tensor(
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
),
torch.tensor(
[
[1, 2, 3, 4, -128, -127, -126, -125],
[-128, -127, -126, -125, 1, 2, 3, 4],
],
dtype=torch.int8,
),
),
],
)
def test_unpack_from_int32(num_bits, values, expected_tensor):
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
assert torch.equal(unpacked_tensor, unpacked_tensor)
assert unpacked_tensor.dtype == unpacked_tensor.dtype

0 comments on commit b9c536d

Please sign in to comment.