Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Adapt quant lm head #1671

Merged
merged 7 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@
help="Use determined group to do quantization",
)
# ============AutoRound==================
parser.add_argument(
"--autoround_iters",
default=2048,
type=int,
help="Calibration dataset max or padding max length for AutoRound.",
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -172,7 +166,6 @@
bits=args.bits,
sym=True if args.scheme == "sym" else False,
group_size=args.group_size,
seq_len=args.seq_len,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ def replace_linear(
if modules_to_not_convert is None:
# output_layer is chatglm last layer name
# embed_out is dolly_v2 last layer name
modules_to_not_convert = ["lm_head", "output_layer", "embed_out"]
modules_to_not_convert = []
if quantization_config.llm_int8_skip_modules:
modules_to_not_convert = modules_to_not_convert.extend(
modules_to_not_convert.extend(
quantization_config.llm_int8_skip_modules
)
modules_to_not_convert = list(set(modules_to_not_convert))
model, is_replaced = _replace_linear(
model,
modules_to_not_convert,
Expand Down Expand Up @@ -559,9 +560,11 @@ def convert_to_quantized_model(model, config, device="cpu"):
group_size=config.group_size,
use_layer_wise=config.layer_wise,
)
quant_config.set_local(".*lm_head", RTNConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", RTNConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", RTNConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, RTNConfig(dtype="fp32"))
logger.info(f"Do RTN algorithm with config {quant_config}")
model = prepare(model, quant_config)
model = convert(model)
elif config.quant_method.value == "awq":
Expand All @@ -575,9 +578,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
use_auto_clip=config.auto_clip,
folding=True,
)
quant_config.set_local(".*lm_head", AWQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", AWQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", AWQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, AWQConfig(dtype="fp32"))
logger.info(f"Do AWQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand All @@ -601,9 +605,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
use_layer_wise=config.layer_wise,
absorb_to_layer=config.absorb_to_layer
)
quant_config.set_local(".*lm_head", TEQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", TEQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", TEQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, TEQConfig(dtype="fp32"))
logger.info(f"Do TEQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand Down Expand Up @@ -632,9 +637,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
block_size=config.blocksize,
static_groups=config.static_groups,
)
quant_config.set_local(".*lm_head", GPTQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", GPTQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", GPTQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, GPTQConfig(dtype="fp32"))
logger.info(f"Do GPTQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand Down Expand Up @@ -662,10 +668,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
iters=config.iters,
scale_dtype=config.scale_dtype,
)
if config.quant_lm_head is False:
quant_config.set_local(".*lm_head", AutoRoundConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", AutoRoundConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", AutoRoundConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, AutoRoundConfig(dtype="fp32"))
logger.info(f"Do AutoRound algorithm with config {quant_config}")
dataloader = get_autoround_dataloader(tokenizer=config.tokenizer,
seqlen=config.seq_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def build_woq_model(model, quantization_config):
from neural_compressor.adaptor.torch_utils.util import set_module
weight_dtype = quantization_config.weight_dtype
for n, m in model.named_modules():
if "lm_head" in n or "output_layer" in n or "embed_out" in n:
if n in quantization_config.llm_int8_skip_modules:
continue
if isinstance(m, torch.nn.Linear):
zp = getattr(
Expand Down Expand Up @@ -883,6 +883,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
hasattr(torch, "xpu") and torch.xpu.is_available()
), "There is no xpu device in this system!"
quantization_config.update(**{"device": "xpu"})
quantization_config.post_init_xpu()
if (
not torch.cuda.is_available()
or device_map == "cpu"
Expand Down
12 changes: 7 additions & 5 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def __init__(
self.double_quant_bits = double_quant_bits
self.double_quant_use_sym = double_quant_use_sym
self.double_quant_group_size = double_quant_group_size
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -911,7 +911,7 @@ def __init__(
self.true_sequential = true_sequential
self.layer_wise = layer_wise
self.seq_len = seq_len
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1078,7 +1078,7 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_neural_speed = use_neural_speed
self.device = kwargs.get("device", "auto")
Expand Down Expand Up @@ -1154,7 +1154,9 @@ def __init__(
self.iters = iters
self.seq_len = seq_len
self.quant_lm_head = quant_lm_head
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
if self.quant_lm_head:
self.llm_int8_skip_modules = []
self.use_ggml = use_ggml
self.use_neural_speed = use_neural_speed
self.batch_size = kwargs.pop("batch_size", 8)
Expand Down
Loading