Skip to content

Commit

Permalink
Refactor ShardedFlatTensor class (#21)
Browse files Browse the repository at this point in the history
* Refactor `ShardedFlatTensor` class

* maybe fix?

* make dispatch work when unsharded

* revert some changes

* another try

* another try

* another fix

* updates

* update PyTorch version in CI

* fix?

* updates

* add a little more safety
  • Loading branch information
epwalsh authored May 16, 2024
1 parent b2c3c09 commit 22db20c
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 127 deletions.
9 changes: 5 additions & 4 deletions .github/actions/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ inputs:
torch-version:
description: The PyTorch version to install
required: false
default: '==2.2.1'
default: '==2.3.0'
runs:
using: composite
steps:
Expand All @@ -34,16 +34,17 @@ runs:
id: virtualenv-cache
with:
path: .venv
key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }}
key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }}
restore-keys: |
${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }}
- if: steps.virtualenv-cache.outputs.cache-hit != 'true'
shell: bash
run: |
# Set up virtual environment without cache hit.
test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
. .venv/bin/activate
#pip install 'torch${{ inputs.torch-version }}' --extra-index-url https://download.pytorch.org/whl/cpu
pip install 'torch${{ inputs.torch-version }}'
pip install 'torch${{ inputs.torch-version }}' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[all]
- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ env:
jobs:
checks:
name: ${{ matrix.task.name }}
# TODO: change to 'ubuntu-latest' once repo is public (will have more RAM then), and update the torch
# install command in the setup-venv action.
runs-on: [macos-13]
runs-on: [ubuntu-latest]
timeout-minutes: 5
strategy:
fail-fast: false
Expand Down
21 changes: 4 additions & 17 deletions src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,33 +1321,20 @@ def _patch_key(model: nn.Module, key: str) -> str:
def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, DTensor):
return tensor.to_local()
elif isinstance(tensor, ShardedFlatTensor):
return tensor.sharded_data
else:
return tensor.data


def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor:
if isinstance(tensor, DTensor):
if isinstance(tensor, DTensor) or (isinstance(tensor, ShardedFlatTensor) and tensor.metadata_set):
return tensor

# TODO: (fixme) when you call `torch.empty_like(x)` on a `ShardedFlatTensor`, `x`, you get
# a `ShardedFlatTensor` without the metadata. Since PyTorch optimizer's use `torch.empty_like()`
# on each param to initialize its state, we run into an issue unless we still call `ShardedFlatTensor.wrap()`
# below.
# if isinstance(tensor, ShardedFlatTensor):
# return tensor

if isinstance(param, ShardedFlatTensor):
return param.wrap(tensor, requires_grad=False)
elif isinstance(param, DTensor):
return DTensor( # type: ignore
tensor,
param.device_mesh,
param.placements,
shape=param.size(),
dtype=tensor.dtype,
requires_grad=False,
stride=param.stride(),
)
return DTensor.from_local(tensor, device_mesh=param.device_mesh, placements=param.placements)
elif isinstance(param, nn.Parameter) and isinstance(param.data, DTensor):
return _wrap_tensor_for_sharded_parameter(tensor, param.data)
else:
Expand Down
12 changes: 6 additions & 6 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def shard_params(

# Now that we have all of the flat parameters we need to collate all of their data into a single
# sharded flat tensor, then set the data for each flat parameter as a view into that flat tensor.
local_flat_sharded_data = torch.cat([flat_param.data for flat_param in flat_params])
local_flat_sharded_data = torch.cat([flat_param.sharded_data for flat_param in flat_params])
params_data = ShardedFlatTensor(
F.pad(local_flat_sharded_data, (0, padded_sharded_numel - local_flat_sharded_data.numel()))
)
Expand All @@ -245,7 +245,7 @@ def shard_params(
)
offset = 0
for flat_param in flat_params:
flat_param.data = params_data[offset : offset + flat_param.numel()]
flat_param.sharded_data = params_data[offset : offset + flat_param.numel()]
offset += flat_param.numel()

return cls(
Expand Down Expand Up @@ -319,7 +319,7 @@ def unshard_(
assert self.params_data.is_sharded
self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only)
if set_grads and self.requires_grad:
self.params_unsharded_grad = torch.zeros_like(self.params_data)
self.params_unsharded_grad = torch.zeros_like(self.params_data.data)
else:
assert not self.params_data.is_sharded
# We prefer to use `all_gather_into_tensor()` directly when possible as it involves
Expand All @@ -342,9 +342,9 @@ def unshard_(
offset = 0
for param in self.params:
if rank0_only and local_rank != 0:
unsharded_data = torch.empty_like(self.params_data)
unsharded_data = torch.empty_like(self.params_data.data)
else:
unsharded_data = self.params_data[offset : offset + param.unsharded_numel]
unsharded_data = self.params_data.data[offset : offset + param.unsharded_numel]

param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only)

Expand All @@ -369,7 +369,7 @@ def reshard_(self, writeback: bool = False):
flat_param.reshard_(writeback=False)
if writeback:
# Reset the view into the new `params_data`.
flat_param.data = self.params_data[offset : offset + flat_param.sharded_numel]
flat_param.sharded_data = self.params_data[offset : offset + flat_param.sharded_numel]
offset += flat_param.sharded_numel

def pre_reduce_scatter_grads_(
Expand Down
23 changes: 12 additions & 11 deletions src/olmo_core/distributed/tensors/sharded_flat_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True
)

if isinstance(data, ShardedFlatTensor):
setattr(
param,
cls.SHARDED_FLAT_TENSOR_METADATA_NAME,
getattr(data, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}).copy(),
)
param._local_tensor = data._local_tensor
param._sharding_spec = data._sharding_spec
param._process_group = data._process_group
else:
setattr(param, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {})
param._local_tensor = None if data is None else data.data.detach()
param._sharding_spec = None # type: ignore[assignment]
param._process_group = None

param._global_tensor = None

return param

def __repr__(self) -> str:
r = torch.Tensor.__repr__(self)
if r.startswith("Parameter("): # ) -- the open parenthesis confuses treesitter sometimes
r = r.replace("Parameter(", "", 1) # ) -- the open parenthesis confuses treesitter sometimes
r = r[:-1]
return r
if self._global_tensor is not None:
return f"ShardedFlatParameter(local_tensor={self._local_tensor}, global_tensor={self._global_tensor}, requires_grad={self.requires_grad})"
else:
return f"ShardedFlatParameter(local_tensor={self._local_tensor}, requires_grad={self.requires_grad})"
Loading

0 comments on commit 22db20c

Please sign in to comment.