-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi gpu support #3548
base: master
Are you sure you want to change the base?
Add multi gpu support #3548
Conversation
Hello @jeffpicard this is awesome, thanks for the PR! @helpmefindaname @HallerPatrick can you take a look? |
Hey @jeffpicard, thanks for the PR. I tested your changes with different number of GPUs and can, more or less, reproduce your speedups! I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with: if flair.distributed:
self.model = DistributedModel(self.model, device_ids=[flair.device.index])
# Disable logging in distributed mode for all but the main process
log.disabled = not is_main_process() Here some points from my side:
On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :) Thank you! |
Many thanks for the thoughtful review!
I'll look into distributing across processes inside the call to
Ah, great idea, thanks!
I'll follow up soon. |
Hi @jeffpicard thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair. I tested this on on 2 RTX A 4000, by increasing the I observed, that somehow the logging at multi-gpu is off by 1 epoch: Also, since currently the non-main-processes also log values, I could observe the following: I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process. using the |
Thanks for looking at this @helpmefindaname !
Ahh, sorry about that. I think it's from the new call to
DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.
Thanks to @HallerPatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you
fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower? |
Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized. I've fixed most of what's mentioned above
I'll push these changes up. I'm still stuck on:
Let me know if you have any thoughts on |
And here's an example of running it on the lastest commit from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
if __name__ == "__main__":
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True) |
Oh, this can be resolved without a big refactor by patching forward similar to what Fabric does here.
I see there are other objects that can't pickle, like Sentences with spans. Since a lot of things might not work, I'll plan to go back to the I'll add a commit soon with these changes, which I hope can take this out of draft. |
Done! @helpmefindaname and @HallerPatrick can you please take another look? This looks good to me. What do you think about merging?
|
Thanks for running CI. I think the same error occurs on master. I think @helpmefindaname already fixed it in this commit but it hasn't merged yet. I've added that commit's diff to this branch, with a minor bug fix changing I think CI should pass now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@helpmefindaname @HallerPatrick gentle reminder that action is to you, if you're able to give it another look. I think this is good and don't plan to touch it. I've added a couple comments to make reading it easier.
f" - loss {intermittent_loss:.8f}" | ||
f" - time (sec): {(current_time - epoch_start_time):.2f}" | ||
f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}" | ||
f" - samples/sec: {samples_per_second:.2f}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summarizing the metrics that will be logged while training:
- loss is the average loss across processes / data splits
- samples/sec is the sum across processes (i.e. overall speed)
- loss and momentum are technically printed from the main processes, but are equal on all processes
# do not switch the attention implementation upon reload. | ||
config_dict["attn_implementation"] = self.model.config._attn_implementation | ||
config_dict.pop("_attn_implementation_autoset", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated to this PR but necessary to get CI to pass. It's based on helpmefindaname's commit.
@@ -722,7 +772,7 @@ def train_custom( | |||
if not determine_best_epoch_using_dev_score: | |||
validation_scores = (train_loss,) | |||
|
|||
if epoch_train_loss < best_epoch_score: | |||
if train_loss < best_epoch_score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is a bugfix (typo?)
Hey @jeffpicard, sorry for the late replies. So I am taking a look now. I am testing with your example: from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed
def main():
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)
if __name__ == "__main__":
launch_distributed(main) I think in your earlier message you forgot to add I also had problem running the example, because i tried running the script with Then I ran into another problem, that the models that should be moved to different devices have different parameters:
The decoder has different number of features. Due to starting multiple processes and the data processing yielding different results for the data processing? Generally data preprocessing should only be done on the main process. So this works: from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed
def main():
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
launch_distributed(train, corpus, label_type, label_dictionary)
def train(corpus, label_type, label_dictionary):
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)
if __name__ == "__main__":
main() I think in Hugginface all preprocessing operations are agnostic to this, but I am not sure. I dont know if it makes sense to add some type of guards to all data processing operations (sounds like a lot of work) or just make it quite clear, that just adding Any thoughts on that? |
@@ -32,6 +32,7 @@ filterwarnings = [ | |||
'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub | |||
"ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible. | |||
"ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models. | |||
"ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a typo?
Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.
There's a related issue here, and a past PR that never ended up merging.
The approach:
DistributedDataParallel
rather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences.launch_distributed
mechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes insideTrainer.train_custom
. However, I ran into problems doing it this way (e.g.TransformerEmbeddings
andPluggable._event_queue
would not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.There are still TODOs. For example, the logging inside
.train_custom
prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements likeif is_main_process():
ortorch.distributed.gather_object
to aggregate metrics across processes, similar to what's done for the eval steps in this PR.Example usage: