Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Support ipex cpu WOQ backend (#1546)
Browse files Browse the repository at this point in the history
* support ipex cpu woq

Signed-off-by: changwangss <[email protected]>

---------

Signed-off-by: changwangss <[email protected]>
Signed-off-by: Dong, Bo <[email protected]>
Co-authored-by: Dong, Bo <[email protected]>
  • Loading branch information
changwangss and a32543254 authored May 16, 2024
1 parent f8a7723 commit 008492d
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pip install -r requirements_cpu_woq.txt
> ```

### Run
We provide compression technologies such as `WeightOnlyQuant` with `Rtn/Awq/Teq/GPTQ/AutoRound` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, besides we provided use [neural-speed](https://github.com/intel/neural-speed) by `--use_neural_speed` to accelerate the optimized model, [here](https://github.com/intel/neural-speed/blob/main/docs/supported_models.md) is neural-speed supported list.
We provide compression technologies such as `WeightOnlyQuant` with `Rtn/Awq/Teq/GPTQ/AutoRound` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, besides we provide use ipex by `--use_ipex` to use intel extension for pytorch to accelerate the model, also provided use [neural-speed](https://github.com/intel/neural-speed) by `--use_neural_speed` to accelerate the optimized model, [here](https://github.com/intel/neural-speed/blob/main/docs/supported_models.md) is neural-speed supported list.
The followings are command to show how to use it.
#### Performance
```shell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"--max_new_tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--use_ipex", action="store_true")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--iters", default=100, type=int, help="num iter")
Expand Down Expand Up @@ -207,14 +208,14 @@
if args.woq:
if args.woq_algo == "Rtn":
quantization_config = RtnConfig(
tokenizer=tokenizer,
bits=args.bits,
sym=True if args.scheme == "sym" else False,
group_size=args.group_size,
compute_dtype=args.compute_dtype,
scale_dtype=args.scale_dtype,
weight_dtype=args.weight_dtype,
layer_wise=args.layer_wise,
use_ipex=args.use_ipex,
)
elif args.woq_algo == "Awq":
quantization_config = AwqConfig(
Expand All @@ -228,6 +229,7 @@
scale_dtype=args.scale_dtype,
weight_dtype=args.weight_dtype,
calib_iters=args.calib_iters,
use_ipex=args.use_ipex,
)
elif args.woq_algo == "Teq":
quantization_config = TeqConfig(
Expand All @@ -241,6 +243,7 @@
scale_dtype=args.scale_dtype,
weight_dtype=args.weight_dtype,
calib_iters=args.calib_iters,
use_ipex=args.use_ipex,
)
elif args.woq_algo == "GPTQ":
quantization_config = GPTQConfig(
Expand All @@ -260,6 +263,7 @@
weight_dtype=args.weight_dtype,
calib_iters=args.calib_iters,
layer_wise=args.layer_wise,
use_ipex=args.use_ipex,
)
elif args.woq_algo == "AutoRound":
quantization_config = AutoRoundConfig(
Expand All @@ -277,6 +281,7 @@
lr=args.lr,
minmax_lr=args.minmax_lr,
use_quant_input=args.use_quant_input,
use_ipex=args.use_ipex,
)
else:
assert False, "Please set the correct '--woq_algo'"
Expand Down Expand Up @@ -388,6 +393,8 @@
model_args += ",model_format=neural_speed"
args = LMEvalParser(model = "hf",
model_args=model_args,
#user_model=user_model,
#tokenizer=tokenizer,
tasks = args.tasks,
device = "cpu",
batch_size = args.batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,34 +183,79 @@ def _replace_linear(
or device == torch.device("cpu")
or device == "auto"
):
from .nn.modules import (
QuantizedLinearQBits,
) # TODO: QuantizedLinearINT4, QuantizedLinearINT8

use_optimum_format = getattr(module, "use_optimum_format", False) or \
quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
"fp4",
"nf4",
"int4_fullrange",
]

model._modules[name] = QuantizedLinearQBits(
in_features,
out_features,
module.bias is not None,
compute_dtype=quantization_config.compute_dtype,
compress_statistics=False,
weight_dtype=quantization_config.weight_dtype,
scale_dtype=quantization_config.scale_dtype,
blocksize=quantization_config.group_size,
scheme=quantization_config.scheme,
compression_dtype=getattr(module, "compression_dtype", torch.int32),
compression_dim=getattr(module, "compression_dim", 1),
device=device,
use_optimum_format=use_optimum_format,
)
if is_ipex_available() and quantization_config.use_ipex:
from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear
from intel_extension_for_pytorch.utils.weight_only_quantization import \
_convert_optimum_format_to_desired

qweight, scales, qzeros = _convert_optimum_format_to_desired(module.qweight,
module.scales,
module.qzeros)

weight_dtype = {
4: ipex.quantization.WoqWeightDtype.INT4,
8: ipex.quantization.WoqWeightDtype.INT8,
}
compute_dtype = {
"fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype.
"bf16": ipex.quantization.WoqLowpMode.BF16,
"fp16": ipex.quantization.WoqLowpMode.FP16,
"int8": ipex.quantization.WoqLowpMode.INT8,

}

ipex_qconfig_mapping = (
ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype[quantization_config.bits],
lowp_mode=compute_dtype[quantization_config.compute_dtype],
act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
group_size=quantization_config.group_size,
)
)
tmp_linear = torch.nn.Linear(
in_features,
out_features,
True if hasattr(module, "bias") else False
)
tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig
model._modules[name] = ipex_linear.from_float_and_int4_weight(
mod = tmp_linear,
qweight = qweight,
scales = scales,
zero_points = qzeros,
bias = module.bias if hasattr(module, "bias") else None,
group_size = quantization_config.group_size,
g_idx = module.g_idx if hasattr(module, "g_idx") else None,
)
else:
from .nn.modules import (
QuantizedLinearQBits,
) # TODO: QuantizedLinearINT4, QuantizedLinearINT8

use_optimum_format = getattr(module, "use_optimum_format", False) or \
quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
"fp4",
"nf4",
"int4_fullrange",
]

model._modules[name] = QuantizedLinearQBits(
in_features,
out_features,
module.bias is not None,
compute_dtype=quantization_config.compute_dtype,
compress_statistics=False,
weight_dtype=quantization_config.weight_dtype,
scale_dtype=quantization_config.scale_dtype,
blocksize=quantization_config.group_size,
scheme=quantization_config.scheme,
compression_dtype=getattr(module, "compression_dtype", torch.int32),
compression_dim=getattr(module, "compression_dim", 1),
device=device,
use_optimum_format=use_optimum_format,
)
elif device == "xpu" or device == torch.device("xpu"):
from intel_extension_for_pytorch.nn.utils._quantize_convert \
import WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401
Expand Down Expand Up @@ -265,7 +310,9 @@ def _replace_linear(
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if device == "cpu" or device == torch.device("cpu") or device == "auto":
if quantization_config.use_ipex:
pass
elif (device == "cpu" or device == torch.device("cpu") or device == "auto"):
if quantization_config.weight_dtype in [
"fp8_e5m2",
"fp8_e4m3",
Expand Down Expand Up @@ -560,7 +607,11 @@ def default_calib_func(model):
if config.weight_dtype not in ["nf4", "fp4", "int4_fullrange"]:
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
inc_model.eval()
if config.use_ipex:
optimum_format_state_dict = inc_model.state_dict()
q_model = replace_linear(inc_model, None, None, config, device=device)
if config.use_ipex:
setattr(q_model, "optimum_format_state_dict", optimum_format_state_dict)
else:
q_model = replace_linear(
inc_model.model, None, None, config, device=device
Expand Down
131 changes: 107 additions & 24 deletions intel_extension_for_transformers/transformers/modeling/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def convert_model_to_public(model):
module.qweight.data = module.qweight.t_().contiguous()
module.scales.data = module.scales.t_().contiguous()
module.weight_transposed = False
elif model.quantization_config.use_ipex:
pass
elif model.quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
Expand All @@ -195,7 +197,6 @@ def convert_model_to_public(model):
]:
model = recover_export_model(model)


def make_contiguous(model):
for param in model.parameters():
if param.data.ndimension() > 1:
Expand Down Expand Up @@ -223,6 +224,7 @@ def save_low_bit(
os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME)
torch.save(self.quantized_state_dict(), weights_file)
return

convert_model_to_public(self)
os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
Expand All @@ -231,6 +233,33 @@ def save_low_bit(
self.save_pretrained(
save_directory=save_directory, push_to_hub=push_to_hub, **kwargs
)

if self.quantization_config.use_ipex:
def save_linear_parameters(model, save_directory):
# only can save to pytorch model.bin due to ipex.
weights_file = os.path.join(
os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME)
os.remove(weights_file)
weights_file = os.path.join(
os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME)
linear_parameters = {}
from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear
for name, module in model.named_modules():
if isinstance(module, ipex_cpu_linear):
linear_parameters[name + ".ipex_scales"] = module._op_context.get_scales().contiguous()
linear_parameters[name + ".ipex_weight"] = \
module._op_context.to_public(module._op_context.get_weight()).contiguous()
linear_parameters[name + ".ipex_zeros"] = module._op_context.get_zero_points().contiguous()
if module._op_context.get_bias() is not None:
linear_parameters[name + ".ipex_bias"] = module._op_context.get_bias().contiguous()
if module._op_context.get_g_idx() is not None:
linear_parameters[name + ".ipex_g_idx"] = module._op_context.get_g_idx().contiguous()
others_parameters = model.state_dict()
linear_parameters.update(others_parameters)

torch.save(linear_parameters, weights_file)

save_linear_parameters(self, save_directory)
self.save_pretrained = types.MethodType(save_low_bit, self)
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
Expand Down Expand Up @@ -1814,42 +1843,96 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

if is_ipex_available() and quantization_config.use_ipex:
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear
def replace_ipex_cpu_woq_linear(model, current_name=[]):
for name, module in model.named_children():
current_name.append(name)
if isinstance(module, WeightOnlyLinear):
weight_dtype = {
4: ipex.quantization.WoqWeightDtype.INT4,
8: ipex.quantization.WoqWeightDtype.INT8,
}
compute_dtype = {
"fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype.
"bf16": ipex.quantization.WoqLowpMode.BF16,
"fp16": ipex.quantization.WoqLowpMode.FP16,
"int8": ipex.quantization.WoqLowpMode.INT8,

}

ipex_qconfig_mapping = (
ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype[quantization_config.bits],
lowp_mode=compute_dtype[quantization_config.compute_dtype],
act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
group_size=quantization_config.group_size,
)
)
tmp_linear = torch.nn.Linear(
module.in_features,
module.out_features,
True if hasattr(module, "bias") else False
)
tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig
target_linear = ipex_linear.from_float_and_int4_weight(
mod = tmp_linear,
qweight = state_dict.pop('.'.join(current_name) + ".ipex_weight"),
scales = state_dict.pop('.'.join(current_name) + ".ipex_scales"),
zero_points = state_dict.pop('.'.join(current_name) + ".ipex_zeros"),
bias = state_dict.pop('.'.join(current_name) + ".ipex_bias") \
if '.'.join(current_name) + ".ipex_bias" in state_dict else None,
group_size = quantization_config.group_size,
g_idx = state_dict.pop('.'.join(current_name) + ".ipex_g_idx") \
if '.'.join(current_name) + ".ipex_g_idx" in state_dict else None,
)
setattr(model, name, target_linear)
else:
replace_ipex_cpu_woq_linear(module, current_name)
current_name.pop()

replace_ipex_cpu_woq_linear(model)
model.load_state_dict(state_dict, strict=False, assign=True)
else:
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

if quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4",
"int4_fullrange",
]:
] and not quantization_config.use_ipex:
model = replace_linear(
model.float(),
model,
quantization_config=quantization_config,
device="cpu" if device_map == "auto" else device_map,
empty_weights=True,
Expand Down
Loading

0 comments on commit 008492d

Please sign in to comment.