Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Single-Device-Abstract DDP #52

Open
lllukehuang opened this issue Aug 30, 2024 · 1 comment
Open

[RFC] Single-Device-Abstract DDP #52

lllukehuang opened this issue Aug 30, 2024 · 1 comment
Assignees
Labels
rfc Let's discuss a proposal

Comments

@lllukehuang
Copy link

lllukehuang commented Aug 30, 2024

Single-Device-Abstract DDP

Motivation

In current PyTorch DDP, when training a model with Dropout operations, the final results obtained from distributed training will not be consistent with those obtained from single-machine training. This is mainly because the RNG state offset is copied across DP workers, and the Dropout mask calculation results are the same for all DP workers. In contrast, on a single device, the Dropout mask results for sequentially input micro-batches do not share this dependency, leading to a misalignment between DP dropout operations and single-machine computation results. We resolve the issue in veScale via deep understanding of how GPUs generate random numbers parallelly and torch cuda random generation implementation patch.

image

We have validated the prototype on several open-source models, including Llama2, Llama3, GPT2, and Mixtual, and it successfully ensures that the loss curve remains consistent with single-device training when DP is enabled.

We welcome any and all feedback on this effort!

Design

To ensure consistency in random number generation between distributed and single-machine scenarios, veScale proposed the ThreadBasedRNGTracer, which regulates the thread ID used during CUDA random number generation. This ensures that the thread ID used are identical in both single-machine and parallel scenarios. However, the existing Random Op processing only considers thread ID adjustments in Tensor Parallel scenarios and overlooks the consistency issues in random number generation caused by Data Parallelism.

Following veScale's former approach to handling TP Random operations, we can inject additional DP-related information into the torch CUDA RNG state. During the CUDA random generation process, this DP information can be retrieved from the RNG state, enabling the generation of correct local random results.

image

For example, consider a scenario with 4 GPUs and a parallel setup of DP=2 and TP=2. In this case, veScale manually adjusts the thread ID on each GPU based on the corresponding parallel configuration to ensure consistency with the single-machine state.

image

RNGTracker API

To ensure the correct generation of single-device-abstracted random numbers when DP is enabled, users need to wrap the random number generation code with the _distribute_region context manager. The DTensorSpec provided to this context manager includes information related to TP, while the dp_size and dp_rank parameters specify the DP-related information.

import vescale.dtensor.random as random

with random._rng_tracker._distribute_region(DTensorSpec, dp_size, dp_rank):
    # here to process random gen ops

Compared with the current master branch, the _distribute_region input now includes dp_size and dp_rank. When these two values are not provided, dp_size defaults to 1 and dp_rank defaults to 0, meaning that the DP-based random number generation adjustment is not enabled by default.


cc @leonardo0lyj @MackZackA @JsBlueCat

@vocaltract vocaltract added the rfc Let's discuss a proposal label Aug 30, 2024
@leonardo0lyj
Copy link
Collaborator

@lllukehuang Great work, indeed!

@leonardo0lyj leonardo0lyj changed the title [RFC] Single-Device-Abstract DDP Dropout [RFC] Single-Device-Abstract DDP Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Let's discuss a proposal
Projects
None yet
Development

No branches or pull requests

3 participants