Skip to content

Commit

Permalink
feat(utils): add NPU Support (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
shen-shanshan authored Oct 10, 2024
1 parent 71b42e0 commit a1aebd4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
8 changes: 5 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ def _load(
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
.to(
# vocos on mps will crash, use cpu fallback
# Vocos on mps will crash, use cpu fallback.
# Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now,
# so we put this calculation of data on CPU instead of NPU.
"cpu"
if "mps" in str(device)
if "mps" in str(device) or "npu" in str(device)
else device
)
.eval()
Expand Down Expand Up @@ -422,7 +424,7 @@ def _infer(

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
if "mps" in str(self.device) or "npu" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
else:
return self.vocos.decode(spec).cpu().numpy()
Expand Down
36 changes: 35 additions & 1 deletion ChatTTS/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import torch

try:
import torch_npu
except ImportError:
pass

from .log import logger


Expand All @@ -21,6 +26,26 @@ def select_device(min_memory=2047, experimental=False):
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{selected_gpu}")
elif _is_torch_npu_available():
"""
Using Ascend NPU to accelerate the process of inferencing when GPU is not found.
"""
selected_npu = 0
max_free_memory = -1
for i in range(torch.npu.device_count()):
props = torch.npu.get_device_properties(i)
free_memory = props.total_memory - torch.npu.memory_reserved(i)
if max_free_memory < free_memory:
selected_npu = i
max_free_memory = free_memory
free_memory_mb = max_free_memory / (1024 * 1024)
if free_memory_mb < min_memory:
logger.get_logger().warning(
f"NPU {selected_npu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
)
device = torch.device("cpu")
else:
device = torch.device(f"npu:{selected_npu}")
elif torch.backends.mps.is_available():
"""
Currently MPS is slower than CPU while needs more memory and core utility,
Expand All @@ -34,7 +59,16 @@ def select_device(min_memory=2047, experimental=False):
logger.get_logger().info("found Apple GPU, but use CPU.")
device = torch.device("cpu")
else:
logger.get_logger().warning("no GPU found, use CPU instead")
logger.get_logger().warning("no GPU or NPU found, use CPU instead")
device = torch.device("cpu")

return device


def _is_torch_npu_available():
try:
# will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found
_ = torch.npu.device_count()
return torch.npu.is_available()
except (AttributeError, RuntimeError):
return False

0 comments on commit a1aebd4

Please sign in to comment.