diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 6c05490616..ad943d6b87 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case. - + DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence. @@ -94,6 +94,21 @@ Packing may cause batch contamination, where adjacent sequences influence one an +## Liger for reducing peak memory usage + +[To complete] + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_loss=True) +``` + ## Disabling model gathering for generation in online methods When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index c4a0232ee3..5066dc98f8 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -29,7 +29,12 @@ PreTrainedTokenizerBase, is_vision_available, ) -from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision +from transformers.testing_utils import ( + require_liger_kernel, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, + require_vision, +) from trl import DPOConfig, DPOTrainer, FDivergenceType @@ -1227,6 +1232,75 @@ def test_padding_free(self): if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_liger_kernel + @parameterized.expand([(0.1,), (0.5,)]) + def test_dpo_trainer_with_liger(self, beta): + """Test DPO trainer with Liger loss enabled. + + This test verifies that: + 1. Training runs successfully with Liger loss + 2. Model parameters update as expected + 3. Loss values are reasonable and finite + 4. Training works with both default and custom beta values + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=beta, + use_liger_loss=True, # Enable Liger loss + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, # Add reference model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Store initial parameters + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + train_output = trainer.train() + + # Verify training completed successfully + self.assertIsNotNone(train_output) + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Verify loss is finite + self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + + # Check parameters have been updated + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Only check non-zero parameters + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + # Verify new parameters are finite + self.assertTrue(torch.isfinite(new_param).all()) + + # Verify model can still do forward pass after training + dummy_batch = next(iter(trainer.get_train_dataloader())) + model_inputs = { + "input_ids": dummy_batch["prompt_input_ids"], + "attention_mask": dummy_batch["prompt_attention_mask"], + } + with torch.no_grad(): + output = trainer.model(**model_inputs) + self.assertIsNotNone(output) + self.assertIsNone(output.loss) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7c18e11cc..2ab34cd7f9 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -119,6 +119,11 @@ class DPOConfig(TrainingArguments): - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in @@ -301,6 +306,18 @@ class DPOConfig(TrainingArguments): ], }, ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_loss` is `True`." + }, + ) beta: float = field( default=0.1, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 218f1af5a9..9502509095 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -50,7 +50,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_xpu_available +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -75,6 +75,9 @@ if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + if is_wandb_available(): import wandb @@ -83,6 +86,13 @@ import deepspeed +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor: + """Shift input ids one token to the right, and pad with pad_token_id""" + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + @dataclass class DataCollatorForPreference(DataCollatorMixin): """ @@ -388,6 +398,17 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) + # Liger kernel + if args.use_liger_loss: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=args.label_pad_token_id, beta=args.beta, use_ref_model=not args.reference_free + ) + self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id @@ -1093,6 +1114,174 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards + def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], self.padding_value, model.config.decoder_start_token_id + ) + # 3. Get decoder outputs + decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + ref_encoder_outputs = self.ref_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = self.ref_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + + labels = concatenated_batch["completion_input_ids"] + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + + # Get the base model outputs (before LM head) + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name, model) + + outputs = base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # Get reference hidden states if needed + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + if hasattr(self.ref_model, "get_decoder"): + ref_base_model = self.ref_model.get_decoder() + else: + ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) + + ref_outputs = ref_base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + elif not self.reference_free: + if hasattr(model, "get_decoder"): + ref_base_model = model.get_decoder() + else: + ref_base_model = getattr(model, self.args.base_model_attribute_name, model) + with self.null_ref_context(): + ref_outputs = ref_base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + labels = input_ids[:, 1:] # Shift right for casual LM + + # Get the LM head + lm_head = model.get_output_embeddings() + + # Get reference model weights if needed + ref_weight = None + ref_bias = None + if not self.reference_free: + if self.ref_model is not None: + ref_lm_head = self.ref_model.get_output_embeddings() + else: + with self.null_ref_context(): + ref_lm_head = model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states if not self.reference_free else None, + ref_weight=ref_weight if not self.reference_free else None, + ref_bias=ref_bias if not self.reference_free else None, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), + ) = loss_output + + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. @@ -1224,8 +1413,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.args.rpo_alpha is not None: # Only use the chosen logits for the RPO loss - chosen_logits = logits[:num_examples] - chosen_labels = labels[:num_examples] + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples] # Compute the log probabilities of the labels output["nll_loss"] = F.cross_entropy( @@ -1268,18 +1457,24 @@ def get_batch_loss_metrics( """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} - model_output = self.concatenated_forward(model, batch) - - # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model - if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: - ref_chosen_logps = batch["ref_chosen_logps"] - ref_rejected_logps = batch["ref_rejected_logps"] + if self.args.use_liger_loss and self.loss_type == "sigmoid": + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] else: - ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + model_output = self.concatenated_forward(model, batch) - losses, chosen_rewards, rejected_rewards = self.dpo_loss( - model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps - ) + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps + ) reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: