Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 26, 2024
1 parent bbf94e4 commit 6593215
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from hydra.utils import get_class
from onnxruntime import SessionOptions
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.onnxruntime import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
Expand Down Expand Up @@ -299,9 +300,8 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

if self.config.library == "transformers":
for key, value in list(inputs.items()):
if key in ["position_ids", "token_type_ids"]:
if key not in self.pretrained_model.input_names:
inputs.pop(key)
if key == "position_ids" and self.model_type not in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs.pop(key)

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
Expand Down

0 comments on commit 6593215

Please sign in to comment.