Skip to content

Commit

Permalink
fix(gpt): flash_attention_2 impl.
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 2, 2024
1 parent 6b13131 commit b607174
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ class _GenerationInputs:
attention_mask: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None

def to(self, device: torch.device):
def to(self, device: torch.device, dtype: torch.dtype):
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
self.attention_mask = self.attention_mask.to(device, dtype=dtype)
if self.position_ids is not None:
self.position_ids = self.position_ids.to(device)
self.position_ids = self.position_ids.to(device, dtype=dtype)
if self.inputs_embeds is not None:
self.inputs_embeds = self.inputs_embeds.to(device)
self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
if self.cache_position is not None:
self.cache_position = self.cache_position.to(device)
self.cache_position = self.cache_position.to(device, dtype=dtype)

def _prepare_generation_inputs(
self,
Expand Down Expand Up @@ -422,7 +422,7 @@ def generate(
del inputs_ids_emb, model_input.input_ids
model_input.inputs_embeds = emb

model_input.to(self.device_gpt)
model_input.to(self.device_gpt, self.gpt.dtype)

outputs: BaseModelOutputWithPast = self.gpt(
attention_mask=model_input.attention_mask,
Expand Down

0 comments on commit b607174

Please sign in to comment.