From 4ffa19483934c163ab65b03df54ee41d90c8801f Mon Sep 17 00:00:00 2001 From: qiyulei-mt Date: Thu, 23 Jan 2025 10:35:05 +0800 Subject: [PATCH] support musa backend in FlagEmbedding --- FlagEmbedding/abc/inference/AbsEmbedder.py | 17 +++++++++++++++-- FlagEmbedding/abc/inference/AbsReranker.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index e64dda26..ba49c213 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -13,6 +13,11 @@ import numpy as np from transformers import is_torch_npu_available +try: + import torch_musa +except Exception: + pass + logger = logging.getLogger(__name__) @@ -125,6 +130,8 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s return [f"cuda:{i}" for i in range(torch.cuda.device_count())] elif is_torch_npu_available(): return [f"npu:{i}" for i in range(torch.npu.device_count())] + elif torch.musa.is_available(): + return [f"musa:{i}" for i in range(torch.musa.device_count())] elif torch.backends.mps.is_available(): try: return [f"mps:{i}" for i in range(torch.mps.device_count())] @@ -135,12 +142,18 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s elif isinstance(devices, str): return [devices] elif isinstance(devices, int): - return [f"cuda:{devices}"] + if torch.musa.is_available(): + return [f"musa:{devices}"] + else: + return [f"cuda:{devices}"] elif isinstance(devices, list): if isinstance(devices[0], str): return devices elif isinstance(devices[0], int): - return [f"cuda:{device}" for device in devices] + if torch.musa.is_available(): + return [f"musa:{device}" for device in devices] + else: + return [f"cuda:{device}" for device in devices] else: raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") else: diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index be1481f6..633bae80 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -12,6 +12,11 @@ from tqdm import tqdm, trange from transformers import is_torch_npu_available +try: + import torch_musa +except Exception: + pass + logger = logging.getLogger(__name__) @@ -107,6 +112,8 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s return [f"cuda:{i}" for i in range(torch.cuda.device_count())] elif is_torch_npu_available(): return [f"npu:{i}" for i in range(torch.npu.device_count())] + elif torch.musa.is_available(): + return [f"musa:{i}" for i in range(torch.musa.device_count())] elif torch.backends.mps.is_available(): return ["mps"] else: @@ -114,12 +121,18 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s elif isinstance(devices, str): return [devices] elif isinstance(devices, int): - return [f"cuda:{devices}"] + if torch.musa.is_available(): + return [f"musa:{devices}"] + else: + return [f"cuda:{devices}"] elif isinstance(devices, list): if isinstance(devices[0], str): return devices elif isinstance(devices[0], int): - return [f"cuda:{device}" for device in devices] + if torch.musa.is_available(): + return [f"musa:{device}" for device in devices] + else: + return [f"cuda:{device}" for device in devices] else: raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") else: