Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Dec 29, 2023
1 parent c29ebe1 commit 044ebac
Showing 1 changed file with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,19 @@ def __init__(self,

self.device = torch.device(device)

self.distributed = rv_config.get_namespace_option(
'rastervision', 'USE_DDP', as_bool=True)
if dist.is_initialized():
self.distributed = True
elif device == 'cuda':
ddp_allowed = rv_config.get_namespace_option(
'rastervision', 'USE_DDP', True, as_bool=True)
dist_available = dist.is_available()
gpu_count = torch.cuda.device_count()
multi_gpus = torch.cuda.device_count() > 1
self.distributed = ddp_allowed and dist_available and multi_gpus
log.info(f'Multiple GPUs detected ({gpu_count}), will use DDP '
'for training.')
else:
self.distributed = False
self.ddp_rank = get_env_var('RANK', None, int)
self.ddp_local_rank = get_env_var('LOCAL_RANK', None, int)
self.ddp_world_size = get_env_var('WORLD_SIZE', None, int)
Expand Down Expand Up @@ -1117,6 +1128,8 @@ def build_dataloader(self, split: str) -> DataLoader:
f'than total batch size ({batch_sz}).')
batch_sz //= world_sz

log.info(f'Per GPU batch size: {batch_sz}')

args = dict(
batch_size=batch_sz,
num_workers=num_workers,
Expand Down

0 comments on commit 044ebac

Please sign in to comment.