Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add support for optimum-quanto #2000

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Aug 9, 2024

This is unfinished, only pure implementations are provided.

Resolves #1997

TODOs:

  • Documentation
  • Tests (should work on CPU!)
  • Install optimum-quanto for CI (awaits quanto release that contains fix to persistence)
  • Verify that QuantoLoraConv2d works
  • Optional: DoRA support
  • Optional: Mixed adapter batches support
  • Cleaner implementation? For now, uses private attributes _data and _scales, overriding .data did not have any effect.

State of unit tests

Since quanto layers are subclasses of their respective torch equivalents, they will generally work with PEFT methods, even if not supported explicitly. E.g. BOFT will "just work" with quanto. However, some merging etc. won't work properly, as this requires special handling for quanto. Therefore, these tests are skipped.

It could be argued that we should explicitly raise when trying to use a non-supported method with quanto. However, we don't do that in general, as we assume that a subclass relationship should mean that the method works with that module. We could do strict checking of type (not subclass), but who knows how much existing code would break for no reason because of that.

Merging tests had to be relaxed, torch.allclose would require quite a high tolerance to pass. Therefore, instead now measure that correlation is > 0.97, which is more robust to outliers.

Moreover, a bunch of tests needed to be skipped, e.g. because quanto does not support deepcopy-ing, and either the PEFT functionaliy (layer replication) or the test itself depends on copying. Also, quanto does not allow to convert the dtype (like calling model.half()).

This is unfinished, only pure implementations are provided.

TODOs:

- [  ] Documentation
- [  ] Tests (should work on CPU!)
- [  ] Whether Conv2d works is not verified yet
- [  ] Optional: DoRA support
- [  ] Optional: Mixed adapter batches support
@BenjaminBossan
Copy link
Member Author

This is what I used for "testing" so far and the results look correct:

import torch
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoModelForCausalLM

torch.manual_seed(0)
inputs = torch.arange(5).view(-1, 1)
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").eval()
with torch.inference_mode():
    output_base = model(inputs).logits
print("output_base")
print(output_base[0, 0, :5])

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

with torch.inference_mode():
    output_quantized = model(inputs).logits
print("output_quantized")
print(output_quantized[0, 0, :5])

config = LoraConfig(r=64, lora_alpha=1280, lora_dropout=0.1, init_lora_weights=False)
print("adding adapter (random)")
model = get_peft_model(model, config)
model.eval()

with torch.inference_mode():
    output_lora = model(inputs).logits
    print("output_lora")
    print(output_lora[0, 0, :5])

    with model.disable_adapter():
        output_disabled = model(inputs).logits
        print("output_disabled")
        print(output_disabled[0, 0, :5])

    output_after_disabled = model(inputs).logits
    print("output_after_disabled")
    print(output_after_disabled[0, 0, :5])

model.merge_adapter()
with torch.inference_mode():
    output_merged = model(inputs).logits
print("output_merged")
print(output_merged[0, 0, :5])

model.unmerge_adapter()
with torch.inference_mode():
    output_unmerged = model(inputs).logits
print("output_unmerged")
print(output_unmerged[0, 0, :5])

unloaded = model.merge_and_unload()
with torch.inference_mode():
    output_unloaded = unloaded(inputs).logits
print("output_unloaded")
print(output_unloaded[0, 0, :5])

If someone wants to test this, they can checkout this branch or they can copy-paste the layer definitions and then dynamically dispatch to the new layers using the normal PEFT release:

from optimum.quanto import QConv2d, QLinear

# copy code for QuantoLoraLinear and QuantoLoraConv2d

custom_module_mapping = {QConv2d: QuantoLoraConv2d, QLinear: QuantoLoraLinear}
config = LoraConfig(...)
config._register_custom_module(custom_module_mapping)

@bghira
Copy link

bghira commented Aug 13, 2024

2024-08-12 19:29:58,243 [INFO] (SaveHookManager) Loading LoRA weights from Path: /Users/bghira/Training/flux/models/checkpoint-10
'time_text_embed.timestep_embedder.linear_1.weight._data'
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/train.py", line 2761, in <module>
    main()
  File "/Users/bghira/src/SimpleTuner/train.py", line 1566, in main
    accelerator.load_state(os.path.join(args.output_dir, path))
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 3098, in load_state
    hook(models, input_dir)
  File "/Users/bghira/src/SimpleTuner/helpers/training/save_hooks.py", line 416, in load_model_hook
    self._load_lora(models=models, input_dir=input_dir)
  File "/Users/bghira/src/SimpleTuner/helpers/training/save_hooks.py", line 335, in _load_lora
    incompatible_keys = set_peft_model_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 397, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

i pulled the new peft build from your branch and applied the mapping to the LoraConfig, but I still see this error when it comes time to loading the state dict. I think the problem is on the Diffusers side here. the set_peft_model_state_dict

@BenjaminBossan
Copy link
Member Author

Thanks for reporting this @bghira. I think it's an issue with optimum-quanto. I already reported this here.

@BenjaminBossan BenjaminBossan changed the title [WIP][FEAT] Add support for optimum-quanto [FEAT] Add support for optimum-quanto Sep 17, 2024
@BenjaminBossan
Copy link
Member Author

Status update: Optimum-quanto v0.2.5 is released and is the minimum version for this to work. Moreover, huggingface/transformers#31732 is merged but it's not part of the latest transformers release yet. As we don't want to depend on an unreleased transformers version and as we're not in a huge hurry, let's wait for the next transformers release.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Sorry for the delay !

@BenjaminBossan
Copy link
Member Author

Failing Windows CI is caused by a known issue in quanto that's already been fixed, but there was no release yet.

@bghira
Copy link

bghira commented Feb 2, 2025

still eagerly awaiting this fix :)

@bghira
Copy link

bghira commented Feb 2, 2025

@sayakpaul with this patch in PEFT, diffusers has a circular import error on git main branch.

@bghira
Copy link

bghira commented Feb 2, 2025

minimal reproducer is to install git diffusers and this branch for peft, run python command line shell and:

>>> from diffusers.loaders.lora_base import load_lora_weights

and see the implosion:

RuntimeError: Failed to import diffusers.loaders.peft because of the following error (look up to see its traceback):
cannot import name '_fetch_state_dict' from partially initialized module 'diffusers.loaders.lora_base' (most likely due to a circular import) (/venv/lib/python3.11/site-packages/diffusers/loaders/lora_base.py)

it also fails on diffusers-v0.32.2.

both PEFT git branch and Diffusers git branch run together fine, but this branch does not 🤔

@bghira
Copy link

bghira commented Feb 2, 2025

rebasing this patch on top of the latest peft main branch doesn't help the situation.

  File "/home/bghira/src/discord-tron-client/.venv/lib/python3.11/site-packages/diffusers/utils/import_utils.py", line 943, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/bghira/src/discord-tron-client/.venv/lib/python3.11/site-packages/diffusers/loaders/peft.py", line 38, in <module>
    from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
ImportError: cannot import name '_fetch_state_dict' from partially initialized module 'diffusers.loaders.lora_base' (most likely due to a circular import) (/home/bghira/src/discord-tron-client/.venv/lib/python3.11/site-packages/diffusers/loaders/lora_base.py)

this is because lora_base is trying to import from import_utils and import_utils is trying to load from lora_base.

@sayakpaul
Copy link
Member

If I do:

from diffusers import DiffusionPipeline 

repo_id = "DavyMorgan/tiny-sd35-pipe"
pipeline = DiffusionPipeline.from_pretrained(repo_id)

It works as expected.

@bghira
Copy link

bghira commented Feb 3, 2025

yes, that doesn't seem to trigger the bug, i created a new venv to try and get precise steps and of course, i can load the lora loader there 🤔 like taking the car to the mechanic and the sound stops.

@bghira
Copy link

bghira commented Feb 3, 2025

ok;

python3.11 -m venv .venv
. .venv/bin/activate
pip install git+https://github.com/huggingface/diffusers
pip install -U git+https://github.com/BenjaminBossan/peft@feat-support-optimum-quanto
pip install optimum-quanto
python
>>> from diffusers.loaders import lora_base
[snip]
RuntimeError: Failed to import diffusers.models.transformers.pixart_transformer_2d because of the following error (look up to see its traceback):
Failed to import diffusers.loaders.peft because of the following error (look up to see its traceback):
cannot import name '_fetch_state_dict' from partially initialized module 'diffusers.loaders.lora_base' (most likely due to a circular import) (/home/bghira/src/peft-quanto-fix/.venv/lib/python3.11/site-packages/diffusers/loaders/lora_base.py)

if optimum-quanto isn't installed, it will not error out. then you see the circular import finally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support optimum-quanto
5 participants