Skip to content

Commit

Permalink
Fix DPO, ORPO
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Oct 24, 2024
1 parent e561366 commit 4ff247a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
1 change: 0 additions & 1 deletion unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def fast_cross_entropy_loss(
)
if n_items is None:
n_items = torch.count_nonzero(labels != -100)
print(n_items)
return loss.sum() / n_items
pass

Expand Down
33 changes: 31 additions & 2 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,10 +1172,10 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):

def patch_gradient_accumulation_fix(Trainer):
# Fixes gradient accumulation
import inspect
if hasattr(Trainer, "get_batch_samples"):
from inspect import getsource
if \
not getsource(Trainer.get_batch_samples).strip()\
not inspect.getsource(Trainer.get_batch_samples).strip()\
.endswith("return batch_samples, num_items_in_batch"):

raise NotImplementedError("Unsloth: Please make a Github issue immediately!!")
Expand All @@ -1198,4 +1198,33 @@ def patch_gradient_accumulation_fix(Trainer):
'`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`'
)
pass

# Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return

function = inspect.getsource(Trainer.training_step)
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)

# Import all variables that need importing
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in function: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())

# Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
# summed it up and did the division before hand, we have to negate it.
function = function.replace(
"loss *= self.args.gradient_accumulation_steps",
"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
)
function = function.replace("def training_step", "def _unsloth_training_step", 1)
exec(function, globals())
Trainer.training_step = _unsloth_training_step
pass
2 changes: 1 addition & 1 deletion unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _free_cached_model(model):

def _merge_lora(layer, name):

bias = None
bias = getattr(layer, "bias", None)
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
# Is LoRA so we need to merge!
W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
Expand Down

0 comments on commit 4ff247a

Please sign in to comment.