Skip to content

Commit

Permalink
Bug fixes (#1195)
Browse files Browse the repository at this point in the history
* Fix TRL

* Update mistral.py

* Patch processing_class

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Installation guide (#1165)

* chore: update chat_templates.py (#1166)

orginal -> original

* Disable Flex Attention

* Update tokenizer_utils.py

* Update _utils.py

* n_items

* Update cross_entropy_loss.py

* Fix DPO, ORPO

* Update _utils.py

* Update _utils.py

* fix/transformers-unpack (#1180)

* Fix DPO, ORPO (#1177)

* Fix TRL

* Update mistral.py

* Patch processing_class

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Installation guide (#1165)

* chore: update chat_templates.py (#1166)

orginal -> original

* Disable Flex Attention

* Update tokenizer_utils.py

* Update _utils.py

* n_items

* Update cross_entropy_loss.py

* Fix DPO, ORPO

* Update _utils.py

---------

Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>

* Add warning for missing Unpack and KwargsForCausalLM in older Transformers versions

---------

Co-authored-by: Daniel Han <[email protected]>
Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>

* Update cross_entropy_loss.py

* Update _utils.py

* Update _utils.py

* donot upcast lm_head and embeddings to float32 (#1186)

* Cleanup upcast logs (#1188)

* Fix/phi-longrope (#1193)

* Enhance rotary embedding handling in LlamaAttention and LongRopeRotaryEmbedding

* Typo

* Improve rotary embedding handling in LlamaAttention to prevent errors with short KV cache

* Update llama.py

* Update llama.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update transformers

---------

Co-authored-by: timothelaborie <[email protected]>
Co-authored-by: Ikko Eltociear Ashimine <[email protected]>
Co-authored-by: Edd <[email protected]>
Co-authored-by: Datta Nimmaturi <[email protected]>
  • Loading branch information
5 people authored Oct 26, 2024
1 parent 9ca13b8 commit d76eda4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def patch_gradient_accumulation_fix(Trainer):
"Unsloth: We fixed a gradient accumulation bug, "\
"but it seems like you don't have the latest transformers version!\n"\
"Please update transformers, TRL and unsloth via:\n"\
'`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`'
'`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`'
)
pass

Expand Down
26 changes: 15 additions & 11 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def LlamaAttention_fast_forward_inference(

# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)

# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
Expand Down Expand Up @@ -1122,7 +1126,7 @@ def get_cached(self, seq_len = None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1248,7 +1252,7 @@ def get_cached(self, seq_len = None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1363,7 +1367,7 @@ def get_cached(self, seq_len = None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1952,10 +1956,10 @@ def get_peft_model(
# Offload!
# [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
if "embed_tokens" in new_target_modules:
print("Unsloth: Casting embed_tokens to float32")
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")

model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
.to(device = "cuda:0", non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)

# [TODO] Move old embed_tokens to CPU - should be disk!
Expand All @@ -1965,10 +1969,10 @@ def get_peft_model(
pass

if "lm_head" in new_target_modules:
print("Unsloth: Casting lm_head to float32")
print("Unsloth: Training lm_head in mixed precision to save VRAM")

model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
.to(device = "cuda:0", non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)

# [TODO] Move old lm_head to CPU - should be disk!
Expand Down Expand Up @@ -2203,18 +2207,18 @@ def get_peft_model(

# Now patch lm_head and embed_tokens
if train_embed_tokens:
print("Unsloth: Casting embed_tokens to float32")
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
.to(device = "cuda:0", non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
pass

if train_lm_head:
print("Unsloth: Casting lm_head to float32")
print("Unsloth: Training lm_head in mixed precision to save VRAM")
assert(hasattr(model.model.lm_head, "modules_to_save"))
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
.to(device = "cuda:0", non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
pass

Expand Down
2 changes: 1 addition & 1 deletion unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def patch_sft_trainer_tokenizer():
" from packaging.version import Version\n"\
" if Version(transformers_version) <= Version('4.45.2'):\n"\
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
" '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\
" '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"\
"except:\n"\
" pass\n"\
"\n\n"
Expand Down

0 comments on commit d76eda4

Please sign in to comment.