Skip to content

Commit

Permalink
Replace numpy transpose with torch permute to speed-up (#9533)
Browse files Browse the repository at this point in the history
  • Loading branch information
Min-Sheng authored Jan 4, 2023
1 parent dcac7dd commit cf43a1b
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions mmdet/datasets/pipelines/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,20 @@ def __init__(self, keys):

def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
permute the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
to :obj:`torch.Tensor` and permuted to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
return results

def __repr__(self):
Expand Down Expand Up @@ -179,7 +179,7 @@ class DefaultFormatBundle:
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
These fields are formatted as follows.
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- img: (1)transpose & to tensor, (2)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
Expand Down Expand Up @@ -226,9 +226,20 @@ def __call__(self, results):
results = self._add_default_meta_keys(results)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
# To improve the computational speed by by 3-5 times, apply:
# If image is not contiguous, use
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
# If image is already contiguous, use
# `torch.permute()` followed by `torch.contiguous()`
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if not img.flags.c_contiguous:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
else:
img = to_tensor(img).permute(2, 0, 1).contiguous()
results['img'] = DC(
to_tensor(img), padding_value=self.pad_val['img'], stack=True)
img, padding_value=self.pad_val['img'], stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results:
continue
Expand Down

0 comments on commit cf43a1b

Please sign in to comment.