Skip to content

Commit

Permalink
Make dask_cuda import lazy (#220)
Browse files Browse the repository at this point in the history
Co-authored-by: James Braza <[email protected]>
  • Loading branch information
sidnarayanan and jamesbraza authored Jan 23, 2025
1 parent 4d964e8 commit 7d383af
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ldp/nn/handlers/transformer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import tree
from dask import config
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from dask_jobqueue import SLURMCluster
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch import nn
Expand Down Expand Up @@ -483,6 +482,9 @@ def _init_local_cluster(
self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig
):
"""Initialize a Dask cluster on local machine."""
# lazy import since dask-cuda only works on Linux machines
from dask_cuda import LocalCUDACluster

self.cluster = LocalCUDACluster(
n_workers=parallel_mode_config.num_workers,
threads_per_worker=parallel_mode_config.num_cpus_per_worker,
Expand Down

0 comments on commit 7d383af

Please sign in to comment.