From 1efc21532f188142d60bf655a85f4076761bd9cb Mon Sep 17 00:00:00 2001 From: Adam Stachowicz Date: Mon, 20 Jan 2025 12:59:06 +0200 Subject: [PATCH] Add HPU support --- src/transformers/training_args.py | 4 ++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 9 +++++++++ 3 files changed, 14 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a7b2ba0db3a..7cfef9271c0 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -48,6 +48,7 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -2268,6 +2269,9 @@ def _setup_devices(self) -> "torch.device": elif is_torch_npu_available(): device = torch.device("npu:0") torch.npu.set_device(device) + elif is_torch_hpu_available(): + device = torch.device("hpu:0") + torch.hpu.set_device(device) else: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 919f9f8bde0..3e299c5bb92 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -216,6 +216,7 @@ is_torch_fx_available, is_torch_fx_proxy, is_torch_greater_or_equal, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ac07281b3d3..39cc28ee883 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -765,6 +765,15 @@ def is_torch_musa_available(check_device=False): return hasattr(torch, "musa") and torch.musa.is_available() +@lru_cache() +def is_torch_hpu_available(): + try: + import torch + return torch.device("hpu") is not None + except (ImportError, RuntimeError): + return False + + def is_torchdynamo_available(): if not is_torch_available(): return False