diff --git a/ldp/nn/handlers/transformer_handler.py b/ldp/nn/handlers/transformer_handler.py index c110d5e..6fd7b74 100644 --- a/ldp/nn/handlers/transformer_handler.py +++ b/ldp/nn/handlers/transformer_handler.py @@ -17,7 +17,6 @@ import tree from dask import config from dask.distributed import Client -from dask_cuda import LocalCUDACluster from dask_jobqueue import SLURMCluster from pydantic import BaseModel, ConfigDict, Field, field_validator from torch import nn @@ -483,6 +482,9 @@ def _init_local_cluster( self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig ): """Initialize a Dask cluster on local machine.""" + # lazy import since dask-cuda only works on Linux machines + from dask_cuda import LocalCUDACluster + self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker,