From bafc98522b7dd8888224d07d4164a42f6634f3cf Mon Sep 17 00:00:00 2001 From: Abheesht Date: Fri, 8 Nov 2024 23:32:25 +0530 Subject: [PATCH 1/2] Fix `return_attention_scores` bug --- .../layers/modeling/transformer_encoder.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index 5ed121e457..ccfec182bd 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -207,22 +207,13 @@ def call( if self.normalize_first: x = self._self_attention_layer_norm(x) - if return_attention_scores: - x, attention_scores = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - return_attention_scores=return_attention_scores, - training=training, - ) - return x, attention_scores - else: - x = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - training=training, - ) + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=True, + training=training, + ) x = self._self_attention_dropout(x, training=training) x = x + residual From 9d50369c1c51f7ecea1ad7692c10b5fd585be8bb Mon Sep 17 00:00:00 2001 From: Abheesht Date: Fri, 8 Nov 2024 23:52:57 +0530 Subject: [PATCH 2/2] Keep the if...else block --- .../layers/modeling/transformer_encoder.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index ccfec182bd..4f471a12dd 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -207,13 +207,21 @@ def call( if self.normalize_first: x = self._self_attention_layer_norm(x) - x, attention_scores = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - return_attention_scores=True, - training=training, - ) + if return_attention_scores: + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=return_attention_scores, + training=training, + ) + else: + x = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + training=training, + ) x = self._self_attention_dropout(x, training=training) x = x + residual