Skip to content

Commit

Permalink
Avoid using __torch_dispatch__ prior to v2.3.0 (#22)
Browse files Browse the repository at this point in the history
There are bugs with `__torch_dispatch__` in earlier versions.
  • Loading branch information
epwalsh authored May 20, 2024
1 parent 22db20c commit 93ec91a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pydantic>=2.0,<3.0",
"cached-path",
"requests",
"packaging",
]

[project.optional-dependencies]
Expand Down
41 changes: 24 additions & 17 deletions src/olmo_core/distributed/tensors/sharded_flat_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging import version

try:
from torch.utils import _cxx_pytree as pytree
Expand Down Expand Up @@ -105,30 +106,36 @@ def __new__(cls, data: torch.Tensor, requires_grad: bool = False) -> ShardedFlat

return tensor

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
del types
kwargs = kwargs or {}
if version.parse(torch.__version__) >= version.parse("2.3.0"):
# There are some bugs with __torch_dispatch__ in earlier versions.

def unwrap(x):
if isinstance(x, ShardedFlatTensor):
return x._global_tensor if x._global_tensor is not None else x._local_tensor
else:
return x
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
del types
kwargs = kwargs or {}

def unwrap(x):
if isinstance(x, ShardedFlatTensor):
return x._global_tensor if x._global_tensor is not None else x._local_tensor
else:
return x

def wrap(x):
if isinstance(x, torch.Tensor):
if x.shape == self.shape:
return self.wrap(x, requires_grad=x.requires_grad)
return x
def wrap(x):
if isinstance(x, torch.Tensor):
if x.shape == self.shape:
return self.wrap(x, requires_grad=x.requires_grad)
return x

out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs))
out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs))

if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore
out = pytree.tree_map(wrap, out)
if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore
out = pytree.tree_map(wrap, out)

return out
return out

def __repr__(self) -> str:
if not self.metadata_set:
return super().__repr__()

if self._global_tensor is not None:
return f"ShardedFlatTensor(local_tensor={self._local_tensor}, global_tensor={self._global_tensor})"
else:
Expand Down
32 changes: 16 additions & 16 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ def test_tensor_shard_spec_for_dtensor_2D_rowwise():
]


def run_get_local_tensor_data_with_dtensor():
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)])

# Make sure modifying the data returned from `_get_local_tensor_data` will modify the data
# in the actual tensor.
_get_local_tensor_data(dtensor).fill_(torch.nan)
assert _get_local_tensor_data(dtensor).isnan().all()
assert dtensor.full_tensor().isnan().all()


@requires_multi_gpu
def test_get_local_tensor_data_with_dtensor():
run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl")


def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir):
checkpointer = Checkpointer()

Expand Down Expand Up @@ -251,22 +267,6 @@ def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir):
assert full_state_dict["y"].shape == (2, 3)


def run_get_local_tensor_data_with_dtensor():
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)])

# Make sure modifying the data returned from `_get_local_tensor_data` will modify the data
# in the actual tensor.
_get_local_tensor_data(dtensor).fill_(torch.nan)
assert _get_local_tensor_data(dtensor).isnan().all()
assert dtensor.full_tensor().isnan().all()


@requires_multi_gpu
def test_get_local_tensor_data_with_dtensor():
run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl")


@pytest.mark.parametrize("backend", BACKENDS)
def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_path):
run_distributed_test(
Expand Down

0 comments on commit 93ec91a

Please sign in to comment.