Skip to content

Commit

Permalink
A few doc updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Aug 1, 2023
1 parent 206af68 commit ea9a413
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions keras_nlp/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ class RotaryEmbedding(keras.layers.Layer):
matrix. It calculates the rotary encoding with a mix of sine and
cosine functions with geometrically increasing wavelengths.
Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
The input must be a tensor with shape [batch_size, sequence_length, feature_dim].
This layer will return new tensor after applying rotational encoding.
The input must be a tensor with shape a sequence dimension and a feature
dimension. Typically, this will either an input with shape
`(batch_size, sequence_length, feature_length)` or
`(batch_size, sequence_length, num_heads, feature_length)`.
This layer will return a new tensor with the rotary embedding applied to
the input tensor.
Args:
percentage: float. The percentage of attn_head_size over which rotation
should be applied. Defaults to 0
max_wavelength: int. The maximum angular wavelength of the sine/cosine
curves, as described in Attention is All You Need. Defaults to
`10000`.
scaling_factor: float. The scaling factor used to scale frequency range
sequence_axis: int. Sequence axis in the input tensor
feature_axis: int. Feature axis in the input tensor
curves.
scaling_factor: float. The scaling factor used to scale frequency range.
sequence_axis: int. Sequence axis in the input tensor.
feature_axis: int. Feature axis in the input tensor.
Examples:
Expand All @@ -45,10 +46,12 @@ class RotaryEmbedding(keras.layers.Layer):
sequence_length = 256
num_heads = 8
# No multi-head dimension.
tensor = tf.ones((batch_size, sequence_length, feature_length))
rot_emb_layer = RotaryEmbedding()
tensor_rot = rot_emb_layer(tensor)
# With multi-head dimension.
tensor = tf.ones((batch_size, sequence_length, num_heads, feature_length))
tensor_rot = rot_emb_layer(tensor)
```
Expand Down

0 comments on commit ea9a413

Please sign in to comment.