From 659321552dd4aff6d94b83de45e88091c9eb36cf Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 26 Sep 2024 19:51:33 +0200 Subject: [PATCH] fix --- optimum_benchmark/backends/onnxruntime/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 8fb69254..87f9a765 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -6,6 +6,7 @@ import torch from hydra.utils import get_class from onnxruntime import SessionOptions +from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.onnxruntime import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, @@ -299,9 +300,8 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.library == "transformers": for key, value in list(inputs.items()): - if key in ["position_ids", "token_type_ids"]: - if key not in self.pretrained_model.input_names: - inputs.pop(key) + if key == "position_ids" and self.model_type not in MODEL_TYPES_REQUIRING_POSITION_IDS: + inputs.pop(key) for key, value in inputs.items(): if isinstance(value, torch.Tensor):