diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index 64ba1193..d618e827 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -643,4 +643,4 @@ def cuda_sync_debug_mode(debug_mode: Union[int, str]): yield finally: if current_mode is not None: - torch.cuda.set_sync_debug_mode(debug_mode) + torch.cuda.set_sync_debug_mode(current_mode)