Skip to content

Commit

Permalink
Merge branch 'master' into torch_load_weight_only_warning
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Oct 25, 2024
2 parents 4d291e0 + 6e6563d commit 69fc219
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 82 deletions.
13 changes: 6 additions & 7 deletions .github/workflows/xpu-max1100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
unit-tests:
runs-on: [self-hosted, intel, xpu]
container:
image: intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
ports:
- 80
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL
Expand All @@ -47,12 +47,11 @@ jobs:
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.1.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu
pip install intel-extension-for-pytorch==2.1.30+xpu -f https://developer.intel.com/ipex-whl-stable-xpu
pip install intel-extension-for-pytorch-deepspeed==2.1.30 -f https://developer.intel.com/ipex-whl-stable-xpu
pip install oneccl_bind_pt==2.1.300+xpu -f https://developer.intel.com/ipex-whl-stable-xpu
pip install torchvision==0.16.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu
pip install py-cpuinfo numpy==1.26
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install py-cpuinfo numpy
pip install .[dev,autotuning]
- name: Check container state
Expand Down
4 changes: 2 additions & 2 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def is_synchronized_device(self):
return False

def use_host_timers(self):
# WA XPU event will be consolidated in 2.5
if ipex.__version__ < '2.5':
# WA XPU event will be consolidated in 2.6
if ipex.__version__ < '2.6':
return True
else:
return self.is_synchronized_device()
Expand Down
5 changes: 4 additions & 1 deletion csrc/aio/py_lib/deepspeed_cpu_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ void cpu_op_desc_t::finish()
{
if (_use_bounce_buffer) {
if (_read_op) {
if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
if (_buffer.is_cuda()) {
_buffer.copy_(_cpu_buffer.to(torch::Device(torch::kCUDA, _buffer.get_device()),
/*non_blocking=*/true));
}
if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); }
#if defined(__ENABLE_CANN__)
Expand Down
2 changes: 2 additions & 0 deletions csrc/aio/py_lib/deepspeed_pin_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t()
{
for (auto iter = _locked_tensors.begin(); iter != _locked_tensors.end(); ++iter) {
munlock(iter->first, iter->second);
std::free((void*)iter->first);
}
_locked_tensors.clear();
}
Expand Down Expand Up @@ -43,6 +44,7 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor)
auto addr = locked_tensor.data_ptr();
if (_locked_tensors.find(addr) != _locked_tensors.end()) {
munlock(addr, _locked_tensors[addr]);
std::free(addr);
_locked_tensors.erase(addr);
return true;
}
Expand Down
3 changes: 1 addition & 2 deletions csrc/gds/py_lib/deepspeed_gds_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ void gds_op_desc_t::add_buffer_to_registry(const torch::Tensor& buffer)
const int64_t device = buffer.get_device();
void* reg_ptr = buffer.data_ptr();

// std::cout << "REG PTR " << reg_ptr << std::endl;
// TODO: add checking to make sure pointer isn't already in set
const auto it = base_ptr_registry.find(device);
if (it == base_ptr_registry.end()) {
Expand Down Expand Up @@ -94,7 +93,7 @@ gds_op_desc_t::gds_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate)
: io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, intra_op_parallelism, validate)
Expand Down
2 changes: 1 addition & 1 deletion csrc/gds/py_lib/deepspeed_gds_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct gds_op_desc_t : io_op_desc_t {
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate);

Expand Down
2 changes: 1 addition & 1 deletion csrc/gds/py_lib/deepspeed_py_gds_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ std::shared_ptr<struct io_op_desc_t> deepspeed_gds_handle_t::_create_io_op_desc(
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int64_t file_num_bytes,
const bool validate)
{
if (buffer.is_cuda()) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/gds/py_lib/deepspeed_py_gds_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct deepspeed_gds_handle_t : deepspeed_io_handle_t {
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int64_t file_num_bytes,
const bool validate);

static int s_cuFile_init;
Expand Down
51 changes: 0 additions & 51 deletions csrc/xpu/aio/deepspeed_cpu_op.cpp

This file was deleted.

11 changes: 8 additions & 3 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,14 @@ def init_device_mesh(self, mesh_shape, mesh_dim_names):
if not required_torch_version(min_version=2.2):
raise RuntimeError(f"Current torch version does not have device mesh"
f"api (torch.__version__: {torch.__version__})")
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)
if not required_torch_version(max_version=2.4):
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)
else:
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)


# This will become a light-weight wrapper around torch.distributed functions
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def set_lm_head(module):
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module.language_model, "lm_head"):
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
return module

Expand Down
9 changes: 2 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,14 +1070,10 @@ def average_tensor(self, tensor):
for i, param, param_id in self.params_in_ipg_bucket:

process_group = self.dp_process_group
grad_reduc = self.get_gradient_for_reduction(param)
#Averages gradients at parameter level if ipg has a moe param
#Otherwise averaging is done at the entire buffer level at the end of the loop
# MoE param have different groups

if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.dp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
Expand Down Expand Up @@ -1116,8 +1112,7 @@ def average_tensor(self, tensor):
curr_size += numel
prev_id, prev_process_group = partition_id, process_group

if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

buckets = {}
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
Expand Down
2 changes: 1 addition & 1 deletion op_builder/xpu/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sources(self):
'csrc/aio/common/deepspeed_aio_types.cpp',
'csrc/aio/py_lib/deepspeed_pin_tensor.cpp',
'csrc/aio/py_lib/deepspeed_py_io_handle.cpp',
'csrc/xpu/aio/deepspeed_cpu_op.cpp',
'csrc/aio/py_lib/deepspeed_cpu_op.cpp',
'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp',
]

Expand Down
112 changes: 112 additions & 0 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import deepspeed
import pytest
import gc
import random
from unit.common import DistributedTest
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
Expand Down Expand Up @@ -238,3 +239,114 @@ def check_equal(logits, cap, sparse_truth, res):
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)


class TestExpertWeightGradWithZero(DistributedTest):
world_size = 2

@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test(self, zero_stage):

if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

def seed_everything(seed=11):
random.seed(seed)
torch.manual_seed(seed)
get_accelerator().manual_seed(seed)
get_accelerator().manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def get_state_dict_ep2(state_dict):
"""
convert state_dict from EP=1 to EP=2
"""
rank = int(deepspeed.comm.get_rank())
ep_state_dict = dict()
dst_sub_key = f"deepspeed_moe.experts.deepspeed_experts.0"
src_sub_key = f"deepspeed_moe.experts.deepspeed_experts.{rank}"
for moe_layer in ["moe_1", "moe_2"]:
for mlp_in_moe in [0, 1]:
dst_key = f"{moe_layer}.{dst_sub_key}.{mlp_in_moe}"
src_key = f"{moe_layer}.{src_sub_key}.{mlp_in_moe}"
ep_state_dict[f"{dst_key}.weight"] = state_dict[f"{src_key}.weight"].detach().clone()
ep_state_dict[f"{dst_key}.bias"] = state_dict[f"{src_key}.bias"].detach().clone()

for key in state_dict.keys():
if "deepspeed_moe.experts.deepspeed_experts" not in key:
ep_state_dict[key] = state_dict[key].detach().clone()
return ep_state_dict

def get_models(hidden_dim):
model_ep1 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=1, use_rts=False)
model_ep2 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=2, use_rts=False)

state_dict_ep1 = model_ep1.state_dict()
state_dict_ep2 = get_state_dict_ep2(state_dict_ep1)
model_ep2.load_state_dict(state_dict_ep2)

model_ep1, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep1)
model_ep2, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep2)

return model_ep1, model_ep2

def extract_expert_grad(model, expert_id):

def _get_weight_bias(experts):
return ([deepspeed.utils.safe_get_full_grad(expert[0].weight)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[0].bias)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[1].weight)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[1].bias)
for expert in experts][expert_id].detach().clone())

return (*_get_weight_bias(model.moe_1.deepspeed_moe.experts.deepspeed_experts),
*_get_weight_bias(model.moe_2.deepspeed_moe.experts.deepspeed_experts))

seed_everything()

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.1,
}
},
"zero_optimization": {
"stage": zero_stage
}
}

hidden_dim = 4
total_samples = 2
rank = deepspeed.comm.get_rank()
model_ep1, model_ep2 = get_models(hidden_dim)

data_loader = sequence_dataloader(model=model_ep1,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model_ep1.device,
dtype=torch.float32)
expert_weight_grad_ep1 = []
expert_weight_grad_ep2 = []
for batch in data_loader:
loss_ep1 = model_ep1(batch[0], batch[1])
loss_ep2 = model_ep2(batch[0], batch[1])

model_ep1.backward(loss_ep1)
model_ep2.backward(loss_ep2)

expert_weight_grad_ep1.extend(extract_expert_grad(model_ep1, rank))
expert_weight_grad_ep2.extend(extract_expert_grad(model_ep2, 0))

model_ep1.step()
model_ep2.step()

assert len(expert_weight_grad_ep1) == len(expert_weight_grad_ep2)
for grad_from_ep1, grad_from_ep2 in zip(expert_weight_grad_ep1, expert_weight_grad_ep2):
assert torch.allclose(grad_from_ep1, grad_from_ep2, atol=0, rtol=1e-4)
8 changes: 5 additions & 3 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, x, y, **kwargs):

class SimpleMoEModel(torch.nn.Module):

def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False, use_rts=True):
super(SimpleMoEModel, self).__init__()
self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
Expand All @@ -89,7 +89,8 @@ def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
k=1,
use_rts=use_rts)
# interleaving MoE modules with dense to create an opportunity
# for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
Expand All @@ -98,7 +99,8 @@ def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
k=1,
use_rts=use_rts)
self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.15.3
0.15.4

0 comments on commit 69fc219

Please sign in to comment.