From d55f0259060a4461cf04ef60c04701285763bb28 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 7 Oct 2024 17:32:49 -0400 Subject: [PATCH] collab: Track cache writes/reads in LLM usage (#18834) This PR extends the LLM usage tracking to support tracking usage for cache writes and reads for Anthropic models. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Antonio --- crates/anthropic/src/anthropic.rs | 4 + .../20241007173634_add_cache_token_counts.sql | 11 ++ crates/collab/src/llm.rs | 80 +++++++++---- crates/collab/src/llm/db/queries/usages.rs | 113 +++++++++++++++--- .../src/llm/db/tables/lifetime_usage.rs | 2 + crates/collab/src/llm/db/tables/model.rs | 2 + .../collab/src/llm/db/tables/usage_measure.rs | 2 + crates/collab/src/llm/db/tests/usage_tests.rs | 62 +++++++++- crates/collab/src/llm/telemetry.rs | 4 + 9 files changed, 241 insertions(+), 39 deletions(-) create mode 100644 crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 6b8972284208a1..08c8f27bd90276 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -521,6 +521,10 @@ pub struct Usage { pub input_tokens: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub output_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_creation_input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_read_input_tokens: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql b/crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql new file mode 100644 index 00000000000000..855e46ab0224dc --- /dev/null +++ b/crates/collab/migrations_llm/20241007173634_add_cache_token_counts.sql @@ -0,0 +1,11 @@ +alter table models + add column price_per_million_cache_creation_input_tokens integer not null default 0, + add column price_per_million_cache_read_input_tokens integer not null default 0; + +alter table usages + add column cache_creation_input_tokens_this_month bigint not null default 0, + add column cache_read_input_tokens_this_month bigint not null default 0; + +alter table lifetime_usages + add column cache_creation_input_tokens bigint not null default 0, + add column cache_read_input_tokens bigint not null default 0; diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 2d040cfa28e1a9..9809985ac72b23 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -318,22 +318,31 @@ async fn perform_completion( chunks .map(move |event| { let chunk = event?; - let (input_tokens, output_tokens) = match &chunk { + let ( + input_tokens, + output_tokens, + cache_creation_input_tokens, + cache_read_input_tokens, + ) = match &chunk { anthropic::Event::MessageStart { message: anthropic::Response { usage, .. }, } | anthropic::Event::MessageDelta { usage, .. } => ( usage.input_tokens.unwrap_or(0) as usize, usage.output_tokens.unwrap_or(0) as usize, + usage.cache_creation_input_tokens.unwrap_or(0) as usize, + usage.cache_read_input_tokens.unwrap_or(0) as usize, ), - _ => (0, 0), + _ => (0, 0, 0, 0), }; - anyhow::Ok(( - serde_json::to_vec(&chunk).unwrap(), + anyhow::Ok(CompletionChunk { + bytes: serde_json::to_vec(&chunk).unwrap(), input_tokens, output_tokens, - )) + cache_creation_input_tokens, + cache_read_input_tokens, + }) }) .boxed() } @@ -359,11 +368,13 @@ async fn perform_completion( chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize; let output_tokens = chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize; - ( - serde_json::to_vec(&chunk).unwrap(), + CompletionChunk { + bytes: serde_json::to_vec(&chunk).unwrap(), input_tokens, output_tokens, - ) + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } }) }) .boxed() @@ -387,13 +398,13 @@ async fn perform_completion( .map(|event| { event.map(|chunk| { // TODO - implement token counting for Google AI - let input_tokens = 0; - let output_tokens = 0; - ( - serde_json::to_vec(&chunk).unwrap(), - input_tokens, - output_tokens, - ) + CompletionChunk { + bytes: serde_json::to_vec(&chunk).unwrap(), + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } }) }) .boxed() @@ -407,6 +418,8 @@ async fn perform_completion( model, input_tokens: 0, output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, inner_stream: stream, }))) } @@ -551,6 +564,14 @@ async fn check_usage_limit( Ok(()) } +struct CompletionChunk { + bytes: Vec, + input_tokens: usize, + output_tokens: usize, + cache_creation_input_tokens: usize, + cache_read_input_tokens: usize, +} + struct TokenCountingStream { state: Arc, claims: LlmTokenClaims, @@ -558,22 +579,26 @@ struct TokenCountingStream { model: String, input_tokens: usize, output_tokens: usize, + cache_creation_input_tokens: usize, + cache_read_input_tokens: usize, inner_stream: S, } impl Stream for TokenCountingStream where - S: Stream, usize, usize), anyhow::Error>> + Unpin, + S: Stream> + Unpin, { type Item = Result, anyhow::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::new(&mut self.inner_stream).poll_next(cx) { - Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => { - bytes.push(b'\n'); - self.input_tokens += input_tokens; - self.output_tokens += output_tokens; - Poll::Ready(Some(Ok(bytes))) + Poll::Ready(Some(Ok(mut chunk))) => { + chunk.bytes.push(b'\n'); + self.input_tokens += chunk.input_tokens; + self.output_tokens += chunk.output_tokens; + self.cache_creation_input_tokens += chunk.cache_creation_input_tokens; + self.cache_read_input_tokens += chunk.cache_read_input_tokens; + Poll::Ready(Some(Ok(chunk.bytes))) } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(None) => Poll::Ready(None), @@ -590,6 +615,8 @@ impl Drop for TokenCountingStream { let model = std::mem::take(&mut self.model); let input_token_count = self.input_tokens; let output_token_count = self.output_tokens; + let cache_creation_input_token_count = self.cache_creation_input_tokens; + let cache_read_input_token_count = self.cache_read_input_tokens; self.state.executor.spawn_detached(async move { let usage = state .db @@ -599,6 +626,8 @@ impl Drop for TokenCountingStream { provider, &model, input_token_count, + cache_creation_input_token_count, + cache_read_input_token_count, output_token_count, Utc::now(), ) @@ -630,11 +659,20 @@ impl Drop for TokenCountingStream { model, provider: provider.to_string(), input_token_count: input_token_count as u64, + cache_creation_input_token_count: cache_creation_input_token_count + as u64, + cache_read_input_token_count: cache_read_input_token_count as u64, output_token_count: output_token_count as u64, requests_this_minute: usage.requests_this_minute as u64, tokens_this_minute: usage.tokens_this_minute as u64, tokens_this_day: usage.tokens_this_day as u64, input_tokens_this_month: usage.input_tokens_this_month as u64, + cache_creation_input_tokens_this_month: usage + .cache_creation_input_tokens_this_month + as u64, + cache_read_input_tokens_this_month: usage + .cache_read_input_tokens_this_month + as u64, output_tokens_this_month: usage.output_tokens_this_month as u64, spending_this_month: usage.spending_this_month as u64, lifetime_spending: usage.lifetime_spending as u64, diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 65a0bd67345bd0..128a42bc58791d 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -14,6 +14,8 @@ pub struct Usage { pub tokens_this_minute: usize, pub tokens_this_day: usize, pub input_tokens_this_month: usize, + pub cache_creation_input_tokens_this_month: usize, + pub cache_read_input_tokens_this_month: usize, pub output_tokens_this_month: usize, pub spending_this_month: usize, pub lifetime_spending: usize, @@ -160,17 +162,14 @@ impl LlmDatabase { .all(&*tx) .await?; - let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find() + let lifetime_usage = lifetime_usage::Entity::find() .filter( lifetime_usage::Column::UserId .eq(user_id) .and(lifetime_usage::Column::ModelId.eq(model.id)), ) .one(&*tx) - .await? - .map_or((0, 0), |usage| { - (usage.input_tokens as usize, usage.output_tokens as usize) - }); + .await?; let requests_this_minute = self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?; @@ -180,18 +179,44 @@ impl LlmDatabase { self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?; let input_tokens_this_month = self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?; + let cache_creation_input_tokens_this_month = self.get_usage_for_measure( + &usages, + now, + UsageMeasure::CacheCreationInputTokensPerMonth, + )?; + let cache_read_input_tokens_this_month = self.get_usage_for_measure( + &usages, + now, + UsageMeasure::CacheReadInputTokensPerMonth, + )?; let output_tokens_this_month = self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?; - let spending_this_month = - calculate_spending(model, input_tokens_this_month, output_tokens_this_month); - let lifetime_spending = - calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens); + let spending_this_month = calculate_spending( + model, + input_tokens_this_month, + cache_creation_input_tokens_this_month, + cache_read_input_tokens_this_month, + output_tokens_this_month, + ); + let lifetime_spending = if let Some(lifetime_usage) = lifetime_usage { + calculate_spending( + model, + lifetime_usage.input_tokens as usize, + lifetime_usage.cache_creation_input_tokens as usize, + lifetime_usage.cache_read_input_tokens as usize, + lifetime_usage.output_tokens as usize, + ) + } else { + 0 + }; Ok(Usage { requests_this_minute, tokens_this_minute, tokens_this_day, input_tokens_this_month, + cache_creation_input_tokens_this_month, + cache_read_input_tokens_this_month, output_tokens_this_month, spending_this_month, lifetime_spending, @@ -208,6 +233,8 @@ impl LlmDatabase { provider: LanguageModelProvider, model_name: &str, input_token_count: usize, + cache_creation_input_tokens: usize, + cache_read_input_tokens: usize, output_token_count: usize, now: DateTimeUtc, ) -> Result { @@ -235,6 +262,10 @@ impl LlmDatabase { &tx, ) .await?; + let total_token_count = input_token_count + + cache_read_input_tokens + + cache_creation_input_tokens + + output_token_count; let tokens_this_minute = self .update_usage_for_measure( user_id, @@ -243,7 +274,7 @@ impl LlmDatabase { &usages, UsageMeasure::TokensPerMinute, now, - input_token_count + output_token_count, + total_token_count, &tx, ) .await?; @@ -255,7 +286,7 @@ impl LlmDatabase { &usages, UsageMeasure::TokensPerDay, now, - input_token_count + output_token_count, + total_token_count, &tx, ) .await?; @@ -271,6 +302,30 @@ impl LlmDatabase { &tx, ) .await?; + let cache_creation_input_tokens_this_month = self + .update_usage_for_measure( + user_id, + is_staff, + model.id, + &usages, + UsageMeasure::CacheCreationInputTokensPerMonth, + now, + cache_creation_input_tokens, + &tx, + ) + .await?; + let cache_read_input_tokens_this_month = self + .update_usage_for_measure( + user_id, + is_staff, + model.id, + &usages, + UsageMeasure::CacheReadInputTokensPerMonth, + now, + cache_read_input_tokens, + &tx, + ) + .await?; let output_tokens_this_month = self .update_usage_for_measure( user_id, @@ -283,8 +338,13 @@ impl LlmDatabase { &tx, ) .await?; - let spending_this_month = - calculate_spending(model, input_tokens_this_month, output_tokens_this_month); + let spending_this_month = calculate_spending( + model, + input_tokens_this_month, + cache_creation_input_tokens_this_month, + cache_read_input_tokens_this_month, + output_tokens_this_month, + ); // Update lifetime usage let lifetime_usage = lifetime_usage::Entity::find() @@ -303,6 +363,12 @@ impl LlmDatabase { input_tokens: ActiveValue::set( usage.input_tokens + input_token_count as i64, ), + cache_creation_input_tokens: ActiveValue::set( + usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, + ), + cache_read_input_tokens: ActiveValue::set( + usage.cache_read_input_tokens + cache_read_input_tokens as i64, + ), output_tokens: ActiveValue::set( usage.output_tokens + output_token_count as i64, ), @@ -327,6 +393,8 @@ impl LlmDatabase { let lifetime_spending = calculate_spending( model, lifetime_usage.input_tokens as usize, + lifetime_usage.cache_creation_input_tokens as usize, + lifetime_usage.cache_read_input_tokens as usize, lifetime_usage.output_tokens as usize, ); @@ -335,6 +403,8 @@ impl LlmDatabase { tokens_this_minute, tokens_this_day, input_tokens_this_month, + cache_creation_input_tokens_this_month, + cache_read_input_tokens_this_month, output_tokens_this_month, spending_this_month, lifetime_spending, @@ -501,13 +571,24 @@ impl LlmDatabase { fn calculate_spending( model: &model::Model, input_tokens_this_month: usize, + cache_creation_input_tokens_this_month: usize, + cache_read_input_tokens_this_month: usize, output_tokens_this_month: usize, ) -> usize { let input_token_cost = input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000; + let cache_creation_input_token_cost = cache_creation_input_tokens_this_month + * model.price_per_million_cache_creation_input_tokens as usize + / 1_000_000; + let cache_read_input_token_cost = cache_read_input_tokens_this_month + * model.price_per_million_cache_read_input_tokens as usize + / 1_000_000; let output_token_cost = output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000; - input_token_cost + output_token_cost + input_token_cost + + cache_creation_input_token_cost + + cache_read_input_token_cost + + output_token_cost } const MINUTE_BUCKET_COUNT: usize = 12; @@ -521,6 +602,8 @@ impl UsageMeasure { UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT, UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT, UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT, + UsageMeasure::CacheCreationInputTokensPerMonth => MONTH_BUCKET_COUNT, + UsageMeasure::CacheReadInputTokensPerMonth => MONTH_BUCKET_COUNT, UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT, } } @@ -531,6 +614,8 @@ impl UsageMeasure { UsageMeasure::TokensPerMinute => Duration::minutes(1), UsageMeasure::TokensPerDay => Duration::hours(24), UsageMeasure::InputTokensPerMonth => Duration::days(30), + UsageMeasure::CacheCreationInputTokensPerMonth => Duration::days(30), + UsageMeasure::CacheReadInputTokensPerMonth => Duration::days(30), UsageMeasure::OutputTokensPerMonth => Duration::days(30), } } diff --git a/crates/collab/src/llm/db/tables/lifetime_usage.rs b/crates/collab/src/llm/db/tables/lifetime_usage.rs index 05ad2d5e94c1fa..fc8354699b2309 100644 --- a/crates/collab/src/llm/db/tables/lifetime_usage.rs +++ b/crates/collab/src/llm/db/tables/lifetime_usage.rs @@ -9,6 +9,8 @@ pub struct Model { pub user_id: UserId, pub model_id: ModelId, pub input_tokens: i64, + pub cache_creation_input_tokens: i64, + pub cache_read_input_tokens: i64, pub output_tokens: i64, } diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs index c87789f27e2fc6..4d7d2d8da9bce3 100644 --- a/crates/collab/src/llm/db/tables/model.rs +++ b/crates/collab/src/llm/db/tables/model.rs @@ -14,6 +14,8 @@ pub struct Model { pub max_tokens_per_minute: i64, pub max_tokens_per_day: i64, pub price_per_million_input_tokens: i32, + pub price_per_million_cache_creation_input_tokens: i32, + pub price_per_million_cache_read_input_tokens: i32, pub price_per_million_output_tokens: i32, } diff --git a/crates/collab/src/llm/db/tables/usage_measure.rs b/crates/collab/src/llm/db/tables/usage_measure.rs index 1105d997c2bcea..50c9501e54f1d6 100644 --- a/crates/collab/src/llm/db/tables/usage_measure.rs +++ b/crates/collab/src/llm/db/tables/usage_measure.rs @@ -10,6 +10,8 @@ pub enum UsageMeasure { TokensPerMinute, TokensPerDay, InputTokensPerMonth, + CacheCreationInputTokensPerMonth, + CacheReadInputTokensPerMonth, OutputTokensPerMonth, } diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 905a3dda08101f..97bcc20e44d1f3 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -33,12 +33,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { let user_id = UserId::from_proto(123); let now = t0; - db.record_usage(user_id, false, provider, model, 1000, 0, now) + db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now) .await .unwrap(); let now = t0 + Duration::seconds(10); - db.record_usage(user_id, false, provider, model, 2000, 0, now) + db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now) .await .unwrap(); @@ -50,6 +50,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 3000, tokens_this_day: 3000, input_tokens_this_month: 3000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, @@ -65,6 +67,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 2000, tokens_this_day: 3000, input_tokens_this_month: 3000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, @@ -72,7 +76,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { ); let now = t0 + Duration::seconds(60); - db.record_usage(user_id, false, provider, model, 3000, 0, now) + db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now) .await .unwrap(); @@ -84,6 +88,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 5000, tokens_this_day: 6000, input_tokens_this_month: 6000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, @@ -100,13 +106,15 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 0, tokens_this_day: 5000, input_tokens_this_month: 6000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, } ); - db.record_usage(user_id, false, provider, model, 4000, 0, now) + db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now) .await .unwrap(); @@ -118,6 +126,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 4000, tokens_this_day: 9000, input_tokens_this_month: 10000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, @@ -134,6 +144,50 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { tokens_this_minute: 0, tokens_this_day: 0, input_tokens_this_month: 9000, + cache_creation_input_tokens_this_month: 0, + cache_read_input_tokens_this_month: 0, + output_tokens_this_month: 0, + spending_this_month: 0, + lifetime_spending: 0, + } + ); + + // Test cache creation input tokens + db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now) + .await + .unwrap(); + + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 1, + tokens_this_minute: 1500, + tokens_this_day: 1500, + input_tokens_this_month: 10000, + cache_creation_input_tokens_this_month: 500, + cache_read_input_tokens_this_month: 0, + output_tokens_this_month: 0, + spending_this_month: 0, + lifetime_spending: 0, + } + ); + + // Test cache read input tokens + db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now) + .await + .unwrap(); + + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 2, + tokens_this_minute: 2800, + tokens_this_day: 2800, + input_tokens_this_month: 11000, + cache_creation_input_tokens_this_month: 500, + cache_read_input_tokens_this_month: 300, output_tokens_this_month: 0, spending_this_month: 0, lifetime_spending: 0, diff --git a/crates/collab/src/llm/telemetry.rs b/crates/collab/src/llm/telemetry.rs index 17a2cb9cd3389d..9daaaf3032090e 100644 --- a/crates/collab/src/llm/telemetry.rs +++ b/crates/collab/src/llm/telemetry.rs @@ -12,11 +12,15 @@ pub struct LlmUsageEventRow { pub model: String, pub provider: String, pub input_token_count: u64, + pub cache_creation_input_token_count: u64, + pub cache_read_input_token_count: u64, pub output_token_count: u64, pub requests_this_minute: u64, pub tokens_this_minute: u64, pub tokens_this_day: u64, pub input_tokens_this_month: u64, + pub cache_creation_input_tokens_this_month: u64, + pub cache_read_input_tokens_this_month: u64, pub output_tokens_this_month: u64, pub spending_this_month: u64, pub lifetime_spending: u64,