Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into main
  • Loading branch information
IlyasMoutawwakil committed Oct 31, 2023
2 parents b4160d3 + eb40d23 commit d5628f7
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion optimum_benchmark/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,18 @@ def extract_shapes_from_model_artifacts(

# text input
shapes["vocab_size"] = artifacts_dict.get("vocab_size", 2)
if shapes["vocab_size"] == 0:
shapes["vocab_size"] = 2

shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", 2)
if shapes["type_vocab_size"] == 0:
shapes["type_vocab_size"] = 2

# image input
shapes["num_channels"] = artifacts_dict.get("num_channels", None)
if shapes["num_channels"] is None:
# processors have different names for the number of channels
shapes["num_channels"] = artifacts_dict.get("channels", None)

image_size = artifacts_dict.get("image_size", None)
if image_size is None:
Expand All @@ -86,10 +94,15 @@ def extract_shapes_from_model_artifacts(
shapes["width"] = None

# classification labels (default to 2)
shapes["num_labels"] = len(artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}))
id2label = artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"})
shapes["num_labels"] = len(id2label)
if shapes["num_labels"] == 0:
shapes["num_labels"] = 2

# object detection labels (default to 2)
shapes["num_queries"] = artifacts_dict.get("num_queries", 2)
if shapes["num_queries"] == 0:
shapes["num_queries"] = 2

return shapes

Expand Down

0 comments on commit d5628f7

Please sign in to comment.