diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 94ea29c79..4109adf59 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -183,15 +183,15 @@ class _GenerationInputs: attention_mask: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None - def to(self, device: torch.device): + def to(self, device: torch.device, dtype: torch.dtype): if self.attention_mask is not None: - self.attention_mask = self.attention_mask.to(device) + self.attention_mask = self.attention_mask.to(device, dtype=dtype) if self.position_ids is not None: - self.position_ids = self.position_ids.to(device) + self.position_ids = self.position_ids.to(device, dtype=dtype) if self.inputs_embeds is not None: - self.inputs_embeds = self.inputs_embeds.to(device) + self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype) if self.cache_position is not None: - self.cache_position = self.cache_position.to(device) + self.cache_position = self.cache_position.to(device, dtype=dtype) def _prepare_generation_inputs( self, @@ -422,7 +422,7 @@ def generate( del inputs_ids_emb, model_input.input_ids model_input.inputs_embeds = emb - model_input.to(self.device_gpt) + model_input.to(self.device_gpt, self.gpt.dtype) outputs: BaseModelOutputWithPast = self.gpt( attention_mask=model_input.attention_mask,