From a956a8dfca11c6b49469b3dcc63276f8e347fb9b Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Mon, 20 May 2024 17:17:27 +0800 Subject: [PATCH] add gc and fix typo REASON: Adding gc collect and torch.cuda.empty_cache() after del to make sure device memory is freed immediately --- bmtrain/block_layer.py | 7 +++++++ bmtrain/nccl/__init__.py | 10 +++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 98200465..aa154b96 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,4 +1,5 @@ from typing import Dict, Iterable, Iterator, Union, List +import gc from .utils import (round_up, tp_split_tensor) from .global_var import config @@ -215,6 +216,9 @@ def init_param_storage(self): param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param + + gc.collect() + torch.cuda.empty_cache() else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) setattr(param, "_start_partition", None) @@ -374,6 +378,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param + + gc.collect() + torch.cuda.empty_cache() elif strict: missing_keys.append(key) diff --git a/bmtrain/nccl/__init__.py b/bmtrain/nccl/__init__.py index 0f4129d5..544dd3e3 100644 --- a/bmtrain/nccl/__init__.py +++ b/bmtrain/nccl/__init__.py @@ -119,7 +119,7 @@ def allReduce( If src == dst, the operation is performed in-place. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() @@ -197,7 +197,7 @@ def broadcast( """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() @@ -237,7 +237,7 @@ def reduce( If src == dst, the operation is performed in-place. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() @@ -266,7 +266,7 @@ def allGather( The dst buffer is only used on rank root. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr() @@ -303,7 +303,7 @@ def reduceScatter( The dst buffer on rank `i` will contail the i-th block of the reduced result. """ - assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.dtype == dst.dtype, "send and recv buffers must be the same type" assert src.is_cuda and dst.is_cuda sendbuff = src.data_ptr()