Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't run Mistral quantized on T4 #417

Open
2 of 4 tasks
emillykkejensen opened this issue Apr 16, 2024 · 5 comments
Open
2 of 4 tasks

Can't run Mistral quantized on T4 #417

emillykkejensen opened this issue Apr 16, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@emillykkejensen
Copy link

System Info

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000001:00:00.0 Off |                  Off |
| N/A   28C    P0             24W /   70W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

docker run --runtime nvidia --gpus all --ipc=host -p 8080:80 \
	-v $PWD/data:/data \
	ghcr.io/predibase/lorax:latest \
	--model-id mistralai/Mistral-7B-v0.1 \
	--quantize bitsandbytes-nf4

Information

  • Docker
  • The CLI directly

Tasks

  • An officially supported command
  • My own modifications

Reproduction

I'm simply trying to run mistralai/Mistral-7B-v0.1 with 4-bit quantization on my T4 with bitsandbytes-nf4! However it errors with 'Mistral model requires flash attn v2'?

2024-04-16T14:50:32.809986Z  INFO download: lorax_launcher: Successfully downloaded weights.
2024-04-16T14:50:32.810173Z  INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-04-16T14:50:38.132469Z  WARN lorax_launcher: flash_attn.py:48 Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2

2024-04-16T14:50:38.267395Z ERROR lorax_launcher: server.py:271 Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 89, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 321, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 267, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 179, in get_model
    from lorax_server.models.flash_mistral import FlashMistral
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_mistral.py", line 10, in <module>
    from lorax_server.models.custom_modeling.flash_mistral_modeling import (
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 49, in <module>
    raise ImportError("Mistral model requires flash attn v2")
ImportError: Mistral model requires flash attn v2

2024-04-16T14:50:39.215837Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

Traceback (most recent call last):

  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 89, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 321, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 267, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 179, in get_model
    from lorax_server.models.flash_mistral import FlashMistral

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_mistral.py", line 10, in <module>
    from lorax_server.models.custom_modeling.flash_mistral_modeling import (

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 49, in <module>
    raise ImportError("Mistral model requires flash attn v2")

ImportError: Mistral model requires flash attn v2
 rank=0
2024-04-16T14:50:39.314620Z ERROR lorax_launcher: Shard 0 failed to start
2024-04-16T14:50:39.314636Z  INFO lorax_launcher: Shutting down shards

Expected behavior

The model to load and run!?

@tgaddair
Copy link
Contributor

Hey @emillykkejensen, unfortunately our min supported architecture at the moment is Ampere due to the flash attention dependency. Please see system requirements here: https://github.com/predibase/lorax?tab=readme-ov-file#requirements

@tgaddair tgaddair added the enhancement New feature or request label Apr 16, 2024
@emillykkejensen
Copy link
Author

Fair enough. However, one could argue that the point of qlora among other things, is to serve on smaller (older and cheeper) GPU's that don't support ampere? Is there anything in the making, or?

@tgaddair
Copy link
Contributor

Yes, we have plans to move our attention computation over to the FlashInfer project, which is working on support for Volta and Turning GPUs. So hopefully that will address the issue.

@emillykkejensen
Copy link
Author

Sounds good 😊 I'm sure you are already aware, but in the off case your not, I can see that there is a fix in TGI? However it seems they simply fix it by loading the full model?

@nethi
Copy link

nethi commented Jul 6, 2024

Is it fair to assume that this should now work given this PR #440 is merged ? With latest versions, I seem to be able to get past FA 2 errors but seems to run into different issue #535

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants