Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM for training llama #1900

Open
dkapur17 opened this issue Jan 7, 2025 · 5 comments
Open

OOM for training llama #1900

dkapur17 opened this issue Jan 7, 2025 · 5 comments
Labels
question Further information is requested

Comments

@dkapur17
Copy link

dkapur17 commented Jan 7, 2025

I'm trying to use the llama-3.2-1B model with the Python API on a compute with 4 Tesla V100s (4*16GB), but the process keeps failing due to OOM. Watching nvidia-smi, I see the utilization shoot up to 16GB on each gpu and then the process dies. The 1B model should work with much lesser VRAM from my understanding, or maybe I'm doing something incorrect. Here is my code:

class LitLLM(L.LightningModule):
    def __init__(self, tokenizer_dir=None, trainer_ckpt_path=None):
        super().__init__()
 
        self.llm = LLM.load("meta-llama/Llama-3.2-1B", distribute=None, access_token=os.getenv("HF_TOKEN"))
        self.trainer_ckpt_path = trainer_ckpt_path

    def setup(self, stage):
        self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)
        
    def training_step(self, batch):
        logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch):
        logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
        self.log("validation_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        warmup_steps = 10
        optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
        return [optimizer], [scheduler]


batch_size = 2
accumulate_grad_batches = 1

lit_model = LitLLM()
data = Alpaca2k()

data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)

trainer = L.Trainer(
    devices=4,
    accelerator="cuda",
    max_epochs=1,
    accumulate_grad_batches=accumulate_grad_batches,
    precision="bf16-true",
)
trainer.fit(lit_model, data)

The process dies before even the first training pass. I tried a few approaches with quantization, by defining quantize (and other params) in self.llm.distribute in the setup method as well, but none of these approaches seem to work. Any ideas on what I might be doing wrong? Thanks.

@dkapur17 dkapur17 added the question Further information is requested label Jan 7, 2025
@rasbt
Copy link
Collaborator

rasbt commented Jan 7, 2025

Thanks for the feedback. It does work on 4 x L4s, which have 24 Gb each. I can see that the usage is around 22-24 GB. Other than trying a smaller batch size or block size, or perhaps a different multi-GPU strategy, I am not sure how this can be improved.

@dkapur17
Copy link
Author

dkapur17 commented Jan 8, 2025

@rasbt thanks for the quick rely. So is it taking 22GB in total across the GPUs or on each GPU? I would think a sequential load strategy could help split the model across the GPUs and 64GB should be enough for it, but when using distribute it looks like it conflicts with the trainer. What would be the right way to distribute the model across the GPUs and then train it using the trainer? Also any inputs on quantizing the model?

@rasbt
Copy link
Collaborator

rasbt commented Jan 8, 2025

It was on each GPU. I think that it uses substantially less RAM than 22 x 4 in total though; it might be that it works just fine on a single GPU with 40 Gb but I haven't tried. You could also consider an FSDP strategy with cpu_offload=True to reduce GPU RAM usage, but this will then take a bit longer to train. Alternatively, the first thing I'd try in your case is to set the batch_size to 1 and then increase the gradient accumulation steps.

@dkapur17
Copy link
Author

dkapur17 commented Jan 8, 2025

Interestingly, using the CLI tool, I'm even able to finetune Llama 3.1 8B with no quantization across the 4 GPUs, although I suspect that's thanks to LoRA, will need to check if it works with the Python API as well.

@rasbt
Copy link
Collaborator

rasbt commented Jan 8, 2025

Ah yes, litgpt finetune ... uses LoRA by default. For full finetuning, it's litgpt finetune_full ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants