Skip to content

Commit

Permalink
added veDeviceMesh
Browse files Browse the repository at this point in the history
  • Loading branch information
MackZackA committed Apr 22, 2024
1 parent 2f2daaa commit b050069
Show file tree
Hide file tree
Showing 24 changed files with 860 additions and 847 deletions.
769 changes: 90 additions & 679 deletions patches/patched_pytorch_v2.2.1_rc3.patch

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions python/example/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@

import numpy as np
import torch
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group, get_rank

from model import GPTConfig, GPT
from vescale.devicemesh_api.device_mesh_api import veDeviceMesh

from vescale.dtensor.device_mesh import init_device_mesh
from vescale import distribute_tensor
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.placement_types import Replicate
Expand Down Expand Up @@ -113,8 +113,9 @@ def main():
device = f"cuda:{rank}"
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)
mesh = init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = mesh.get_rank() // tp_size

mesh = veDeviceMesh.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = get_rank() // tp_size
else:
rank = 0
ddp_rank = 0
Expand Down Expand Up @@ -329,8 +330,7 @@ def get_lr(it):
# Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
# training loop
X, Y = get_batch("train") # fetch the very first batch
t0 = time.time()
Expand Down Expand Up @@ -363,8 +363,7 @@ def get_lr(it):
# When iter_num == 0, the training does not start sotoptimizer state is empty,
# Don't save checkpoint
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
if iter_num == 0 and eval_only:
break

Expand Down
2 changes: 0 additions & 2 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@ tqdm
optree
accelerate
transformers==4.37.2
grpcio
grpcio-tools
8 changes: 3 additions & 5 deletions python/vescale/checkpoint/planner/vescale/vescale_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
find_state_dict_object,
)

from vescale.dtensor.device_mesh import mesh_resources
from vescale.devicemesh_api import veDeviceMesh

logger: logging.Logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -190,8 +190,6 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
A function for creating local saving plan for saving checkpoint
"""
requests = []
device_mesh = mesh_resources.get_current_mesh()
dp_device_mesh = device_mesh["DP"]
for fqn, obj in state_dict.items():
# Since DTensor supports submesh, adding extra check to ensure _create_write_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
Expand Down Expand Up @@ -232,7 +230,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.irecv,
tensor=recv_tensor,
peer=k,
group=dp_device_mesh.get_dim_groups(0),
group=veDeviceMesh.get_data_parallel_dim_groups(),
)
recv_tensors[k] = recv_tensor
p2p_ops.append(recv_op)
Expand All @@ -243,7 +241,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.isend,
tensor=obj.local_tensor,
peer=writer_rank,
group=dp_device_mesh.get_dim_groups(0),
group=veDeviceMesh.get_data_parallel_dim_groups(),
)
p2p_ops.append(send_op)

Expand Down
4 changes: 2 additions & 2 deletions python/vescale/checkpoint/storage/checkpoint_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_megatron_tp_group(world_size, pp_size, tp_size, dp_size, cur_rank) -> t

def _deduce_parallel_plan_by_device_mesh(mesh: DeviceMesh):
"""make rank to megatron tp_rank, pp_rank map"""
# FIXME(cery.69) : current only support data parallel is 1
# FIXME : current only support data parallel is 1
# allways parallel in last dim
tp_size = mesh.size()
# for rank = pp_rank * tp_size + tp_rank
Expand Down Expand Up @@ -261,7 +261,7 @@ def find_device_mesh(st):
torch.save(optim, os.path.join(megatron_optim_dict_path, "optim.pt"))
del st["optim"]
torch.save(st, megatron_save_file)
# FIXME(cery.69): support dp not 1
# FIXME: support dp not 1
return st


Expand Down
18 changes: 18 additions & 0 deletions python/vescale/devicemesh_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from .device_mesh_api import veDeviceMesh
Loading

0 comments on commit b050069

Please sign in to comment.