Skip to content

Commit

Permalink
add gc and fix typo
Browse files Browse the repository at this point in the history
REASON: Adding gc collect and torch.cuda.empty_cache() after del to make sure device memory is freed immediately
  • Loading branch information
BeingGod committed May 20, 2024
1 parent b903a31 commit a956a8d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 7 additions & 0 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Iterable, Iterator, Union, List
import gc

from .utils import (round_up, tp_split_tensor)
from .global_var import config
Expand Down Expand Up @@ -215,6 +216,9 @@ def init_param_storage(self):
param.data[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
del contiguous_param

gc.collect()
torch.cuda.empty_cache()
else:
param.data = torch.tensor([], dtype=param.dtype, device=param.device)
setattr(param, "_start_partition", None)
Expand Down Expand Up @@ -374,6 +378,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
del contiguous_param

gc.collect()
torch.cuda.empty_cache()
elif strict:
missing_keys.append(key)

Expand Down
10 changes: 5 additions & 5 deletions bmtrain/nccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def allReduce(
If src == dst, the operation is performed in-place.
"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -197,7 +197,7 @@ def broadcast(
"""

assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -237,7 +237,7 @@ def reduce(
If src == dst, the operation is performed in-place.
"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -266,7 +266,7 @@ def allGather(
The dst buffer is only used on rank root.
"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -303,7 +303,7 @@ def reduceScatter(
The dst buffer on rank `i` will contail the i-th block of the reduced result.
"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down

0 comments on commit a956a8d

Please sign in to comment.