diff --git a/caikit_nlp/toolkit/torch_run.py b/caikit_nlp/toolkit/torch_run.py index 3a8879c8..43184f2d 100644 --- a/caikit_nlp/toolkit/torch_run.py +++ b/caikit_nlp/toolkit/torch_run.py @@ -24,7 +24,8 @@ # Third Party from torch import cuda -from torch.distributed.launcher.api import LaunchConfig, Std +from torch.distributed.elastic.multiprocessing.api import Std +from torch.distributed.launcher.api import LaunchConfig import torch.distributed as dist # First Party diff --git a/pyproject.toml b/pyproject.toml index bcb82db0..e1ce63ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "scipy>=1.8.1", "sentence-transformers>=2.3.1,<2.4.0", "tokenizers>=0.13.3", - "torch>=2.0.1,<2.3.0", + "torch>=2.0.1", "tqdm>=4.65.0", "transformers>=4.32.0", "peft==0.6.0",