Skip to content

Commit

Permalink
fix: output names in archive generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn committed Oct 11, 2024
1 parent 12f326d commit efaf3dd
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions luxonis_train/core/utils/archive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,7 @@ def _get_head_outputs(
if name == head_name:
output_names.append(output["name"])

if output_names:
return output_names

# TODO: Fix this, will require refactoring custom ONNX output names
logger.error(
"ONNX model uses custom output names, trying to determine outputs based on the head type. "
"This will likely result in incorrect archive for multi-head models. "
"You can ignore this error if your model has only one head."
)

if head_type == "ClassificationHead":
return [outputs[0]["name"]]
elif head_type == "EfficientBBoxHead":
return [output["name"] for output in outputs]
elif head_type in ["SegmentationHead", "BiSeNetHead"]:
return [outputs[0]["name"]]
elif head_type == "EfficientKeypointBBoxHead":
return [outputs[0]["name"]]
else:
raise ValueError("Unknown head name")

return output_names

def get_heads(
cfg: Config,
Expand Down Expand Up @@ -238,9 +218,15 @@ def get_heads(
task = str(next(iter(task.values())))

classes = _get_classes(node_name, task, class_dict)
head_outputs = _get_head_outputs(
outputs, node_alias, node_name
)

export_output_names = nodes[node_alias].export_output_names
if export_output_names is not None:
head_outputs = export_output_names
else:
head_outputs = _get_head_outputs(
outputs, node_alias, node_name
)

if node_alias in head_names:
curr_head_name = f"{node_alias}_{len(head_names)}" # add suffix if name is already present
else:
Expand Down

0 comments on commit efaf3dd

Please sign in to comment.