From b607174e937d580ca4d0187aa57551ed9a136404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:23:08 +0900 Subject: [PATCH] fix(gpt): flash_attention_2 impl. --- ChatTTS/model/gpt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 94ea29c79..4109adf59 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -183,15 +183,15 @@ class _GenerationInputs: attention_mask: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None - def to(self, device: torch.device): + def to(self, device: torch.device, dtype: torch.dtype): if self.attention_mask is not None: - self.attention_mask = self.attention_mask.to(device) + self.attention_mask = self.attention_mask.to(device, dtype=dtype) if self.position_ids is not None: - self.position_ids = self.position_ids.to(device) + self.position_ids = self.position_ids.to(device, dtype=dtype) if self.inputs_embeds is not None: - self.inputs_embeds = self.inputs_embeds.to(device) + self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype) if self.cache_position is not None: - self.cache_position = self.cache_position.to(device) + self.cache_position = self.cache_position.to(device, dtype=dtype) def _prepare_generation_inputs( self, @@ -422,7 +422,7 @@ def generate( del inputs_ids_emb, model_input.input_ids model_input.inputs_embeds = emb - model_input.to(self.device_gpt) + model_input.to(self.device_gpt, self.gpt.dtype) outputs: BaseModelOutputWithPast = self.gpt( attention_mask=model_input.attention_mask,