Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate new float8 quantization primitives into AQT #1598

Open
wants to merge 11 commits into
base: gh/danielvegamyhre/23/head
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@

import torch

from torchao.dtypes.utils import (
AQTTensorImpl,
Layout,
PlainLayout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
FP8_TYPES,
MappingType,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
ZeroPointDomain,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor

logger = logging.getLogger(__name__)
aten = torch.ops.aten
Expand Down Expand Up @@ -422,6 +417,39 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update from_hp_to_floatx with the new float8 logic. For fp1-fp7, we're using from_hp_to_floatx.

cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a docblock here should explain the difference between from_hp_to_floatx, from_hp_to_fpx, from_hp_to_float8

assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"

# to avoid circular dependency
from torchao.dtypes.floatx import Float8AQTTensorImpl

original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn):
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
Loading