Skip to content

Commit

Permalink
add cases len=0 or len=1 to collate fn (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
LBerth authored Nov 13, 2024
1 parent 4eb556c commit 38586e2
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mfai/torch/namedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def squeeze_(self, dim_name: Union[List[str], str]):
self.names.remove(name)

def unsqueeze_(self, dim_name: str, dim_index: int):
""" "
"""
Insert a new dimension dim_name of size 1 at dim_index
"""
self.tensor = torch.unsqueeze(self.tensor, dim_index)
Expand Down Expand Up @@ -432,6 +432,12 @@ def to_(self, *args, **kwargs):
@staticmethod
def collate_fn(batch: List["NamedTensor"]) -> "NamedTensor":
"""
Collate a list of NamedTensors into a single NamedTensor.
Collate a list of NamedTensors into a batched single NamedTensor.
"""
if len(batch) == 0:
raise ValueError("Cannot collate an empty list of NamedTensors")
if len(batch) == 1:
# add batch dim to the single namedtensor (in place operation)
batch[0].unsqueeze_(dim_name="batch", dim_index=0)
return batch[0]
return NamedTensor.stack(batch, dim_name="batch", dim=0)

0 comments on commit 38586e2

Please sign in to comment.