diff --git a/optimum_benchmark/backends/utils.py b/optimum_benchmark/backends/utils.py index 25bca8ef..fc9348c1 100644 --- a/optimum_benchmark/backends/utils.py +++ b/optimum_benchmark/backends/utils.py @@ -59,10 +59,18 @@ def extract_shapes_from_model_artifacts( # text input shapes["vocab_size"] = artifacts_dict.get("vocab_size", 2) + if shapes["vocab_size"] == 0: + shapes["vocab_size"] = 2 + shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", 2) + if shapes["type_vocab_size"] == 0: + shapes["type_vocab_size"] = 2 # image input shapes["num_channels"] = artifacts_dict.get("num_channels", None) + if shapes["num_channels"] is None: + # processors have different names for the number of channels + shapes["num_channels"] = artifacts_dict.get("channels", None) image_size = artifacts_dict.get("image_size", None) if image_size is None: @@ -86,10 +94,15 @@ def extract_shapes_from_model_artifacts( shapes["width"] = None # classification labels (default to 2) - shapes["num_labels"] = len(artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"})) + id2label = artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}) + shapes["num_labels"] = len(id2label) + if shapes["num_labels"] == 0: + shapes["num_labels"] = 2 # object detection labels (default to 2) shapes["num_queries"] = artifacts_dict.get("num_queries", 2) + if shapes["num_queries"] == 0: + shapes["num_queries"] = 2 return shapes