diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c38ad8957..4447dca54 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -253,9 +253,11 @@ def _load( vocos = ( Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head) .to( - # vocos on mps will crash, use cpu fallback + # Vocos on mps will crash, use cpu fallback. + # Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now, + # so we put this calculation of data on CPU instead of NPU. "cpu" - if "mps" in str(device) + if "mps" in str(device) or "npu" in str(device) else device ) .eval() @@ -422,7 +424,7 @@ def _infer( @torch.inference_mode() def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: - if "mps" in str(self.device): + if "mps" in str(self.device) or "npu" in str(self.device): return self.vocos.decode(spec.cpu()).cpu().numpy() else: return self.vocos.decode(spec).cpu().numpy() diff --git a/ChatTTS/utils/gpu.py b/ChatTTS/utils/gpu.py index 58aeb3eea..f0a58bbd7 100644 --- a/ChatTTS/utils/gpu.py +++ b/ChatTTS/utils/gpu.py @@ -1,5 +1,10 @@ import torch +try: + import torch_npu +except ImportError: + pass + from .log import logger @@ -21,6 +26,26 @@ def select_device(min_memory=2047, experimental=False): device = torch.device("cpu") else: device = torch.device(f"cuda:{selected_gpu}") + elif _is_torch_npu_available(): + """ + Using Ascend NPU to accelerate the process of inferencing when GPU is not found. + """ + selected_npu = 0 + max_free_memory = -1 + for i in range(torch.npu.device_count()): + props = torch.npu.get_device_properties(i) + free_memory = props.total_memory - torch.npu.memory_reserved(i) + if max_free_memory < free_memory: + selected_npu = i + max_free_memory = free_memory + free_memory_mb = max_free_memory / (1024 * 1024) + if free_memory_mb < min_memory: + logger.get_logger().warning( + f"NPU {selected_npu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." + ) + device = torch.device("cpu") + else: + device = torch.device(f"npu:{selected_npu}") elif torch.backends.mps.is_available(): """ Currently MPS is slower than CPU while needs more memory and core utility, @@ -34,7 +59,16 @@ def select_device(min_memory=2047, experimental=False): logger.get_logger().info("found Apple GPU, but use CPU.") device = torch.device("cpu") else: - logger.get_logger().warning("no GPU found, use CPU instead") + logger.get_logger().warning("no GPU or NPU found, use CPU instead") device = torch.device("cpu") return device + + +def _is_torch_npu_available(): + try: + # will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found + _ = torch.npu.device_count() + return torch.npu.is_available() + except (AttributeError, RuntimeError): + return False