diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index dac7b2ea..05995957 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -26,6 +26,13 @@ defmodule Bumblebee.Text.Llama do default: 11008, doc: "the dimensionality of intermediate layers" ], + attention_head_size: [ + default: nil, + doc: """ + the size of the key, value, and query projection per attention head. + Defaults to `div(hidden_size, num_attention_heads) + """ + ], num_blocks: [ default: 32, doc: "the number of Transformer blocks in the model" @@ -169,6 +176,7 @@ defmodule Bumblebee.Text.Llama do def init_cache(spec, batch_size, max_length, _inputs) do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, decoder_num_attention_heads: spec.num_attention_heads, decoder_num_blocks: spec.num_blocks ) @@ -321,6 +329,7 @@ defmodule Bumblebee.Text.Llama do Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, attention_head_mask: attention_head_mask, + attention_head_size: spec.attention_head_size, cache: cache, num_blocks: spec.num_blocks, num_attention_heads: spec.num_attention_heads, @@ -431,6 +440,7 @@ defmodule Bumblebee.Text.Llama do num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", activation()}, rotary_embedding_base: {"rope_theta", number()},