Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] enable direct configuration in quantize_ #1585

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 17, 2025

summary

This PR is a POC for passing per-workflow arguments to quantize_ directly, without wrapping them in a Callable.

High level motivation: passing direct configuraton is intuintive and widely used in similar contexts across various projects. Passing configuration wrapped in a callable is IMO not intuitive, hard to understand and debug, and we have evidence that it pushes a significant portion of users from building on top of torchao.

user facing API proposed changes

signature of quantize_

#
# before
#
def quantize(
    model: torch.nn.Module,
    apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
    ...,
): ...

#
# after - intermediate state (for the time that we need to keep the old syntax around)
# we need to clarify whether the intermediate state is needed for 1 to 2 releases, or if we should
# just go directly from before to after
#
def quantize(
    model: torch.nn.Module,
    apply_tensor_subclass: Union[Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig],
    ...,
): ...

#
# after - long term state
#
def quantize(
    model: torch.nn.Module,
    config: AOBaseWorkflowConfig,
    ...,
): ...

usage example

An example for int4_weight_only

#
# before
#
quantize_(m, int4_weight_only(group_size=32))

#
# after, with new user facing names
#
quantize_(m, Int4WeightOnlyWorkflowConfig(group_size=32))

#
# AND, after, with BC names
#
quantize_(m, int4_weight_only(group_size=32))

developer facing proposed changes

See the PR details for examples, but they can be summarized as:

#
# old
#

# quantize_ calls the instance of calling this function on each module of the model
def int4_weight_only(group_size: int, ...) -> Callable:

    def new_callable(weight: torch.Tensor):
        # configuration is captured here via local variables
        ...
        
    # return type is a Callable
    return _get_linear_subclass_inserter(new_callable)

#
# new
#

# config base class
class AOBaseWorkflowConfig(abc.ABC):
    pass

# user facing configuration of a workflow
@dataclass
class Int4WeightOnlyWorkflowConfig(AOBaseWorkflowConfig):
    group_size: int = 128
    ...

# not user facing transform of a module according to a worfklow's configuration
@register_quantize_module_handler(Int4WeightOnlyWorkflowConfig)
def _int4_weight_only_transform(
    module: torch.nn.Module, 
    config: UserFacingWorkflowConfig,
) -> torch.nn.Module:
    # map to AQT, not user facing
    ...

current status

At this point this is a POC, and I've shown how it can work on three user facing workflows:

  • PTQ's int4_weight_only
  • QAT's intx_quantization_aware_training and from_intx_quantization_aware_training

next steps

discuss more broadly

Test Plan:

pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1585

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit 997f715 with merge base 32d9b0b (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@vkuzo vkuzo force-pushed the 20250117_torchao_config_example branch 2 times, most recently from 14a21bb to d7cdbb4 Compare January 17, 2025 22:21
@vkuzo vkuzo changed the title [wip] configs configs configs! [wip] enable direct configuration in torchao, instead of tensor_subclass_inserter Jan 17, 2025
@vkuzo vkuzo changed the title [wip] enable direct configuration in torchao, instead of tensor_subclass_inserter [wip] enable direct configuration in quantize_, without Callable wrapping Jan 17, 2025
@vkuzo vkuzo changed the title [wip] enable direct configuration in quantize_, without Callable wrapping [wip] enable direct configuration in quantize_ Jan 17, 2025
@vkuzo vkuzo changed the title [wip] enable direct configuration in quantize_ [rfc] enable direct configuration in quantize_ Jan 17, 2025
@drisspg
Copy link
Contributor

drisspg commented Jan 17, 2025

Is there someone on the quanty/core side of things that can chime in on this API cc @andrew Or

Tbh this feels almost too similar to the original since this isn't a Pure dataclass and has the _transform func.

I imagine people want:

quantize_(model, THE_CONFIG_TYPE(weight_scheme="f38", a_scheme"a12", ...)"

FWIW :

ao/scripts/hf_eval.py

Lines 68 to 147 in 1240b19

if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
if quantization == "int8dq":
quantize_(model, int8_dynamic_activation_int8_weight())
elif quantization == "int8wo":
quantize_(model, int8_weight_only())
elif quantization == "int4wo":
# note cannot quantize this model on cpu and run it on cuda at this time
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "fp6":
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))
elif quantization == "awq":
from torchao.prototype.awq.example import get_calib_dataset
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
if not TORCH_VERSION_AT_LEAST_2_3:
print("AWQ quantization requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model = model.to(device)
insert_awq_observer_(
model,
calibration_limit,
calibration_seq_length,
quant_dtype=quant_dtype,
group_size=group_size,
)
with torch.no_grad():
calibration_data = get_calib_dataset(
tokenizer=tokenizer,
n_samples=calibration_limit,
block_size=calibration_seq_length,
)
for batch in calibration_data:
model(batch.to(device))
del batch
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(
model,
awq_uintx(quant_dtype=quant_dtype, group_size=group_size),
is_observed_linear,
)
if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
if sparsity == "semi_sparse":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
elif sparsity == "semi_sparse_mlp_only":
def all_linear(mod, name):
if (
isinstance(mod, torch.nn.Linear)
and "lm_head" not in name
and "mlp" in name
):
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

I think this code is a good example of the downsides of the current approach. This has high cognitive load IMO and is hard to tell what exactly is being chosen when. I think a good sign that the new API is better if it also could clean up this type of code

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 17, 2025

Is there someone on the quanty/core side of things that can chime in on this API cc @andrew Or

Tbh this feels almost too similar to the original since this isn't a Pure dataclass and has the _transform func.

I imagine people want:

quantize_(model, THE_CONFIG_TYPE(weight_scheme="f38", a_scheme"a12", ...)"

FWIW :

ao/scripts/hf_eval.py

Lines 68 to 147 in 1240b19

if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
if quantization == "int8dq":
quantize_(model, int8_dynamic_activation_int8_weight())
elif quantization == "int8wo":
quantize_(model, int8_weight_only())
elif quantization == "int4wo":
# note cannot quantize this model on cpu and run it on cuda at this time
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "fp6":
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))
elif quantization == "awq":
from torchao.prototype.awq.example import get_calib_dataset
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
if not TORCH_VERSION_AT_LEAST_2_3:
print("AWQ quantization requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model = model.to(device)
insert_awq_observer_(
model,
calibration_limit,
calibration_seq_length,
quant_dtype=quant_dtype,
group_size=group_size,
)
with torch.no_grad():
calibration_data = get_calib_dataset(
tokenizer=tokenizer,
n_samples=calibration_limit,
block_size=calibration_seq_length,
)
for batch in calibration_data:
model(batch.to(device))
del batch
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(
model,
awq_uintx(quant_dtype=quant_dtype, group_size=group_size),
is_observed_linear,
)
if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
if sparsity == "semi_sparse":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
elif sparsity == "semi_sparse_mlp_only":
def all_linear(mod, name):
if (
isinstance(mod, torch.nn.Linear)
and "lm_head" not in name
and "mlp" in name
):
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

I think this code is a good example of the downsides of the current approach. This has high cognitive load IMO and is hard to tell what exactly is being chosen when. I think a good sign that the new API is better if it also could clean up this type of code

So, I'm actually highly in favor of moving the transform function out of user facing land to keep the config clean. However, we don't have consensus on this on the team, with @HDCharles liking the colocation. I'm a big fan of easy-to-think-about incremental changes moving in the right direction, so currently I'm just punting this discussion till later, and "enforcing" that by keeping the transform function private with an underscore and "scary" docblocks. Since this is private, we could do it at any future time.

I personally care a lot about "don't wrap config in callable" and consider that blocking from moving training use cases here. I'm mildly annoyed by keeping the transform here, but not the hill I'd die on personally.

@andrewor14
Copy link
Contributor

Is there someone on the quanty/core side of things that can chime in on this API

Yeah, the API they want is something like this:

@dataclass
class QuantizationRecipe:
    # weights
    weight_bits: int = 4
    weight_group_size: int | str = 32
    weight_quantization: bool = True
    dynamic_weights: bool = False

    # activations
    activation_bits: int = 8
    activation_group_size: int | str = "per_token"
    activation_quantization: bool = True
    dynamic_activations: bool = True

def quantize(model: nn.Module, recipe: QuantizationRecipe):
    ...

which I think is simple and pretty reasonable. Given a model and a config or recipe, quantize the model according to the settings configured in it. I think this kind of API is also pretty common. E.g.

  • vllm uses "Modifier" recipes to configure the quantization scheme
  • bitsandbytes allows users to set optimizer bits by overriding a config
  • HF transformers uses our TorchAoConfig for example, and they have different configs for all other quantization libraries too

As far as I know passing in a function as the main API isn't common at all. I would also argue it's better practice in general to pass in clearly defined types instead of arbitrary functions that let the user do anything they want. They can easily write the latter themselves if they want, just search for all nn.Linear modules in the model and apply a custom function on them. They don't need quantize_ to do that.

@andrewor14
Copy link
Contributor

So, I'm actually highly in favor of moving the transform function out of user facing land to keep the config clean. However, we don't have consensus on this on the team, with @HDCharles liking the colocation.

I think another way of address this is to just register a handler for all configs we want to handle in quantize_. E.g.

def quantize_(model: nn.Module, config: BaseConfig, ...existing args):
    if type(config) in _QUANTIZE_CONFIG_HANDLER:
        _QUANTIZE_CONFIG_HANDLER[type(config)](model, config)
    else:
        raise ValueError("Unknown config type")

@register_quantize_handler(AffineQuantizationConfig)
def _do_affine_quantization(model: nn.Module, config: AffineQuantizationConfig)

@register_quantize_handler(QATConfig)
def _do_qat(model: nn.Module, config: QATConfig)

This will keep quantize_ small, and we won't need a separate transform_fn argument so the signature will be simpler. Configs will also be simpler and just be configs; they don't need to know anything about how they will be used.

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 21, 2025

summary of offline discussion:

  1. will change to a registration API for transforms, as suggested by @andrewor14
  2. will describe BC plan in more detail
  3. will get some user feedback

Summary:

POC for:

* decoupling configuration from transformation
* stop passing obscure stateful callables around
* enable printing of configuration
* reduce amount of context switching to navigate the logic from `quantize_` to
  quantizing a single module

TODO more polish before wider discussion.

Test Plan:

```
pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20250117_torchao_config_example branch from d7cdbb4 to 997f715 Compare January 21, 2025 21:42
@vkuzo vkuzo requested a review from andrewor14 January 21, 2025 22:05
Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New proposal looks great to me!

weight_config: Optional[FakeQuantizeConfig] = None,
) -> Callable:
@dataclass
class IntXQuantizationAwareTrainingWorkflowConfig(AOBaseWorkflowConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe drop Workflow from the name? Seems a bit long

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping the base class as AOBaseWorkflowConfig and removing Workflow from child classes could work here

apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
# apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
apply_tensor_subclass: Union[
Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are two BC breaking steps:

  1. Change the arg name to "config"
  2. Change the arg type to AOBaseWorkflowConfig only

Based on my quick search I feel we can just do (1) now. I haven't found any use cases yet that call this with a keyword argument, so this seems relatively safe. For (2) maybe we should deprecate with warning first and remove after a couple releases? I feel this use case (users passing in arbitrary functions) is harder to grep for.

Curious what others think as well @drisspg @HDCharles

zero_point_domain=None,
):
@dataclass
class Int4WeightOnlyWorkflowConfig(AOBaseWorkflowConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I feel eventually we do want an intermediate AffineQuantizationConfig that lets users set the different dtypes, symmetric, dynamic etc, but I think we can add that later

def _int4_weight_only_transform(
module: torch.nn.Module, config: Int4WeightOnlyWorkflowConfig
) -> torch.nn.Module:
# TODO(future PR): perhaps move this logic to a different file, to keep the API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with this. We should move all configs and associated transforms to a separate file

@vkuzo vkuzo added the topic: bc-breaking Use this tag if this PR breaks backward compatibility label Jan 22, 2025
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 22, 2025

moving development to #1595 to move to stacked PRs

@vkuzo vkuzo closed this Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bc-breaking Use this tag if this PR breaks backward compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants