Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi gpu support #3548

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

jeffpicard
Copy link

@jeffpicard jeffpicard commented Sep 24, 2024

Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.

There's a related issue here, and a past PR that never ended up merging.

The approach:

  • This PR uses raw pytorch's DistributedDataParallel rather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences.
  • In order to use multiple GPUs, users would call a launch_distributed mechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes inside Trainer.train_custom. However, I ran into problems doing it this way (e.g. TransformerEmbeddings and Pluggable._event_queue would not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.

There are still TODOs. For example, the logging inside .train_custom prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements like if is_main_process(): or torch.distributed.gather_object to aggregate metrics across processes, similar to what's done for the eval steps in this PR.

Example usage:

import random

import torch

from flair.datasets import IMDB
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer


def example(multi_gpu):
    random.seed(1337)
    corpus = IMDB()
    corpus.downsample(0.1)
    label_type = "sentiment"
    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    mini_batch_chunk_size = 32
    num_processes = max(torch.cuda.device_count(), 1)
    mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_processes  # fair comparision in terms of batch-updates
    trainer.fine_tune("./tmp", multi_gpu=multi_gpu, max_epochs=2, mini_batch_chunk_size=mini_batch_chunk_size, mini_batch_size=mini_batch_size)

if __name__ == "__main__":
    multi_gpu=True
    if multi_gpu:
        launch_distributed(example, multi_gpu)
    else:
        example(multi_gpu)

@alanakbik
Copy link
Collaborator

Hello @jeffpicard this is awesome, thanks for the PR!

@helpmefindaname @HallerPatrick can you take a look?

flair/distributed_utils.py Outdated Show resolved Hide resolved
flair/distributed_utils.py Outdated Show resolved Hide resolved
@HallerPatrick
Copy link
Collaborator

Hey @jeffpicard, thanks for the PR.

I tested your changes with different number of GPUs and can, more or less, reproduce your speedups!

I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with:

if flair.distributed:
   self.model = DistributedModel(self.model, device_ids=[flair.device.index])
       
   # Disable logging in distributed mode for all but the main process
   log.disabled = not is_main_process()    

Here some points from my side:

  1. I am a little suspicious about the DistributedModel wrapper, where we can now arbitrarily update the DistributedDataParallel model without knowing if it effects any distributed-logic. I see the convenience of it.
    Maybe we can check every __getattr__ and __setattr__ call just to be on the save side here :P

  2. Model saving logic is still distributed. Easily fixable

  3. How to handle "best model" logic after each epoch. Should we just naively test the main process model? I dont know if this is nitpicky...

On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :)

Thank you!

@jeffpicard
Copy link
Author

jeffpicard commented Sep 25, 2024

Many thanks for the thoughtful review!

isolate the distribution logic only for the training logic

I'll look into distributing across processes inside the call to .train/.fine_tune rather than before. Some of the serialization issues (e.g. Pluggable._event_queue) should be solvable, I think.

log.disabled = not is_main_process()

Ah, great idea, thanks!

  1. I felt the same way! Thanks for calling it out. The idea is inspired by other implementations like Lightning Fabric. I'll be more careful about the implementation.

  2. I did add an if is_main_process() to Model.save, but I can move that in front of all the calls to model.save in trainer to be less surprising.

  3. I believe testing the main process model should be fine since the models on each process/GPU should be the same. However, the data should also be the same. The dev Dataset should already be the same on each process, but if train_loss is used, that's calculated only for the fraction of data the given process handles. I'll try torch.distributed.gather_object(train_loss) to average across all processes/gpus. This will also help for logging the training progress.

I'll follow up soon.

@helpmefindaname
Copy link
Collaborator

Hi @jeffpicard

thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair.

I tested this on on 2 RTX A 4000, by increasing the mini_batch_chunk_size to be so large that all gpu-memory is used. And the mini_batch_size to be either the same (multi-gpu) or 2x (single-gpu) to have a fair comparision in terms of batch-updates.
Also, I used clearML for logging.
With that, I can comfirm the ~2x speed improvement for 2 gpus & that the metrics at the end are about the same (although slightly worse for multi-gpu).

I observed, that somehow the logging at multi-gpu is off by 1 epoch:
image
here you see, that there was no report for epoch 1, but a epoch 21 magically appeared. I am not sure why that is.

Also, since currently the non-main-processes also log values, I could observe the following:
image
Here, the non-mainprocess is ahead of the main process, due to not having to evaluate. I am not sure, if that is good, or if we should rather syncronize the processes at the end of each epoch.
Obviously splitting the evaluation would also be an option, but I think that would imply a lot of changes that make this PR more complicated.

I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process.
Note: currently the lr-scheduler doesn't know that multi-gpu training uses a higher batch-size/less train steps:
image

using the AnnealOnPlateau schduler doesn't work, as the non-main-processes fail without eval metric.

@jeffpicard
Copy link
Author

Thanks for looking at this @helpmefindaname !

logging at multi-gpu is off by 1 epoch

Ahh, sorry about that. I think it's from the new call to .set_epoch(epoch) which was off by 1.

the non-mainprocess is ahead of the main process [...] we should rather syncronize the processes

DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.

plugins

Thanks to @HallerPatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the AnnealOnPlateau work. However, yes, if a plugin needs to synchronize information from all processes, it'll have to explicitly do that.


I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you __call__ a model rather than just use forward_loss. Changing the Trainer:

loss, datapoint_count = self.model.forward_loss(batch_step)
# becomes
loss, datapoint_count = self.model(batch_step)

fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower?

@jeffpicard
Copy link
Author

Any idea what could be going on that's making it slower?

Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized.

I've fixed most of what's mentioned above

  • Process forking now happens inside .train so all users have to do is add the multi_gpu=True argument
  • The metrics logged during training are now averaged/summed from all GPUs rather than printing the rank=0 data
  • Removed DistributedModel wrapper

I'll push these changes up.


I'm still stuck on:

  • What to do about forward vs forward_loss. In order to get the gradients to synchronize, pytorch relies on hooks run by __call__, which then invoke the special function forward. flair's trainer relies on forward_loss. Which is potentially convenient because forward can just be redirected to forward_loss. But some Model's also use forward. One option is to refactor all models so that either all use forward or none use forward but that's complex ¯_(ツ)_/¯.
  • I need to make TransformerEmbeddings work with pickle. Currently getting TypeError: DistilBertModel.__init__() got an unexpected keyword argument 'instance_parameters'.

Let me know if you have any thoughts on forward.

@jeffpicard
Copy link
Author

jeffpicard commented Oct 8, 2024

And here's an example of running it on the lastest commit

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

if __name__ == "__main__":
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"

    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)

@jeffpicard
Copy link
Author

What to do about forward vs forward_loss

Oh, this can be resolved without a big refactor by patching forward similar to what Fabric does here.

make TransformerEmbeddings work with pickle

I see there are other objects that can't pickle, like Sentences with spans. Since a lot of things might not work, I'll plan to go back to the launch_distributed approach from the example in the original description. Doing it that way doesn't require everything to be pickleable. There are a couple other options like invoking the script with torchrun scipt.py, or integrating Lightning Fabric, which has a launcher that doesn't require objects to pickle. If you connect more with either of those, or committing to making all objects pickleable, let me know.

I'll add a commit soon with these changes, which I hope can take this out of draft.

@jeffpicard
Copy link
Author

Done! @helpmefindaname and @HallerPatrick can you please take another look? This looks good to me. What do you think about merging?

  • Plugins
    • I added a property so that each plugin can set whether it gets run on all processes or just the main one. Defaults to true (seemed safer). AnnealingPlugin and LinearSchedulerPlugin are the only plugins that run on all processes.
  • lr-scheduler doesn't know that multi-gpu training uses a higher batch-size

    • I modified the calculation to take this into account.
  • Example script
    • I modified the script in the top comment to reflect a minimal example of running this.

@jeffpicard jeffpicard changed the title Draft: Add multi gpu support Add multi gpu support Oct 25, 2024
@jeffpicard
Copy link
Author

Thanks for running CI. I think the same error occurs on master. I think @helpmefindaname already fixed it in this commit but it hasn't merged yet. I've added that commit's diff to this branch, with a minor bug fix changing del to pop.

I think CI should pass now.

Copy link
Author

@jeffpicard jeffpicard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@helpmefindaname @HallerPatrick gentle reminder that action is to you, if you're able to give it another look. I think this is good and don't plan to touch it. I've added a couple comments to make reading it easier.

Comment on lines 708 to +710
f" - loss {intermittent_loss:.8f}"
f" - time (sec): {(current_time - epoch_start_time):.2f}"
f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}"
f" - samples/sec: {samples_per_second:.2f}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing the metrics that will be logged while training:

  • loss is the average loss across processes / data splits
  • samples/sec is the sum across processes (i.e. overall speed)
  • loss and momentum are technically printed from the main processes, but are equal on all processes

Comment on lines +1357 to +1359
# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
config_dict.pop("_attn_implementation_autoset", None)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this PR but necessary to get CI to pass. It's based on helpmefindaname's commit.

@@ -722,7 +772,7 @@ def train_custom(
if not determine_best_epoch_using_dev_score:
validation_scores = (train_loss,)

if epoch_train_loss < best_epoch_score:
if train_loss < best_epoch_score:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a bugfix (typo?)

@HallerPatrick
Copy link
Collaborator

Hey @jeffpicard, sorry for the late replies.

So I am taking a look now. I am testing with your example:

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed


def main():
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"

    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)


if __name__ == "__main__":
    launch_distributed(main)

I think in your earlier message you forgot to add launch_distributed. Is that right?

I also had problem running the example, because i tried running the script with torchrun/torch.distributed.launch, which is my fault. But we should definitely add documentation for multi GPU!

Then I ran into another problem, that the models that should be moved to different devices have different parameters:

Moving model on device: 0
TextClassifier(
  (embeddings): DocumentTFIDFEmbeddings()
  (decoder): Linear(in_features=8921, out_features=2, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)
Moving model on device: 1
TextClassifier(
  (embeddings): DocumentTFIDFEmbeddings()
  (decoder): Linear(in_features=9539, out_features=2, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)

The decoder has different number of features. Due to starting multiple processes and the data processing yielding different results for the data processing?

Generally data preprocessing should only be done on the main process. So this works:

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed


def main():
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"
    label_dictionary = corpus.make_label_dictionary(label_type)
    launch_distributed(train, corpus, label_type, label_dictionary)

def train(corpus, label_type, label_dictionary):
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)

if __name__ == "__main__":
    main()

I think in Hugginface all preprocessing operations are agnostic to this, but I am not sure.

I dont know if it makes sense to add some type of guards to all data processing operations (sounds like a lot of work) or just make it quite clear, that just adding multi_gpu=True to the trainer, will not work.

Any thoughts on that?

@@ -32,6 +32,7 @@ filterwarnings = [
'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub
"ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible.
"ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models.
"ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a typo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants