You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The subpackage MultiScaleDeformableAttention.so in this project causes an unexpected behavior in transformers.DeformableDetrModel. The behavior can be summarized by:
If ninja is not installed, nothing wrong will happen.
If CUDA toolkits and the devel files are installed locally, no matter whether ninja is installed or not, nothing wrong will happen.
If CUDA devel files are not installed and we only use the PyPI to install the CUDA run time, in this case, every time when accessing transformers.DeformableDetrModel, two error messages will show:
Could not load the custom kernel for multi-scale deformable attention: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/pyxxx_cuxxx/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Note that the above messages will NOT show as long as we do not install ninja. It seems that these messages will not actually influence the processing of transformers.DeformableDetrModel. The model is still able to produce outputs with these messages exist.
Background
Now, many deep learning frameworks will automatically install their own CUDA runtime when using pip wheels. For example,
python -m pip install torch
Running the above command on Linux (1) with GPU (drivers) available and (2) with CUDA not installed, the installation will contain CUDA run time libraries.
However, such an installation will not install the developed files (including the headers and some shared files). In this case, torch.cuda.is_avaiable() returns True, but the CUDA_HOME environment variable is not available.
For a strange reason, this pacakge MultiScaleDeformableAttention.so will require the environment variable CUDA_HOME if and only if ninja is installed. This behavior can be inspected when using transformers package.
Reproduce the error
Start a new docker container by
docker run --gpus all -it --rm --shm-size=1g python:3.10-slim bash
Run the same script again, this time, the following warning messages will show
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler (c++) is not compatible with the compiler Pytorch was
built with for this platform, which is g++ on linux. Please
use g++ to to compile your extension. Alternatively, you may
compile PyTorch from source using c++, and then you can also use
c++ to compile your extension.
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
warnings.warn(WRONG_COMPILER_WARNING.format(
Could not load the custom kernel for multi-scale deformable attention: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Could not load the custom kernel for multi-scale deformable attention: /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/MultiScaleDeformableAttention.so: cannot open shared object file: No such file or directory
Certainly, /root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/ is empty.
This issue was initially proposed in transformers repository. The maintainer of transformers suggest that I should submit the issue here.
Personally, I think it is not reasonable that the dynamic library MultiScaleDeformableAttention would ask for the devel files in the run time. By the way, this behavior only exists when the ninja is installed. Therefore, I am suspecting whether there should be a check for blocking the behaviors related to ninja in the run time.
Further materials
I have tried to use some Docker images with CUDA preinstalled. The tested images are:
nvcr.io/nvidia/pytorch:24.12-py3 (Ubuntu 24.04)
cainmagi/deformable-detr (Debian 12)
In both images, we have CUDA toolkits and the devel files preinstalled. In this case,
MultiScaleDeformableAttention can be built successfully
transformers will not complain for MultiScaleDeformableAttention.so even if we do not built MultiScaleDeformableAttention and copy it to /root/.cache/torch_extensions/.
The text was updated successfully, but these errors were encountered:
Short explanation
The subpackage
MultiScaleDeformableAttention.so
in this project causes an unexpected behavior intransformers.DeformableDetrModel
. The behavior can be summarized by:ninja
is not installed, nothing wrong will happen.ninja
is installed or not, nothing wrong will happen.transformers.DeformableDetrModel
, two error messages will show:ninja
. It seems that these messages will not actually influence the processing oftransformers.DeformableDetrModel
. The model is still able to produce outputs with these messages exist.Background
Now, many deep learning frameworks will automatically install their own CUDA runtime when using
pip
wheels. For example,Running the above command on Linux (1) with GPU (drivers) available and (2) with CUDA not installed, the installation will contain CUDA run time libraries.
However, such an installation will not install the developed files (including the headers and some shared files). In this case,
torch.cuda.is_avaiable()
returnsTrue
, but theCUDA_HOME
environment variable is not available.For a strange reason, this pacakge
MultiScaleDeformableAttention.so
will require the environment variableCUDA_HOME
if and only ifninja
is installed. This behavior can be inspected when usingtransformers
package.Reproduce the error
/root/.cache/torch_extensions/py310_cu124/MultiScaleDeformableAttention/
is empty.This issue was initially proposed in
transformers
repository. The maintainer oftransformers
suggest that I should submit the issue here.The related issues are:
MultiScaleDeformableAttention.so
is not found in/root/.cache/torch_extensions
ifninja
is installed withtransformers
huggingface/transformers#35349Personally, I think it is not reasonable that the dynamic library
MultiScaleDeformableAttention
would ask for the devel files in the run time. By the way, this behavior only exists when theninja
is installed. Therefore, I am suspecting whether there should be a check for blocking the behaviors related toninja
in the run time.Further materials
I have tried to use some Docker images with CUDA preinstalled. The tested images are:
nvcr.io/nvidia/pytorch:24.12-py3
(Ubuntu 24.04)cainmagi/deformable-detr
(Debian 12)In both images, we have CUDA toolkits and the devel files preinstalled. In this case,
MultiScaleDeformableAttention
can be built successfullytransformers
will not complain forMultiScaleDeformableAttention.so
even if we do not builtMultiScaleDeformableAttention
and copy it to/root/.cache/torch_extensions/
.The text was updated successfully, but these errors were encountered: