Skip to content

Commit

Permalink
Fix bug of multiple pre-processing when segmentation (PyTorch) (#645)
Browse files Browse the repository at this point in the history
It is very slow in performing segmentation inference.
#531 
#234

And, it is because the dataloader will apply multiple data preprocessing
if self.cache_convert is None.

https://github.com/isl-org/Open3D-ML/blob/fcf97c07bf7a113a47d0fcf63760b245c2a2784e/ml3d/torch/dataloaders/torch_dataloader.py#L77-L83

When running the run_inference method, the cache_convert of dataloader
is None.

https://github.com/isl-org/Open3D-ML/blob/fcf97c07bf7a113a47d0fcf63760b245c2a2784e/ml3d/torch/pipelines/semantic_segmentation.py#L143-L147

This leads to extreme slowness in performing reasoning.

I've added a get_cache method to provide cache to avoid slowdowns caused
by multiple preprocessing during inference.

I tested it on a GV100 GPU with RandLA-Net on the Toronto3D dataset.
Inferencing time for a single scene is only two minutes and 37 seconds.
Reasoning is considerably faster than before

```bash
After: test 0/1: 100%|██████████████████████████████████████████████████████| 4990714/4990714 [02:37<00:00, 31769.86it/s]

Before: test 0/1:   4%|██                                                     | 187127/4990714 [05:12<2:19:39, 573.27it/s]
```
  • Loading branch information
Lionelsy authored Jan 7, 2025
1 parent 3754ece commit b64b514
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 3 additions & 4 deletions ml3d/torch/dataloaders/torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self,
sampler=None,
use_cache=True,
steps_per_epoch=None,
cache_convert=None,
**kwargs):
"""Initialize.
Expand All @@ -38,6 +39,7 @@ def __init__(self,
self.dataset = dataset
self.preprocess = preprocess
self.steps_per_epoch = steps_per_epoch
self.cache_convert = cache_convert

if preprocess is not None and use_cache:
cache_dir = getattr(dataset.cfg, 'cache_dir')
Expand All @@ -59,10 +61,7 @@ def __init__(self,
continue
data = dataset.get_data(idx)
# cache the data
self.cache_convert(name, data, attr)

else:
self.cache_convert = None
self.cache_convert(name, data, attr)

self.transform = transform

Expand Down
8 changes: 7 additions & 1 deletion ml3d/torch/pipelines/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def run_inference(self, data):
model.device = device
model.eval()

preprocess_func = model.preprocess
processed_data = preprocess_func(data, {'split': 'test'})
def get_cache(attr):
return processed_data

batcher = self.get_batcher(device)
infer_dataset = InferenceDummySplit(data)
self.dataset_split = infer_dataset
Expand All @@ -144,7 +149,8 @@ def run_inference(self, data):
preprocess=model.preprocess,
transform=model.transform,
sampler=infer_sampler,
use_cache=False)
use_cache=False,
cache_convert=get_cache)
infer_loader = DataLoader(infer_split,
batch_size=cfg.batch_size,
sampler=get_sampler(infer_sampler),
Expand Down

0 comments on commit b64b514

Please sign in to comment.