diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 75f983924..c178a9ad2 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -268,17 +268,19 @@ def _load( self.vocos = vocos self.logger.log(logging.INFO, "vocos loaded.") + # computation of MelSpectrogram on npu is not support now, use cpu fallback. + dvae_device = torch.device("cpu") if "npu" in str(self.device) else device dvae = DVAE( decoder_config=asdict(self.config.dvae.decoder), encoder_config=asdict(self.config.dvae.encoder), vq_config=asdict(self.config.dvae.vq), dim=self.config.dvae.decoder.idim, coef=coef, - device=device, + device=dvae_device, ) coef = str(dvae) assert dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_pretrained(dvae_ckpt_path, device) + dvae.load_pretrained(dvae_ckpt_path, dvae_device) self.dvae = dvae.eval() self.logger.log(logging.INFO, "dvae loaded.")