Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] liger DPO support #2568

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely, and by default, TRL does not mo

To reduce memory usage, it’s important to truncate sequences to a reasonable length. Even discarding just a few tokens from the dataset can result in significant memory savings by minimizing unnecessary padding. Truncation is a good practice and should always be applied to ensure efficient use of resources. While the truncation limit doesn’t need to be overly restrictive, setting a sensible value is essential for optimal performance.

<hfoptions id="dpo">
<hfoptions id="truncation">
<hfoption id="DPO">

DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
Expand Down Expand Up @@ -84,4 +84,22 @@ training_args = SFTConfig(..., packing=True, max_seq_length=512)

Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).

</Tip>
</Tip>

## Liger for reducing peak memory usage

[To complete]

<hfoptions id="liger">
<hfoption id="DPO">

To use Liger for reducing peak memory usage, use the following code snippet:

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)
```

</hfoption>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif I've added this section in the new guide for reducing memory usage, if you've words to fill it

</hfoptions>
76 changes: 75 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
PreTrainedTokenizerBase,
is_vision_available,
)
from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision
from transformers.testing_utils import (
require_liger_kernel,
require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_vision,
)

from trl import DPOConfig, DPOTrainer, FDivergenceType

Expand Down Expand Up @@ -1197,6 +1202,75 @@ def test_padding_free(self):
if param.sum() != 0:
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))

@require_liger_kernel
@parameterized.expand([(0.1,), (0.5,)])
def test_dpo_trainer_with_liger(self, beta):
"""Test DPO trainer with Liger loss enabled.

This test verifies that:
1. Training runs successfully with Liger loss
2. Model parameters update as expected
3. Loss values are reasonable and finite
4. Training works with both default and custom beta values
"""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=beta,
use_liger_loss=True, # Enable Liger loss
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model, # Add reference model
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# Store initial parameters
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
train_output = trainer.train()

# Verify training completed successfully
self.assertIsNotNone(train_output)
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Verify loss is finite
self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"]))

# Check parameters have been updated
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# Only check non-zero parameters
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
# Verify new parameters are finite
self.assertTrue(torch.isfinite(new_param).all())

# Verify model can still do forward pass after training
dummy_batch = next(iter(trainer.get_train_dataloader()))
model_inputs = {
"input_ids": dummy_batch["prompt_input_ids"],
"attention_mask": dummy_batch["prompt_attention_mask"],
}
with torch.no_grad():
output = trainer.model(**model_inputs)
self.assertIsNotNone(output)
self.assertIsNone(output.loss)


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
17 changes: 17 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class DPOConfig(TrainingArguments):
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.

use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -291,6 +296,18 @@ class DPOConfig(TrainingArguments):
],
},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when "
"`use_liger_loss` is `True`."
},
)
beta: float = field(
default=0.1,
metadata={
Expand Down
Loading
Loading