diff --git a/crates/collab/migrations_llm/20241008155620_create_monthly_usages.sql b/crates/collab/migrations_llm/20241008155620_create_monthly_usages.sql new file mode 100644 index 00000000000000..2733552a3a16f2 --- /dev/null +++ b/crates/collab/migrations_llm/20241008155620_create_monthly_usages.sql @@ -0,0 +1,13 @@ +create table monthly_usages ( + id serial primary key, + user_id integer not null, + model_id integer not null references models (id) on delete cascade, + month integer not null, + year integer not null, + input_tokens bigint not null default 0, + cache_creation_input_tokens bigint not null default 0, + cache_read_input_tokens bigint not null default 0, + output_tokens bigint not null default 0 +); + +create unique index uix_monthly_usages_on_user_id_model_id_month_year on monthly_usages (user_id, model_id, month, year); diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 23a16590cac630..b70fc1e3ba32d3 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -22,12 +22,15 @@ use stripe::{ }; use util::ResultExt; -use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::billing_subscription::{self, StripeSubscriptionStatus}; use crate::db::{ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, }; +use crate::llm::db::LlmDatabase; +use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS; +use crate::rpc::ResultExt as _; use crate::{AppState, Error, Result}; pub fn router() -> Router { @@ -79,7 +82,7 @@ async fn list_billing_subscriptions( .into_iter() .map(|subscription| BillingSubscriptionJson { id: subscription.id, - name: "Zed Pro".to_string(), + name: "Zed LLM Usage".to_string(), status: subscription.stripe_subscription_status, cancel_at: subscription.stripe_cancel_at.map(|cancel_at| { cancel_at @@ -117,7 +120,7 @@ async fn create_billing_subscription( let Some((stripe_client, stripe_price_id)) = app .stripe_client .clone() - .zip(app.config.stripe_price_id.clone()) + .zip(app.config.stripe_llm_usage_price_id.clone()) else { log::error!("failed to retrieve Stripe client or price ID"); Err(Error::http( @@ -150,7 +153,7 @@ async fn create_billing_subscription( params.client_reference_id = Some(user.github_login.as_str()); params.line_items = Some(vec![CreateCheckoutSessionLineItems { price: Some(stripe_price_id.to_string()), - quantity: Some(1), + quantity: Some(0), ..Default::default() }]); let success_url = format!("{}/account", app.config.zed_dot_dev_url()); @@ -631,3 +634,95 @@ async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } + +const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); + +pub fn sync_llm_usage_with_stripe_periodically(app: Arc, llm_db: LlmDatabase) { + let Some(stripe_client) = app.stripe_client.clone() else { + log::warn!("failed to retrieve Stripe client"); + return; + }; + let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else { + log::warn!("failed to retrieve Stripe LLM usage price ID"); + return; + }; + + let executor = app.executor.clone(); + executor.spawn_detached({ + let executor = executor.clone(); + async move { + loop { + sync_with_stripe( + &app, + &llm_db, + &stripe_client, + stripe_llm_usage_price_id.clone(), + ) + .await + .trace_err(); + + executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await; + } + } + }); +} + +async fn sync_with_stripe( + app: &Arc, + llm_db: &LlmDatabase, + stripe_client: &stripe::Client, + stripe_llm_usage_price_id: Arc, +) -> anyhow::Result<()> { + let subscriptions = app.db.get_active_billing_subscriptions().await?; + + for (customer, subscription) in subscriptions { + update_stripe_subscription( + llm_db, + stripe_client, + &stripe_llm_usage_price_id, + customer, + subscription, + ) + .await + .log_err(); + } + + Ok(()) +} + +async fn update_stripe_subscription( + llm_db: &LlmDatabase, + stripe_client: &stripe::Client, + stripe_llm_usage_price_id: &Arc, + customer: billing_customer::Model, + subscription: billing_subscription::Model, +) -> Result<(), anyhow::Error> { + let monthly_spending = llm_db + .get_user_spending_for_month(customer.user_id, Utc::now()) + .await?; + let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) + .context("failed to parse subscription ID")?; + + let monthly_spending_over_free_tier = + monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS); + + let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil(); + Subscription::update( + stripe_client, + &subscription_id, + stripe::UpdateSubscription { + items: Some(vec![stripe::UpdateSubscriptionItems { + // TODO: Do we need to send up the `id` if a subscription item + // with this price already exists, or will Stripe take care of + // it? + id: None, + price: Some(stripe_llm_usage_price_id.to_string()), + quantity: Some(new_quantity as u64), + ..Default::default() + }]), + ..Default::default() + }, + ) + .await?; + Ok(()) +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 7a7ba31f166988..bcf093bebd4240 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -112,6 +112,29 @@ impl Database { .await } + pub async fn get_active_billing_subscriptions( + &self, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut rows = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .select_also(billing_customer::Entity) + .order_by_asc(billing_subscription::Column::Id) + .stream(&*tx) + .await?; + + while let Some(row) = rows.next().await { + if let (subscription, Some(customer)) = row? { + result.push((customer, subscription)); + } + } + + Ok(result) + }) + .await + } + /// Returns whether the user has an active billing subscription. pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { Ok(self.count_active_billing_subscriptions(user_id).await? > 0) diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 6c32023a97a287..ccecf80087aacb 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -174,7 +174,7 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, - pub stripe_price_id: Option>, + pub stripe_llm_usage_price_id: Option>, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -193,6 +193,10 @@ impl Config { } } + pub fn is_llm_billing_enabled(&self) -> bool { + self.stripe_llm_usage_price_id.is_some() + } + #[cfg(test)] pub fn test() -> Self { Self { @@ -231,7 +235,7 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_price_id: None, + stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, } diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 9809985ac72b23..96413cf7c56710 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -436,6 +436,9 @@ fn normalize_model_name(known_models: Vec, name: String) -> String { } } +/// The maximum monthly spending an individual user can reach before they have to pay. +pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100; + /// The maximum lifetime spending an individual user can reach before being cut off. /// /// Represented in cents. @@ -458,6 +461,18 @@ async fn check_usage_limit( ) .await?; + if state.config.is_llm_billing_enabled() { + if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS { + if !claims.has_llm_subscription.unwrap_or(false) { + return Err(Error::http( + StatusCode::PAYMENT_REQUIRED, + "Maximum spending limit reached for this month.".to_string(), + )); + } + } + } + + // TODO: Remove this once we've rolled out monthly spending limits. if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS { return Err(Error::http( StatusCode::FORBIDDEN, @@ -505,7 +520,6 @@ async fn check_usage_limit( UsageMeasure::RequestsPerMinute => "requests_per_minute", UsageMeasure::TokensPerMinute => "tokens_per_minute", UsageMeasure::TokensPerDay => "tokens_per_day", - _ => "", }; if let Some(client) = state.clickhouse_client.as_ref() { diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index d46f51bb0df594..996837116b190d 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -97,6 +97,14 @@ impl LlmDatabase { .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?) } + pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> { + Ok(self + .models + .values() + .find(|model| model.id == id) + .ok_or_else(|| anyhow!("no model for ID {id:?}"))?) + } + pub fn options(&self) -> &ConnectOptions { &self.options } diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index d703066913f8f5..1a98685bcddb4e 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,5 +1,5 @@ use crate::db::UserId; -use chrono::Duration; +use chrono::{Datelike, Duration}; use futures::StreamExt as _; use rpc::LanguageModelProvider; use sea_orm::QuerySelect; @@ -140,6 +140,46 @@ impl LlmDatabase { .await } + pub async fn get_user_spending_for_month( + &self, + user_id: UserId, + now: DateTimeUtc, + ) -> Result { + self.transaction(|tx| async move { + let month = now.date_naive().month() as i32; + let year = now.date_naive().year(); + + let mut monthly_usages = monthly_usage::Entity::find() + .filter( + monthly_usage::Column::UserId + .eq(user_id) + .and(monthly_usage::Column::Month.eq(month)) + .and(monthly_usage::Column::Year.eq(year)), + ) + .stream(&*tx) + .await?; + let mut monthly_spending_in_cents = 0; + + while let Some(usage) = monthly_usages.next().await { + let usage = usage?; + let Ok(model) = self.model_by_id(usage.model_id) else { + continue; + }; + + monthly_spending_in_cents += calculate_spending( + model, + usage.input_tokens as usize, + usage.cache_creation_input_tokens as usize, + usage.cache_read_input_tokens as usize, + usage.output_tokens as usize, + ); + } + + Ok(monthly_spending_in_cents) + }) + .await + } + pub async fn get_usage( &self, user_id: UserId, @@ -162,6 +202,18 @@ impl LlmDatabase { .all(&*tx) .await?; + let month = now.date_naive().month() as i32; + let year = now.date_naive().year(); + let monthly_usage = monthly_usage::Entity::find() + .filter( + monthly_usage::Column::UserId + .eq(user_id) + .and(monthly_usage::Column::ModelId.eq(model.id)) + .and(monthly_usage::Column::Month.eq(month)) + .and(monthly_usage::Column::Year.eq(year)), + ) + .one(&*tx) + .await?; let lifetime_usage = lifetime_usage::Entity::find() .filter( lifetime_usage::Column::UserId @@ -177,28 +229,18 @@ impl LlmDatabase { self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?; let tokens_this_day = self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?; - let input_tokens_this_month = - self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?; - let cache_creation_input_tokens_this_month = self.get_usage_for_measure( - &usages, - now, - UsageMeasure::CacheCreationInputTokensPerMonth, - )?; - let cache_read_input_tokens_this_month = self.get_usage_for_measure( - &usages, - now, - UsageMeasure::CacheReadInputTokensPerMonth, - )?; - let output_tokens_this_month = - self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?; - let spending_this_month = calculate_spending( - model, - input_tokens_this_month, - cache_creation_input_tokens_this_month, - cache_read_input_tokens_this_month, - output_tokens_this_month, - ); - let lifetime_spending = if let Some(lifetime_usage) = lifetime_usage { + let spending_this_month = if let Some(monthly_usage) = &monthly_usage { + calculate_spending( + model, + monthly_usage.input_tokens as usize, + monthly_usage.cache_creation_input_tokens as usize, + monthly_usage.cache_read_input_tokens as usize, + monthly_usage.output_tokens as usize, + ) + } else { + 0 + }; + let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage { calculate_spending( model, lifetime_usage.input_tokens as usize, @@ -214,10 +256,18 @@ impl LlmDatabase { requests_this_minute, tokens_this_minute, tokens_this_day, - input_tokens_this_month, - cache_creation_input_tokens_this_month, - cache_read_input_tokens_this_month, - output_tokens_this_month, + input_tokens_this_month: monthly_usage + .as_ref() + .map_or(0, |usage| usage.input_tokens as usize), + cache_creation_input_tokens_this_month: monthly_usage + .as_ref() + .map_or(0, |usage| usage.cache_creation_input_tokens as usize), + cache_read_input_tokens_this_month: monthly_usage + .as_ref() + .map_or(0, |usage| usage.cache_read_input_tokens as usize), + output_tokens_this_month: monthly_usage + .as_ref() + .map_or(0, |usage| usage.output_tokens as usize), spending_this_month, lifetime_spending, }) @@ -290,60 +340,68 @@ impl LlmDatabase { &tx, ) .await?; - let input_tokens_this_month = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::InputTokensPerMonth, - now, - input_token_count, - &tx, - ) - .await?; - let cache_creation_input_tokens_this_month = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::CacheCreationInputTokensPerMonth, - now, - cache_creation_input_tokens, - &tx, - ) - .await?; - let cache_read_input_tokens_this_month = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::CacheReadInputTokensPerMonth, - now, - cache_read_input_tokens, - &tx, - ) - .await?; - let output_tokens_this_month = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::OutputTokensPerMonth, - now, - output_token_count, - &tx, + + let month = now.date_naive().month() as i32; + let year = now.date_naive().year(); + + // Update monthly usage + let monthly_usage = monthly_usage::Entity::find() + .filter( + monthly_usage::Column::UserId + .eq(user_id) + .and(monthly_usage::Column::ModelId.eq(model.id)) + .and(monthly_usage::Column::Month.eq(month)) + .and(monthly_usage::Column::Year.eq(year)), ) + .one(&*tx) .await?; + + let monthly_usage = match monthly_usage { + Some(usage) => { + monthly_usage::Entity::update(monthly_usage::ActiveModel { + id: ActiveValue::unchanged(usage.id), + input_tokens: ActiveValue::set( + usage.input_tokens + input_token_count as i64, + ), + cache_creation_input_tokens: ActiveValue::set( + usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, + ), + cache_read_input_tokens: ActiveValue::set( + usage.cache_read_input_tokens + cache_read_input_tokens as i64, + ), + output_tokens: ActiveValue::set( + usage.output_tokens + output_token_count as i64, + ), + ..Default::default() + }) + .exec(&*tx) + .await? + } + None => { + monthly_usage::ActiveModel { + user_id: ActiveValue::set(user_id), + model_id: ActiveValue::set(model.id), + month: ActiveValue::set(month), + year: ActiveValue::set(year), + input_tokens: ActiveValue::set(input_token_count as i64), + cache_creation_input_tokens: ActiveValue::set( + cache_creation_input_tokens as i64, + ), + cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64), + output_tokens: ActiveValue::set(output_token_count as i64), + ..Default::default() + } + .insert(&*tx) + .await? + } + }; + let spending_this_month = calculate_spending( model, - input_tokens_this_month, - cache_creation_input_tokens_this_month, - cache_read_input_tokens_this_month, - output_tokens_this_month, + monthly_usage.input_tokens as usize, + monthly_usage.cache_creation_input_tokens as usize, + monthly_usage.cache_read_input_tokens as usize, + monthly_usage.output_tokens as usize, ); // Update lifetime usage @@ -406,10 +464,11 @@ impl LlmDatabase { requests_this_minute, tokens_this_minute, tokens_this_day, - input_tokens_this_month, - cache_creation_input_tokens_this_month, - cache_read_input_tokens_this_month, - output_tokens_this_month, + input_tokens_this_month: monthly_usage.input_tokens as usize, + cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens + as usize, + cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize, + output_tokens_this_month: monthly_usage.output_tokens as usize, spending_this_month, lifetime_spending, }) @@ -597,7 +656,6 @@ fn calculate_spending( const MINUTE_BUCKET_COUNT: usize = 12; const DAY_BUCKET_COUNT: usize = 48; -const MONTH_BUCKET_COUNT: usize = 30; impl UsageMeasure { fn bucket_count(&self) -> usize { @@ -605,10 +663,6 @@ impl UsageMeasure { UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT, UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT, UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT, - UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT, - UsageMeasure::CacheCreationInputTokensPerMonth => MONTH_BUCKET_COUNT, - UsageMeasure::CacheReadInputTokensPerMonth => MONTH_BUCKET_COUNT, - UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT, } } @@ -617,10 +671,6 @@ impl UsageMeasure { UsageMeasure::RequestsPerMinute => Duration::minutes(1), UsageMeasure::TokensPerMinute => Duration::minutes(1), UsageMeasure::TokensPerDay => Duration::hours(24), - UsageMeasure::InputTokensPerMonth => Duration::days(30), - UsageMeasure::CacheCreationInputTokensPerMonth => Duration::days(30), - UsageMeasure::CacheReadInputTokensPerMonth => Duration::days(30), - UsageMeasure::OutputTokensPerMonth => Duration::days(30), } } diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 4beefe2b5d45d7..57aded70e91d29 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,5 +1,6 @@ pub mod lifetime_usage; pub mod model; +pub mod monthly_usage; pub mod provider; pub mod revoked_access_token; pub mod usage; diff --git a/crates/collab/src/llm/db/tables/monthly_usage.rs b/crates/collab/src/llm/db/tables/monthly_usage.rs new file mode 100644 index 00000000000000..1e849f6aefc585 --- /dev/null +++ b/crates/collab/src/llm/db/tables/monthly_usage.rs @@ -0,0 +1,22 @@ +use crate::{db::UserId, llm::db::ModelId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "monthly_usages")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub user_id: UserId, + pub model_id: ModelId, + pub month: i32, + pub year: i32, + pub input_tokens: i64, + pub cache_creation_input_tokens: i64, + pub cache_read_input_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/usage_measure.rs b/crates/collab/src/llm/db/tables/usage_measure.rs index 50c9501e54f1d6..b0e5b866447ed0 100644 --- a/crates/collab/src/llm/db/tables/usage_measure.rs +++ b/crates/collab/src/llm/db/tables/usage_measure.rs @@ -9,10 +9,6 @@ pub enum UsageMeasure { RequestsPerMinute, TokensPerMinute, TokensPerDay, - InputTokensPerMonth, - CacheCreationInputTokensPerMonth, - CacheReadInputTokensPerMonth, - OutputTokensPerMonth, } #[derive(Clone, Debug, PartialEq, DeriveEntityModel)] diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 97bcc20e44d1f3..8e8dc0ff6b240b 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -6,7 +6,7 @@ use crate::{ }, test_llm_db, }; -use chrono::{Duration, Utc}; +use chrono::{DateTime, Duration, Utc}; use pretty_assertions::assert_eq; use rpc::LanguageModelProvider; @@ -29,7 +29,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { .await .unwrap(); - let t0 = Utc::now(); + // We're using a fixed datetime to prevent flakiness based on the clock. + let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z") + .unwrap() + .with_timezone(&Utc); let user_id = UserId::from_proto(123); let now = t0; @@ -134,23 +137,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { } ); - let t2 = t0 + Duration::days(30); - let now = t2; - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 0, - tokens_this_minute: 0, - tokens_this_day: 0, - input_tokens_this_month: 9000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, - } - ); + // We're using a fixed datetime to prevent flakiness based on the clock. + let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z") + .unwrap() + .with_timezone(&Utc); // Test cache creation input tokens db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now) @@ -164,7 +154,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 1, tokens_this_minute: 1500, tokens_this_day: 1500, - input_tokens_this_month: 10000, + input_tokens_this_month: 1000, cache_creation_input_tokens_this_month: 500, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, @@ -185,7 +175,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 2, tokens_this_minute: 2800, tokens_this_day: 2800, - input_tokens_this_month: 11000, + input_tokens_this_month: 2000, cache_creation_input_tokens_this_month: 500, cache_read_input_tokens_this_month: 300, output_tokens_this_month: 0, diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index e1e6c7332627dc..2f6ce6ee286ccc 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -22,6 +22,12 @@ pub struct LlmTokenClaims { pub is_staff: bool, #[serde(default)] pub has_llm_closed_beta_feature_flag: bool, + // This field is temporarily optional so it can be added + // in a backwards-compatible way. We can make it required + // once all of the LLM tokens have cycled (~1 hour after + // this change has been deployed). + #[serde(default)] + pub has_llm_subscription: Option, pub plan: rpc::proto::Plan, } @@ -33,6 +39,7 @@ impl LlmTokenClaims { github_user_login: String, is_staff: bool, has_llm_closed_beta_feature_flag: bool, + has_llm_subscription: bool, plan: rpc::proto::Plan, config: &Config, ) -> Result { @@ -50,6 +57,7 @@ impl LlmTokenClaims { github_user_login: Some(github_user_login), is_staff, has_llm_closed_beta_feature_flag, + has_llm_subscription: Some(has_llm_subscription), plan, }; diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 0e6bb67d13db37..bbbd4e562cb901 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -6,6 +6,7 @@ use axum::{ routing::get, Extension, Router, }; +use collab::api::billing::sync_llm_usage_with_stripe_periodically; use collab::api::CloudflareIpCountryHeader; use collab::llm::{db::LlmDatabase, log_usage_periodically}; use collab::migrations::run_database_migrations; @@ -29,7 +30,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer, }; -use util::ResultExt as _; +use util::{maybe, ResultExt as _}; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -136,6 +137,28 @@ async fn main() -> Result<()> { fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone()); + let llm_db = maybe!(async { + let database_url = state + .config + .llm_database_url + .as_ref() + .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; + let max_connections = state + .config + .llm_database_max_connections + .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; + + let mut db_options = db::ConnectOptions::new(database_url); + db_options.max_connections(max_connections); + LlmDatabase::new(db_options, state.executor.clone()).await + }) + .await + .trace_err(); + + if let Some(llm_db) = llm_db { + sync_llm_usage_with_stripe_periodically(state.clone(), llm_db); + } + app = app .merge(collab::api::events::router()) .merge(collab::api::extensions::router()) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 27c95a5b44e1a8..e66c306c506ddb 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -191,16 +191,26 @@ impl Session { } } - pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result { + pub async fn has_llm_subscription( + &self, + db: &MutexGuard<'_, DbHandle>, + ) -> anyhow::Result { if self.is_staff() { - return Ok(proto::Plan::ZedPro); + return Ok(true); } let Some(user_id) = self.user_id() else { - return Ok(proto::Plan::Free); + return Ok(false); }; - if db.has_active_billing_subscription(user_id).await? { + Ok(db.has_active_billing_subscription(user_id).await?) + } + + pub async fn current_plan( + &self, + _db: &MutexGuard<'_, DbHandle>, + ) -> anyhow::Result { + if self.is_staff() { Ok(proto::Plan::ZedPro) } else { Ok(proto::Plan::Free) @@ -3471,7 +3481,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { } async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> { - let plan = session.current_plan(session.db().await).await?; + let plan = session.current_plan(&session.db().await).await?; session .peer @@ -4471,7 +4481,7 @@ async fn count_language_model_tokens( }; authorize_access_to_legacy_llm_endpoints(&session).await?; - let rate_limit: Box = match session.current_plan(session.db().await).await? { + let rate_limit: Box = match session.current_plan(&session.db().await).await? { proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit), }; @@ -4592,7 +4602,7 @@ async fn compute_embeddings( let api_key = api_key.context("no OpenAI API key configured on the server")?; authorize_access_to_legacy_llm_endpoints(&session).await?; - let rate_limit: Box = match session.current_plan(session.db().await).await? { + let rate_limit: Box = match session.current_plan(&session.db().await).await? { proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit), }; @@ -4915,7 +4925,8 @@ async fn get_llm_api_token( user.github_login.clone(), session.is_staff(), has_llm_closed_beta_feature_flag, - session.current_plan(db).await?, + session.has_llm_subscription(&db).await?, + session.current_plan(&db).await?, &session.app_state.config, )?; response.send(proto::GetLlmTokenResponse { token })?; diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 8d2396eef08181..55bc279c8eaf6e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -677,7 +677,7 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_price_id: None, + stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, },