Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gc and fix typo #202

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Iterable, Iterator, Union, List
import gc

from .utils import (round_up, tp_split_tensor)
from .global_var import config
Expand Down Expand Up @@ -215,6 +216,9 @@ def init_param_storage(self):
param.data[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
del contiguous_param

gc.collect()
torch.cuda.empty_cache()
else:
param.data = torch.tensor([], dtype=param.dtype, device=param.device)
setattr(param, "_start_partition", None)
Expand Down Expand Up @@ -374,6 +378,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
del contiguous_param

gc.collect()
torch.cuda.empty_cache()
elif strict:
missing_keys.append(key)

Expand Down
10 changes: 5 additions & 5 deletions bmtrain/nccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def allReduce(
If src == dst, the operation is performed in-place.

"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -197,7 +197,7 @@ def broadcast(

"""

assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -237,7 +237,7 @@ def reduce(
If src == dst, the operation is performed in-place.

"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -266,7 +266,7 @@ def allGather(
The dst buffer is only used on rank root.

"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down Expand Up @@ -303,7 +303,7 @@ def reduceScatter(
The dst buffer on rank `i` will contail the i-th block of the reduced result.

"""
assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
assert src.dtype == dst.dtype, "send and recv buffers must be the same type"
assert src.is_cuda and dst.is_cuda

sendbuff = src.data_ptr()
Expand Down
Loading