Skip to content

Commit

Permalink
fix mock batch
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Jan 29, 2025
1 parent 65fab16 commit 5ae6342
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def reshuffle(self, epoch: Optional[int] = None, in_memory: bool = False, **kwar
self.build_and_save_global_indices(in_memory=in_memory)

def get_mock_batch(self) -> Dict[str, Any]:
rng = torch.Generator(device=get_default_device())
rng = torch.Generator(device=get_default_device(non_cuda=True))
rng.manual_seed(self.seed + self.dp_rank)
num_instances = self.rank_batch_size // self.dataset.max_sequence_length
input_ids = torch.randint(
Expand Down
7 changes: 5 additions & 2 deletions src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ def mark_dynamic(x: torch.Tensor, dim: Union[int, Sequence[int]]):
torch._dynamo.mark_dynamic(x, dim)


def get_default_device() -> torch.device:
def get_default_device(non_cuda: bool = False) -> torch.device:
"""
Get the default device.
"""
if torch.cuda.is_available() and torch.cuda.is_initialized():
return torch.device("cuda")
if non_cuda:
return torch.device("cpu")
else:
return torch.device("cuda")
elif torch.mps.is_available():
return torch.device("mps")
else:
Expand Down

0 comments on commit 5ae6342

Please sign in to comment.