From cf43a1bc6ae7b4aca9a2e2e4927d7f2a3a6229f6 Mon Sep 17 00:00:00 2001 From: Min Sheng Wu <30727252+Min-Sheng@users.noreply.github.com> Date: Wed, 4 Jan 2023 23:52:32 +0800 Subject: [PATCH] Replace numpy transpose with torch permute to speed-up (#9533) --- mmdet/datasets/pipelines/formatting.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/mmdet/datasets/pipelines/formatting.py b/mmdet/datasets/pipelines/formatting.py index 45ca69cfc6f..2e07f3894f0 100644 --- a/mmdet/datasets/pipelines/formatting.py +++ b/mmdet/datasets/pipelines/formatting.py @@ -80,20 +80,20 @@ def __init__(self, keys): def __call__(self, results): """Call function to convert image in results to :obj:`torch.Tensor` and - transpose the channel order. + permute the channel order. Args: results (dict): Result dict contains the image data to convert. Returns: dict: The result dict contains the image converted - to :obj:`torch.Tensor` and transposed to (C, H, W) order. + to :obj:`torch.Tensor` and permuted to (C, H, W) order. """ for key in self.keys: img = results[key] if len(img.shape) < 3: img = np.expand_dims(img, -1) - results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() + results[key] = to_tensor(img).permute(2, 0, 1).contiguous() return results def __repr__(self): @@ -179,7 +179,7 @@ class DefaultFormatBundle: "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". These fields are formatted as follows. - - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - img: (1)transpose & to tensor, (2)to DataContainer (stack=True) - proposals: (1)to tensor, (2)to DataContainer - gt_bboxes: (1)to tensor, (2)to DataContainer - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer @@ -226,9 +226,20 @@ def __call__(self, results): results = self._add_default_meta_keys(results) if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) + # To improve the computational speed by by 3-5 times, apply: + # If image is not contiguous, use + # `numpy.transpose()` followed by `numpy.ascontiguousarray()` + # If image is already contiguous, use + # `torch.permute()` followed by `torch.contiguous()` + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if not img.flags.c_contiguous: + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = to_tensor(img) + else: + img = to_tensor(img).permute(2, 0, 1).contiguous() results['img'] = DC( - to_tensor(img), padding_value=self.pad_val['img'], stack=True) + img, padding_value=self.pad_val['img'], stack=True) for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: if key not in results: continue