Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 31, 2023
1 parent 014c97e commit 09435a8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,12 @@ def load_model_from_config(self) -> None:

def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None:
super().prepare_for_inference(input_shapes=input_shapes, **kwargs)

if self.config.quantization_scheme == "gptq" or (
if (self.config.quantization_scheme == "gptq" and self.config.quantization_config["desc_act"]) or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config["desc_act"]
and self.pretrained_config.quantization_config["quant_method"] == "gptq"
):
LOGGER.info("\t+ Setting GPTQ max_input_length")
LOGGER.info("\t+ Setting GPTQ's max_input_length")
from auto_gptq import exllama_set_max_input_length

max_input_length = to_pow2(input_shapes["batch_size"] * input_shapes["sequence_length"])
Expand Down

0 comments on commit 09435a8

Please sign in to comment.