Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jetson (aarch64) support #724

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Jetson (aarch64) support #724

wants to merge 6 commits into from

Conversation

jasl
Copy link

@jasl jasl commented Dec 14, 2023

I refactored setup.py to make it work on my Jetson AGX Orin, I think it also helps for future ARM + GPU platforms

I don't want to make it complex so I just allow to set CUDA gencode from ENV, Jetson is compute_87, sm_87

@jasl jasl changed the title Jetson support Jetson (aarch64) support Dec 14, 2023
setup.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
jasl and others added 4 commits December 30, 2023 18:05
Co-authored-by: Aaron Gokaslan <[email protected]>
Co-authored-by: Aaron Gokaslan <[email protected]>
Co-authored-by: Aaron Gokaslan <[email protected]>
@FenardH
Copy link

FenardH commented May 2, 2024

Hello, thanks for the amazing job.

I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

@jasl
Copy link
Author

jasl commented May 2, 2024

Hello, thanks for the amazing job.

I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

Which version of Jetpack you're using? I just tried on JP 6.0 DP

@jasl
Copy link
Author

jasl commented May 2, 2024

I'm waiting for the JP 6.0 production release, I guess we just need to let the setup.py support sm_87

@FenardH
Copy link

FenardH commented May 2, 2024

Hello, thanks for the amazing job.
I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

Which version of Jetpack you're using? I just tried on JP 6.0 DP

Hello,

Thank you for the quick reply.

Mine is JP 5.1.2. Couple months ago, just back to the moment of release of flash attention 2, I tried to install it with setting compute_87 or sm_87 but both attmpts were failed with the same JP. Do you have any ideas about what's wrong here? Thank you again.

Best regards,

@jasl
Copy link
Author

jasl commented May 2, 2024

Hello, thanks for the amazing job.
I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

Which version of Jetpack you're using? I just tried on JP 6.0 DP

Hello,

Thank you for the quick reply.

Mine is JP 5.1.2. Couple months ago, just back to the moment of release of flash attention 2, I tried to install it with setting compute_87 or sm_87 but both attmpts were failed with the same JP. Do you have any ideas about what's wrong here? Thank you again.

Best regards,

I haven't tried on 5.1.x. I guess the reason is that the CUDA is too old. I have to upgrade to 6.0 because Ubuntu 18.04, CUDA 11.4, and Python 3.7 are too old to run recent versions of LLM and Stable Diffusion

@FenardH
Copy link

FenardH commented May 3, 2024

Hello, thanks for the amazing job.
I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

Which version of Jetpack you're using? I just tried on JP 6.0 DP

Hello,
Thank you for the quick reply.
Mine is JP 5.1.2. Couple months ago, just back to the moment of release of flash attention 2, I tried to install it with setting compute_87 or sm_87 but both attmpts were failed with the same JP. Do you have any ideas about what's wrong here? Thank you again.
Best regards,

I haven't tried on 5.1.x. I guess the reason is that the CUDA is too old. I have to upgrade to 6.0 because Ubuntu 18.04, CUDA 11.4, and Python 3.7 are too old to run recent versions of LLM and Stable Diffusion

Hello again, I upgraded Orin to JP 6.0 DP today and tried to install flash_attn 2 again with your fork (branch aarch64). The upgraded JP eventually did not help for the correct installation. I noticed that the CUDA gencode was with compute_90 and sm_90 while compiling instead of 87 for Orin. Could you please share more info how you install the package from source? Thank you.

@jasl
Copy link
Author

jasl commented May 3, 2024

Hello, thanks for the amazing job.
I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you.

Which version of Jetpack you're using? I just tried on JP 6.0 DP

Hello,
Thank you for the quick reply.
Mine is JP 5.1.2. Couple months ago, just back to the moment of release of flash attention 2, I tried to install it with setting compute_87 or sm_87 but both attmpts were failed with the same JP. Do you have any ideas about what's wrong here? Thank you again.
Best regards,

I haven't tried on 5.1.x. I guess the reason is that the CUDA is too old. I have to upgrade to 6.0 because Ubuntu 18.04, CUDA 11.4, and Python 3.7 are too old to run recent versions of LLM and Stable Diffusion

Hello again, I upgraded Orin to JP 6.0 DP today and tried to install flash_attn 2 again with your fork (branch aarch64). The upgraded JP eventually did not help for the correct installation. I noticed that the CUDA gencode was with compute_90 and sm_90 while compiling instead of 87 for Orin. Could you please share more info how you install the package from source? Thank you.

I don't want to make it complex (Jetson isn't popular) so the PR actually introduces an env CUDA_GENCODE so you can override it when compiling on the Jetson platform

You can use this command

MAX_JOBS=8 FORCE_BUILD=True CUDA_GENCODE='arch=compute_87,code=sm_87' pip3 wheel --wheel-dir=dist --no-deps --verbose .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants