Skip to content

Commit

Permalink
Raise ValueError when number of dims evaluate to zero (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Aug 8, 2023
1 parent 272ba83 commit f75c9b6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def build(
hidden_dim = decoder_sequence_shape[-1]
# Attention head size is `hidden_dim` over the number of heads.
head_dim = int(hidden_dim // self.num_heads)
if head_dim == 0:
raise ValueError(
"Attention `head_dim` computed cannot be zero. "
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
f"or greater than `num_heads` value of {self.num_heads}."
)

# Self attention layers.
self._self_attention_layer = CachedMultiHeadAttention(
Expand Down
8 changes: 7 additions & 1 deletion keras_nlp/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
bias_initializer="zeros",
normalize_first=False,
name=None,
**kwargs
**kwargs,
):
super().__init__(name=name, **kwargs)
self.intermediate_dim = intermediate_dim
Expand All @@ -113,6 +113,12 @@ def build(self, inputs_shape):
hidden_dim = inputs_shape[-1]
# Attention head size is `hidden_dim` over the number of heads.
key_dim = int(hidden_dim // self.num_heads)
if key_dim == 0:
raise ValueError(
"Attention `key_dim` computed cannot be zero. "
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
f"or greater than `num_heads` value of {self.num_heads}."
)

# Self attention layers.
self._self_attention_layer = keras.layers.MultiHeadAttention(
Expand Down

0 comments on commit f75c9b6

Please sign in to comment.