From 93ec91ab30f40d888f00ec9e58eec9ad1125f50e Mon Sep 17 00:00:00 2001 From: Pete Date: Mon, 20 May 2024 11:18:57 -0700 Subject: [PATCH] Avoid using `__torch_dispatch__` prior to v2.3.0 (#22) There are bugs with `__torch_dispatch__` in earlier versions. --- pyproject.toml | 1 + .../tensors/sharded_flat_tensor.py | 41 +++++++++++-------- src/test/distributed/checkpoint_test.py | 32 +++++++-------- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c165b938..9d8138d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pydantic>=2.0,<3.0", "cached-path", "requests", + "packaging", ] [project.optional-dependencies] diff --git a/src/olmo_core/distributed/tensors/sharded_flat_tensor.py b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py index d537894d..13e39af8 100644 --- a/src/olmo_core/distributed/tensors/sharded_flat_tensor.py +++ b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from packaging import version try: from torch.utils import _cxx_pytree as pytree @@ -105,30 +106,36 @@ def __new__(cls, data: torch.Tensor, requires_grad: bool = False) -> ShardedFlat return tensor - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - del types - kwargs = kwargs or {} + if version.parse(torch.__version__) >= version.parse("2.3.0"): + # There are some bugs with __torch_dispatch__ in earlier versions. - def unwrap(x): - if isinstance(x, ShardedFlatTensor): - return x._global_tensor if x._global_tensor is not None else x._local_tensor - else: - return x + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + del types + kwargs = kwargs or {} + + def unwrap(x): + if isinstance(x, ShardedFlatTensor): + return x._global_tensor if x._global_tensor is not None else x._local_tensor + else: + return x - def wrap(x): - if isinstance(x, torch.Tensor): - if x.shape == self.shape: - return self.wrap(x, requires_grad=x.requires_grad) - return x + def wrap(x): + if isinstance(x, torch.Tensor): + if x.shape == self.shape: + return self.wrap(x, requires_grad=x.requires_grad) + return x - out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs)) + out = func(*pytree.tree_map(unwrap, args), **pytree.tree_map(unwrap, kwargs)) - if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore - out = pytree.tree_map(wrap, out) + if func in {torch.ops.aten.empty_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.ones_like.default}: # type: ignore + out = pytree.tree_map(wrap, out) - return out + return out def __repr__(self) -> str: + if not self.metadata_set: + return super().__repr__() + if self._global_tensor is not None: return f"ShardedFlatTensor(local_tensor={self._local_tensor}, global_tensor={self._global_tensor})" else: diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 1c21b1f5..56f2eb92 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -205,6 +205,22 @@ def test_tensor_shard_spec_for_dtensor_2D_rowwise(): ] +def run_get_local_tensor_data_with_dtensor(): + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]) + + # Make sure modifying the data returned from `_get_local_tensor_data` will modify the data + # in the actual tensor. + _get_local_tensor_data(dtensor).fill_(torch.nan) + assert _get_local_tensor_data(dtensor).isnan().all() + assert dtensor.full_tensor().isnan().all() + + +@requires_multi_gpu +def test_get_local_tensor_data_with_dtensor(): + run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl") + + def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir): checkpointer = Checkpointer() @@ -251,22 +267,6 @@ def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir): assert full_state_dict["y"].shape == (2, 3) -def run_get_local_tensor_data_with_dtensor(): - mesh = init_device_mesh("cuda", (dist.get_world_size(),)) - dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]) - - # Make sure modifying the data returned from `_get_local_tensor_data` will modify the data - # in the actual tensor. - _get_local_tensor_data(dtensor).fill_(torch.nan) - assert _get_local_tensor_data(dtensor).isnan().all() - assert dtensor.full_tensor().isnan().all() - - -@requires_multi_gpu -def test_get_local_tensor_data_with_dtensor(): - run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl") - - @pytest.mark.parametrize("backend", BACKENDS) def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_path): run_distributed_test(