Skip to content

Commit

Permalink
integrate new float8 quantization primitives into AQT
Browse files Browse the repository at this point in the history
ghstack-source-id: 982ea07cc649320a57bf47120489bac86f2e900b
ghstack-comment-id: 2608090492
Pull Request resolved: #1598
  • Loading branch information
danielvegamyhre committed Jan 24, 2025
1 parent ee34c68 commit be2d094
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@

import torch

from torchao.dtypes.utils import (
AQTTensorImpl,
Layout,
PlainLayout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor

logger = logging.getLogger(__name__)
aten = torch.ops.aten
Expand Down Expand Up @@ -422,6 +417,39 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"

# to avoid circular dependency
from torchao.dtypes.floatx import Float8AQTTensorImpl

original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn):
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down

0 comments on commit be2d094

Please sign in to comment.