Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using __torch_dispatch__ prior to v2.3.0 #22

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pydantic>=2.0,<3.0",
"cached-path",
"requests",
"packaging",
]

[project.optional-dependencies]
Expand Down
41 changes: 24 additions & 17 deletions src/olmo_core/distributed/tensors/sharded_flat_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging import version

try:
from torch.utils import _cxx_pytree as pytree
Expand Down Expand Up @@ -105,30 +106,36 @@ def __new__(cls, data: torch.Tensor, requires_grad: bool = False) -> ShardedFlat

return tensor

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
del types
kwargs = kwargs or {}
if version.parse(torch.__version__) >= version.parse("2.3.0"):
# There are some bugs with __torch_dispatch__ in earlier versions.

def unwrap(x):
if isinstance(x, ShardedFlatTensor):
return x._global_tensor if x._global_tensor is not None else x._local_tensor
else:
return x
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
del types
kwargs = kwargs or {}

def unwrap(x):
if isinstance(x, ShardedFlatTensor):
return x._global_tensor if x._global_tensor is not None else x._local_tensor
else:
return x

def wrap(x):
if isinstance(x, torch.Tensor):
if x.shape == self.shape:
return self.wrap(x, requires_grad=x.requires_grad)
return x
def wrap(x):
if isinstance(x, torch.Tensor):
if x.shape == self.shape:
return self.wrap(x, requires_grad=x.requires_grad)
return x

out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs))
out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs))

if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore
out = pytree.tree_map(wrap, out)
if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore
out = pytree.tree_map(wrap, out)

return out
return out

def __repr__(self) -> str:
if not self.metadata_set:
return super().__repr__()

if self._global_tensor is not None:
return f"ShardedFlatTensor(local_tensor={self._local_tensor}, global_tensor={self._global_tensor})"
else:
Expand Down
32 changes: 16 additions & 16 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ def test_tensor_shard_spec_for_dtensor_2D_rowwise():
]


def run_get_local_tensor_data_with_dtensor():
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)])

# Make sure modifying the data returned from `_get_local_tensor_data` will modify the data
# in the actual tensor.
_get_local_tensor_data(dtensor).fill_(torch.nan)
assert _get_local_tensor_data(dtensor).isnan().all()
assert dtensor.full_tensor().isnan().all()


@requires_multi_gpu
def test_get_local_tensor_data_with_dtensor():
run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl")


def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir):
checkpointer = Checkpointer()

Expand Down Expand Up @@ -251,22 +267,6 @@ def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir):
assert full_state_dict["y"].shape == (2, 3)


def run_get_local_tensor_data_with_dtensor():
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)])

# Make sure modifying the data returned from `_get_local_tensor_data` will modify the data
# in the actual tensor.
_get_local_tensor_data(dtensor).fill_(torch.nan)
assert _get_local_tensor_data(dtensor).isnan().all()
assert dtensor.full_tensor().isnan().all()


@requires_multi_gpu
def test_get_local_tensor_data_with_dtensor():
run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl")


@pytest.mark.parametrize("backend", BACKENDS)
def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_path):
run_distributed_test(
Expand Down
Loading