Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Native Distributed Data Parallels Support #50

Open
hiberfil opened this issue May 3, 2024 · 1 comment
Open

Adding Native Distributed Data Parallels Support #50

hiberfil opened this issue May 3, 2024 · 1 comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed

Comments

@hiberfil
Copy link

hiberfil commented May 3, 2024

Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.

Below is how I invoke it.

torchrun --nproc_per_node=8 ddptest.py

import os
import pandas as pd
from be_great import GReaT
import torch.distributed as dist
import torch
from collections import OrderedDict

def main():
    # Set CUDA devices for each process
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    dataFile = "/edit/for/your/own/repo.csv"
    data = pd.read_csv(dataFile)

    great = GReaT("gpt2-xl",         
                      batch_size=8,
                      epochs=50,                           
                      fp16=True
                     )

   # Move the model to the appropriate GPU
    great.model.to(local_rank)  

    # Wrap the model for distributed training
    great.model = torch.nn.parallel.DistributedDataParallel(
        great.model, device_ids=[local_rank], output_device=local_rank
    )

    trainer = great.fit(data, data.columns.to_list())

    
        # Save the model only from rank 0 process
    if dist.get_rank() == 0:
        # Create a new state dict with corrected key names
        state_dict = great.model.state_dict()
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v

        # Save the model with the modified state dictl
        torch.save(new_state_dict, "/edit/for/your/own/model.pt")


if __name__ == "__main__":
    # Initialize the distributed process group
    dist.init_process_group(backend="nccl") 
    main()

Again thank you so much for this awesome framework.

@unnir unnir added help wanted Extra attention is needed good first issue Good for newcomers labels May 6, 2024
@unnir
Copy link
Collaborator

unnir commented May 6, 2024

Hi @hiberfil,

Thank you for choosing our framework :)

So far we do not have plans about adding native distributed data parallels support. However, it will be great to have, therefore any contributions are very welcome.

Also, thank you for providing a simple workaround script, it will be definitely useful for others!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants