Skip to content

Commit

Permalink
collab: Track cache writes/reads in LLM usage (#18834)
Browse files Browse the repository at this point in the history
This PR extends the LLM usage tracking to support tracking usage for
cache writes and reads for Anthropic models.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <[email protected]>
Co-authored-by: Antonio <[email protected]>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent c5d252b commit d55f025
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 39 deletions.
4 changes: 4 additions & 0 deletions crates/anthropic/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ pub struct Usage {
pub input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u32>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
alter table models
add column price_per_million_cache_creation_input_tokens integer not null default 0,
add column price_per_million_cache_read_input_tokens integer not null default 0;

alter table usages
add column cache_creation_input_tokens_this_month bigint not null default 0,
add column cache_read_input_tokens_this_month bigint not null default 0;

alter table lifetime_usages
add column cache_creation_input_tokens bigint not null default 0,
add column cache_read_input_tokens bigint not null default 0;
80 changes: 59 additions & 21 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,22 +318,31 @@ async fn perform_completion(
chunks
.map(move |event| {
let chunk = event?;
let (input_tokens, output_tokens) = match &chunk {
let (
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
) = match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => (
usage.input_tokens.unwrap_or(0) as usize,
usage.output_tokens.unwrap_or(0) as usize,
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
usage.cache_read_input_tokens.unwrap_or(0) as usize,
),
_ => (0, 0),
_ => (0, 0, 0, 0),
};

anyhow::Ok((
serde_json::to_vec(&chunk).unwrap(),
anyhow::Ok(CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
))
cache_creation_input_tokens,
cache_read_input_tokens,
})
})
.boxed()
}
Expand All @@ -359,11 +368,13 @@ async fn perform_completion(
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
let output_tokens =
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
(
serde_json::to_vec(&chunk).unwrap(),
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
Expand All @@ -387,13 +398,13 @@ async fn perform_completion(
.map(|event| {
event.map(|chunk| {
// TODO - implement token counting for Google AI
let input_tokens = 0;
let output_tokens = 0;
(
serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
Expand All @@ -407,6 +418,8 @@ async fn perform_completion(
model,
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
inner_stream: stream,
})))
}
Expand Down Expand Up @@ -551,29 +564,41 @@ async fn check_usage_limit(
Ok(())
}

struct CompletionChunk {
bytes: Vec<u8>,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
}

struct TokenCountingStream<S> {
state: Arc<LlmState>,
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
inner_stream: S,
}

impl<S> Stream for TokenCountingStream<S>
where
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
{
type Item = Result<Vec<u8>, anyhow::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
bytes.push(b'\n');
self.input_tokens += input_tokens;
self.output_tokens += output_tokens;
Poll::Ready(Some(Ok(bytes)))
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.input_tokens += chunk.input_tokens;
self.output_tokens += chunk.output_tokens;
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Expand All @@ -590,6 +615,8 @@ impl<S> Drop for TokenCountingStream<S> {
let model = std::mem::take(&mut self.model);
let input_token_count = self.input_tokens;
let output_token_count = self.output_tokens;
let cache_creation_input_token_count = self.cache_creation_input_tokens;
let cache_read_input_token_count = self.cache_read_input_tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
Expand All @@ -599,6 +626,8 @@ impl<S> Drop for TokenCountingStream<S> {
provider,
&model,
input_token_count,
cache_creation_input_token_count,
cache_read_input_token_count,
output_token_count,
Utc::now(),
)
Expand Down Expand Up @@ -630,11 +659,20 @@ impl<S> Drop for TokenCountingStream<S> {
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
cache_creation_input_token_count: cache_creation_input_token_count
as u64,
cache_read_input_token_count: cache_read_input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
cache_creation_input_tokens_this_month: usage
.cache_creation_input_tokens_this_month
as u64,
cache_read_input_tokens_this_month: usage
.cache_read_input_tokens_this_month
as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
spending_this_month: usage.spending_this_month as u64,
lifetime_spending: usage.lifetime_spending as u64,
Expand Down
Loading

0 comments on commit d55f025

Please sign in to comment.