Skip to content

Commit

Permalink
Reserve KV cache capacity after the first model run
Browse files Browse the repository at this point in the history
Hugging Face models with separate branches for the first and subsequent
iterations do not use the input KV cache buffer on the first run. Thus they did
not benefit from the pre-allocated capacity and ended up re-allocating a new KV
cache buffer on each run.

To resolve this, change the KV cache growth strategy to grow the buffer after
the model runs, if the capacity limit has been reached. Also replace the
hard-coded capacity with a growth strategy that doubles the capacity each time.
This amortizes the costs of copying the old KV cache into the new buffer.
  • Loading branch information
robertknight committed Nov 15, 2024
1 parent 773f728 commit c24e15f
Showing 1 changed file with 99 additions and 14 deletions.
113 changes: 99 additions & 14 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,76 @@ enum KvCacheData {
BatchHeadSeqChans(NdTensor<f32, 4>),
}

impl KvCacheData {
/// Allocate a KV cache buffer with the given batch size, number of heads
/// and embed size.
///
/// The buffer initially has capacity to be extended to a sequence length
/// of `seq_len_capacity`.
fn with_capacity(
batch_size: usize,
n_heads: Option<usize>,
size: usize,
seq_len_capacity: usize,
) -> KvCacheData {
if let Some(n_heads) = n_heads {
KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, seq_len_capacity, size],
2, /* seq dim */
))
} else {
KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, seq_len_capacity, size],
1, /* seq dim */
))
}
}

/// Return the current sequence length of the cache.
fn sequence_len(&self) -> usize {
match self {
KvCacheData::BatchSeqChans(data) => data.size(1),
KvCacheData::BatchHeadSeqChans(data) => data.size(2),
}
}

/// Return true if the KV cache has capacity for a given sequence length.
fn has_capacity(&self, sequence_len: usize) -> bool {
match self {
KvCacheData::BatchSeqChans(data) => {
data.has_capacity(1 /* seq dim */, sequence_len)
}
KvCacheData::BatchHeadSeqChans(data) => {
data.has_capacity(2 /* seq dim */, sequence_len)
}
}
}

/// Clone this cache into a new buffer with space to store sequences of
/// a given size.
fn clone_with_capacity(&self, max_sequence_len: usize) -> KvCacheData {
let max_sequence_len = max_sequence_len.max(self.sequence_len());
match self {
KvCacheData::BatchSeqChans(data) => {
let [batch, _seq, chans] = data.shape();
let mut new_data =
NdTensor::with_capacity([batch, max_sequence_len, chans], 1 /* seq dim */);
new_data.append(1, data).expect("should have capacity");
KvCacheData::BatchSeqChans(new_data)
}
KvCacheData::BatchHeadSeqChans(data) => {
let [batch, n_heads, _seq, chans] = data.shape();
let mut new_data = NdTensor::with_capacity(
[batch, n_heads, max_sequence_len, chans],
2, /* seq dim */
);
new_data.append(2, data).expect("should have capacity");
KvCacheData::BatchHeadSeqChans(new_data)
}
}
}
}

/// Key-value cache for a single layer of a transformer model.
struct KvCache {
/// Input ID for this cache entry.
Expand Down Expand Up @@ -440,23 +510,28 @@ impl<'a> Generator<'a> {
.find_node(&output_name)
.ok_or(GeneratorError::OutputNotFound(output_name))?;

// This value should be configurable.
let max_seq_len = 512;
// Initial sequence length capacity for KV cache buffer.
//
// For models that execute different operations on the first vs
// subsequent iterations (eg. Hugging Face "merged" models with
// past and no-past branches) the input buffer may not be used in
// the first iteration. Instead we need to reserve capacity once
// the model returns the initial KV cache.
//
// For other simpler models the input KV cache buffer is used for
// all iterations, in which case we would ideally reserve capacity
// up-front based on the max expected sequence length.
let max_seq_len = 1;

let kv_cache_entry = KvCache {
input_id,
output_id,
cache: if let Some(n_heads) = n_heads {
Some(KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, max_seq_len, size],
2, /* seq dim */
)))
} else {
Some(KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, max_seq_len, size],
1, /* seq dim */
)))
},
cache: Some(KvCacheData::with_capacity(
batch_size,
n_heads,
size,
max_seq_len,
)),
};

if kv_pattern.encoder {
Expand Down Expand Up @@ -717,7 +792,7 @@ impl<'a> Generator<'a> {
let output = outputs.remove(0);

let err_context = "failed to save self-attention KV-cache";
let kv_cache = match output.ndim() {
let mut kv_cache = match output.ndim() {
3 => KvCacheData::BatchSeqChans(
output.try_into().map_err(|e| wrap_error(e, err_context))?,
),
Expand All @@ -731,6 +806,16 @@ impl<'a> Generator<'a> {
));
}
};

// Grow the KV cache buffer if it has reached the limit of its
// pre-allocated sequence length.
//
// Double the capacity each time to amortize the costs of copying
// the previous buffer.
if !kv_cache.has_capacity(kv_cache.sequence_len() + 1) {
kv_cache = kv_cache.clone_with_capacity(kv_cache.sequence_len() * 2);
}

cache_entry.cache = Some(kv_cache);
}

Expand Down

0 comments on commit c24e15f

Please sign in to comment.