Skip to content

Commit

Permalink
[DTensor & DModule & DOptim] feature updates (#20)
Browse files Browse the repository at this point in the history
## In this PR, we update some features in our DTensor & DModule & DOptim
implementations, Yo~

### DTensor Updates:
1.  Support more dtensor ops.
2.  Sharding Strategy Updates.

### DModule Updates:
1. Decouple uneven support and run check.
2. Reduce some CPU overhead.

### DOptim Updates:
1.  More fridenly API.
2.  Unit test updates.
3.  Reorder some communication for better results.

### Other Updates/fixes:
1. Some minor update on our nano GPT model and test results.
  • Loading branch information
jc-bytedance authored Apr 3, 2024
1 parent c6981e6 commit 364c3b2
Show file tree
Hide file tree
Showing 38 changed files with 2,085 additions and 700 deletions.
16 changes: 14 additions & 2 deletions python/example/nanogpt_4D_finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

In this example, we demonstrate how to finetune a pre-trained GPT2 using veScale. The example is built upon @karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT/) project. With near-zero change in the model code and minimal changes in the training code, we can finetune a pre-trained GPT2 on the Shakespeare dataset and utilize multiple GPUs via 4D parallelism: Data, Tensor, Sequence, and Optimizer Parallelism. The correctness of our implementation is verified via comparing both the training and the validation loss with the single GPU result produced by nanoGPT. The differences is negligible when the computation is conducted using fp32, ~1% using bf16.

## Prerequisites
## Prerequisite

```
pip3 install datasets tiktoken
pip3 install tiktoken datasets
```

## Run
Expand Down Expand Up @@ -35,6 +35,18 @@ Here are the training Loss and validation loss curves plot for fp32 runs that la

![figure](./figures/nanoGPT_finetune_4d_train_loss_fp32_200.jpg)

For the bf16 runs, in `base_train.py`, instead of using `torch.amp.autocast`, we cast the model to bf16 directly and both the gradients and the optimizer states are casted to bf16 automatically. For a fair comparison, we modify veScale to store both the gradients and the optimizer state in bf16 instead of fp32.

![figure](./figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg)


![figure](./figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg)

## Difference from the upstream nanoGPT

1. When training with bf16 (`--dtype='bfloat16'`), the model is casted to bf16 and we remove the usage of `amp.autocast`.
2. Sampling mini-batches is done at the 0th rank and the indices is later broadcasted to other ranks. This ensures that both `base_train.py` and `finetune_4D.py` works on the identical batch every iteration.

## Caveats

1. `torch.compile` for veScale is still experimental. We run the single GPU baseline with the `compile` flag off.
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions python/example/nanogpt_4D_finetune/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
"transformer.wte.output": [[Replicate()]],
"transformer.wpe.input": [[Replicate()]],
"transformer.wpe.output": [[Replicate()]],
r"transformer.h.\d+.ln_1.input": [[Shard(1)]],
r"transformer.h.\d+.input": [[Shard(1)]],
r"transformer.h.\d+.attn.input": [[Replicate()]],
r"transformer.h.\d+.attn.c_proj.output": [[Replicate()]],
r"transformer.h.\d+.ln_2.input": [[Shard(1)]],
r"transformer.h.\d+.attn.output": [[Shard(1)]],
r"transformer.h.\d+.mlp.c_fc.input": [[Replicate()]],
r"transformer.h.\d+.mlp.c_proj.output": [[Replicate()]],
r"transformer.h.\d+.mlp.output": [[Shard(1)]],
"transformer.ln_f.input": [[Shard(1)]],
"lm_head.input": [[Shard(2)]],
"lm_head.output": [[Replicate()]],
Expand Down
22 changes: 13 additions & 9 deletions python/vescale/ddp/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import torch
import torch.distributed.distributed_c10d as c10d

from vescale.dmodule._dmodule import DModule
from vescale.dtensor.dtensor import DTensor
from vescale.dtensor.device_mesh import DeviceMesh
from vescale.dtensor.placement_types import DTensorSpec
from vescale.ddp.grad_buffer import GradBuffer


Expand Down Expand Up @@ -62,7 +60,6 @@ class DistributedDataParallel(torch.nn.Module):
# run the forward.
ddp_module(torch.rand(xxx))
```
TODO: remove `shared` attributed attached by Megatron.
"""

def __init__(
Expand Down Expand Up @@ -192,14 +189,26 @@ def param_hook(*unused):
assert param.grad is not None, "param.grad being None is not safe when overlap_grad_reduce is True"
# NOTE: it seems that there are some place where grad_added_to_main_grad is True.
# what will happen then?

# TODO: remove grad_added_to_main_grad attribute.
model_parallel_device_mesh, placements = None, None
if param.grad is not None and not param.grad_added_to_main_grad:
if isinstance(param.data, DTensor):
param.main_grad.add_(param.grad._local_tensor.data) # add DTensor's data
param.main_grad._spec: DTensorSpec = param.grad._spec # save DTensor's spec
model_parallel_device_mesh = param.grad._spec.mesh
placements = param.grad._spec.placements
else:
param.main_grad.add_(param.grad.data)
param.grad = None

if (
model_parallel_device_mesh is not None
and placements is not None
and any(p.is_partial() for p in placements)
):
param_to_grad_buffer[param].register_partial_grad_ready(
param, model_parallel_device_mesh, placements
)
if self.overlap_grad_reduce:
param_to_grad_buffer[param].register_grad_ready(param)

Expand Down Expand Up @@ -229,11 +238,6 @@ def finish_grad_sync(self):
for grad_buffer in self.grad_buffers.values():
grad_buffer.finish_grad_sync()

# NOTE: here we do DDP.AllReduce(Mean) before DModule.AllReduce(Sum),
# which can cause different precision with Megatron baseline.
if DModule.is_dmodule(self.module):
self.module.finish_grad_sync()

for expert_grad in self.expert_grads:
expert_grad /= self.data_parallel_world_size

Expand Down
53 changes: 52 additions & 1 deletion python/vescale/ddp/grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import math
import warnings
from typing import Dict, List, Union
from typing import Dict, List, Union, Sequence

import torch
import torch.distributed as dist

from vescale.dtensor.dtensor import DTensor
from vescale.dtensor.device_mesh import DeviceMesh
from vescale.dtensor.placement_types import Placement


def get_param_nelements(param: Union[torch.nn.Parameter, torch.Tensor]) -> int:
Expand Down Expand Up @@ -75,7 +77,9 @@ def reset(self):
"""
self.params_with_grad = set()
self.communication_handle = None
self.partial_grad_communication_handle = None
self.communication_issued = False
self.partial_grad_communication_issued = False

def shard_buffer(self, buffer: torch.Tensor):
"""
Expand All @@ -88,6 +92,23 @@ def shard_buffer(self, buffer: torch.Tensor):
]
return sharded_buffer

def all_reduce_partial_grad(
self, partial_main_grad, model_parallel_device_mesh: DeviceMesh, placements: Sequence[Placement]
):
# wait for the last partial grad all-reduce finish
if self.partial_grad_communication_handle is not None and self.partial_grad_communication_issued:
self.partial_grad_communication_handle.wait()

# TODO: there may be other invalid cases, we should add more checks here.
partial_mesh_idxes = [i for i, p in enumerate(placements) if p.is_partial()]
assert len(partial_mesh_idxes) == 1, "currently, we only consider a single Partial on the same mesh dim."
model_parallel_pg = model_parallel_device_mesh.get_dim_groups(partial_mesh_idxes[0])

self.partial_grad_communication_handle = dist.all_reduce(
partial_main_grad, group=model_parallel_pg, async_op=True
)
self.partial_grad_communication_issued = True

def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operation
Expand All @@ -97,6 +118,11 @@ def start_grad_sync(self):
communication call. When overlap_grad_reduce is set to False, makes
synchronous call.
"""

# We must wait until all partial grad in this bucket is all-reduced.
if self.partial_grad_communication_handle is not None and self.partial_grad_communication_issued:
self.partial_grad_communication_handle.wait()

assert (
self.communication_handle is None and not self.communication_issued
), "Should not have multiple communication calls in flight at once"
Expand Down Expand Up @@ -164,6 +190,19 @@ def register_grad_ready(self, param: torch.nn.Parameter):
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()

def register_partial_grad_ready(
self,
param: torch.nn.Parameter,
model_parallel_device_mesh: DeviceMesh,
placements: Sequence[Placement],
):
"""
Immediately trigger partial gradient all-reduce in an async way.
"""
assert param in self.params, "Param is not in the bucket"
assert any(p.is_partial() for p in placements), "Param's grad should be partial sharded"
self.all_reduce_partial_grad(param.main_grad, model_parallel_device_mesh, placements)


class GradBuffer:
"""
Expand Down Expand Up @@ -411,3 +450,15 @@ def register_grad_ready(self, param: torch.nn.Parameter):
if self.is_last_microbatch:
bucket = self.param_to_bucket[param]
bucket.register_grad_ready(param)

def register_partial_grad_ready(
self,
param: torch.nn.Parameter,
model_parallel_device_mesh: DeviceMesh,
placements: Sequence[Placement],
):
"""
Immediately trigger partial gradient all-reduce in an async way.
"""
bucket = self.param_to_bucket[param]
bucket.register_partial_grad_ready(param, model_parallel_device_mesh, placements)
4 changes: 3 additions & 1 deletion python/vescale/dmodule/_dmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ def _distribute_parameter(

# regular intialization
if is_sharded:
dt = DTensor.from_local(t, device_mesh, pi.placements, run_check=pi.run_check)
dt = DTensor.from_local(
t, device_mesh, pi.placements, run_check=pi.run_check, support_uneven=pi.support_uneven
)
else:
dt = distribute_tensor(t, device_mesh, pi.placements)
return nn.Parameter(dt, requires_grad=param.requires_grad) if is_param else dt
Expand Down
31 changes: 9 additions & 22 deletions python/vescale/dmodule/_grad_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
################################################################################

"""This file handles gradient allreduce for DModule
"""This file handles gradient allreduce for DModule with no DDP
NOTE:
- If wrapped by DDP, it is called after DDP.finish_grad_sync()
- `generate_grad_sync_list` is not recommended to be placed into a param.grad pre-hook, because:
i) having multiple hooks on param.grad complicates the design and debugging
ii) gradient accumlation will repeatedly fire param.grad pre-hook, degrading performance
Expand All @@ -47,21 +46,12 @@ def generate_grad_sync_list(candidate: List[Tuple[str, DTensor]]) -> List[Tuple[
for fqn, param in candidate:
assert param.requires_grad
assert isinstance(param.data, DTensor)
if hasattr(param, "main_grad"):
if param.main_grad is None:
continue
grad_spec = getattr(param.main_grad, "_spec", None)
assert grad_spec is not None, "DDP's .main_grad must save DTensor .grad's _spec"
placements = grad_spec.placements
fqn += ".main_grad"
grad = param.main_grad
else:
assert hasattr(param, "grad")
if param.grad is None:
continue
placements = param.grad.placements
fqn += ".grad"
grad = param.grad
assert hasattr(param, "grad")
if param.grad is None:
continue
placements = param.grad.placements
fqn += ".grad"
grad = param.grad
if any(p.is_partial() for p in placements):
grad_sync_list.append((fqn, grad))
return grad_sync_list
Expand Down Expand Up @@ -122,11 +112,8 @@ def sync_gradients(grad_sync_list: List[Tuple[str, Union[Tensor, DTensor]]], dev
# get local tensors to allreduce + get process group to allreduce
local_gradients = []
partial_mesh_idxes = set()
for fqn, grad in grad_sync_list:
if fqn.endswith("main_grad"):
local_gradients.append(grad.data)
else:
local_gradients.append(grad._local_tensor)
for _, grad in grad_sync_list:
local_gradients.append(grad._local_tensor)
partial_mesh_idxes.update([i for i, p in enumerate(grad._spec.placements) if p.is_partial()])
assert len(partial_mesh_idxes) == 1, "currently, we only consider a single Partial on the same mesh dim."
partial_pg = device_mesh.get_dim_groups(partial_mesh_idxes.pop())
Expand Down
15 changes: 11 additions & 4 deletions python/vescale/dmodule/_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def _convert_by_pi(
return x
return x.redistribute(device_mesh, pi.placements, async_op=pi.async_op)
if isinstance(x, torch.Tensor):
return DTensor.from_local(x, device_mesh, pi.placements, run_check=pi.run_check, async_input=pi.async_op)
return DTensor.from_local(
x,
device_mesh,
pi.placements,
run_check=pi.run_check,
support_uneven=pi.support_uneven,
async_input=pi.async_op,
)
if not raise_err:
logging.info("binding a placement %s with a %s obj: %s. The placement is ignored.", pi.placements, type(x), x)
return x
Expand Down Expand Up @@ -198,9 +205,9 @@ def _hook(
output_pis: FwdPIs,
):
if isinstance(output, Sequence) and isinstance(output_pis, Sequence):
assert len(output) == len(
output_pis
), f"Mismatched actual output size: {output} vs. plaments size: {output_pis}!"
assert (
len(output) == len(output_pis)
), f"Mismatched actual output size: {[x._spec if isinstance(x, DTensor) else x for x in output]} vs. plaments size: {output_pis}!"
return [PostHookOutput._convert(o, pi, device_mesh) for o, pi in zip(output, output_pis)]
if isinstance(output, DTensor) and output_pis[0] is not None:
return PostHookOutput._convert(output, output_pis[0], device_mesh)
Expand Down
12 changes: 8 additions & 4 deletions python/vescale/dmodule/placements_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PlacementsInterface:
async_op: bool = True # flag for DTensor.redistribute/from_local
defer_reshard: bool = False # flag for deferred resharding mode
run_check: bool = True # flag for DTensor.from_local
skippable_op: bool = True # flag for DTensor.redistribute # TODO: to enable
support_uneven: bool = True # flag for DTensor.from_local
grad: Optional[Sequence[Placement]] = None # the placement to enforce on this tensor.grad

@classmethod
Expand All @@ -43,9 +43,13 @@ def from_placements(cls, placements: Any) -> Any:
return placements
return cls(placements)

def normalize_placements(self, mesh_ndim: int) -> None:
self.placements = normalize_placements(self.placements, mesh_ndim)
self.grad = normalize_placements(self.grad, mesh_ndim)
def normalize_placements(self, mesh_ndim: int, *, tensor_ndim: int = 0, none_as_replicate: bool = False) -> None:
self.placements = normalize_placements(
self.placements, mesh_ndim, tensor_ndim=tensor_ndim, none_as_replicate=none_as_replicate
)
self.grad = normalize_placements(
self.grad, mesh_ndim, tensor_ndim=tensor_ndim, none_as_replicate=none_as_replicate
)

def is_none(self) -> bool:
"""Is it equivalent to `None` placements;
Expand Down
4 changes: 3 additions & 1 deletion python/vescale/dtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def _dtensor_init_helper(
device_mesh = device_mesh or mesh_resources.get_current_mesh()
device = device_mesh.device_type
# get placements
placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True)
placements: Tuple[Placement] = normalize_placements(
placements, device_mesh.ndim, tensor_ndim=len(global_shape), none_as_replicate=True
)
# get local tensor shape
local_shape = compute_local_shape(global_shape, device_mesh, placements)
# initialize the local tensor
Expand Down
Loading

0 comments on commit 364c3b2

Please sign in to comment.