Skip to content

Commit

Permalink
dont import numba cuda indirectly via cython
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Nov 18, 2024
1 parent e29d055 commit d508aa0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
3 changes: 1 addition & 2 deletions python/cudf/cudf/_lib/aggregation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import pylibcudf

import cudf
from cudf._lib.types import SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES
from cudf.utils import cudautils

_agg_name_map = {
"COUNT_VALID": "COUNT",
Expand Down Expand Up @@ -196,7 +195,7 @@ class Aggregation:
# Handling UDF type
nb_type = numpy_support.from_dtype(kwargs['dtype'])
type_signature = (nb_type[:],)
ptx_code, output_dtype = cudautils.compile_udf(op, type_signature)
ptx_code, output_dtype = cudf.utils.cudautils.compile_udf(op, type_signature)
output_np_dtype = cudf.dtype(output_dtype)
if output_np_dtype not in SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES:
raise TypeError(f"Result of window function has unsupported dtype {op[1]}")
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def _setup_numba():
if driver_version < (12, 0):
patch_numba_linker_cuda_11()
else:
from pynvjitlink.patch import patch_numba_linker

patch_numba_linker()
numba_config.CUDA_ENABLE_PYNVJITLINK = True


class _CUDFNumbaConfig:
Expand Down

0 comments on commit d508aa0

Please sign in to comment.