diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index c790e15ff7..d5a26ed0d3 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -6,6 +6,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, logger from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ....utils import get_device_name + class GaudiGPT2Attention(torch.nn.Module): """ @@ -264,7 +266,16 @@ def gaudi_gpt2_block_forward( """ residual = hidden_states - hidden_states = self.ln_1(hidden_states) + + # TODO: remove this workaround when SynapseAI 1.13 is released + if not self.ln_1.training and get_device_name() == "gaudi2" and hidden_states.shape[:-1] == torch.Size([1, 1]): + # Change to 1,2,1600 and back to 1,1,1600 + hidden_states = hidden_states.repeat([1, 2, 1]) # this changes the shape 1x2x1600 + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states[:, :1, :] + else: + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( hidden_states, layer_past=layer_past, @@ -302,7 +313,16 @@ def gaudi_gpt2_block_forward( outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states - hidden_states = self.ln_2(hidden_states) + + # TODO: remove this workaround when SynapseAI 1.13 is released + if not self.ln_2.training and get_device_name() == "gaudi2" and hidden_states.shape[:-1] == torch.Size([1, 1]): + # Change to 1,2,1600 and back to 1,1,1600 + hidden_states = hidden_states.repeat([1, 2, 1]) # this changes the shape 1x2x1600 + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states[:, :1, :] + else: + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states @@ -501,7 +521,14 @@ def custom_forward(*inputs): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) - hidden_states = self.ln_f(hidden_states) + # TODO: remove this workaround when SynapseAI 1.13 is released + if not self.ln_f.training and get_device_name() == "gaudi2" and hidden_states.shape[:-1] == torch.Size([1, 1]): + # Change to 1,2,1600 and back to 1,1,1600 + hidden_states = hidden_states.repeat([1, 2, 1]) # this changes the shape 1x2x1600 + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states[:, :1, :] + else: + hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state