Skip to content

Commit

Permalink
collab: Update billing code for LLM usage billing (#18879)
Browse files Browse the repository at this point in the history
This PR reworks our existing billing code in preparation for charging
based on LLM usage.

We aren't yet exercising the new billing-related code outside of
development.

There are some noteworthy changes for our existing LLM usage tracking:

- A new `monthly_usages` table has been added for tracking usage
per-user, per-model, per-month
- The per-month usage measures have been removed, in favor of the
`monthly_usages` table
- All of the per-month metrics in the Clickhouse rows have been changed
from a rolling 30-day window to a calendar month

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <[email protected]>
Co-authored-by: Richard <[email protected]>
Co-authored-by: Max <[email protected]>
  • Loading branch information
4 people authored Oct 8, 2024
1 parent a95fb8f commit f861479
Show file tree
Hide file tree
Showing 15 changed files with 390 additions and 132 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
create table monthly_usages (
id serial primary key,
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
month integer not null,
year integer not null,
input_tokens bigint not null default 0,
cache_creation_input_tokens bigint not null default 0,
cache_read_input_tokens bigint not null default 0,
output_tokens bigint not null default 0
);

create unique index uix_monthly_usages_on_user_id_model_id_month_year on monthly_usages (user_id, model_id, month, year);
103 changes: 99 additions & 4 deletions crates/collab/src/api/billing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ use stripe::{
};
use util::ResultExt;

use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS;
use crate::rpc::ResultExt as _;
use crate::{AppState, Error, Result};

pub fn router() -> Router {
Expand Down Expand Up @@ -79,7 +82,7 @@ async fn list_billing_subscriptions(
.into_iter()
.map(|subscription| BillingSubscriptionJson {
id: subscription.id,
name: "Zed Pro".to_string(),
name: "Zed LLM Usage".to_string(),
status: subscription.stripe_subscription_status,
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at
Expand Down Expand Up @@ -117,7 +120,7 @@ async fn create_billing_subscription(
let Some((stripe_client, stripe_price_id)) = app
.stripe_client
.clone()
.zip(app.config.stripe_price_id.clone())
.zip(app.config.stripe_llm_usage_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::http(
Expand Down Expand Up @@ -150,7 +153,7 @@ async fn create_billing_subscription(
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_price_id.to_string()),
quantity: Some(1),
quantity: Some(0),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
Expand Down Expand Up @@ -631,3 +634,95 @@ async fn find_or_create_billing_customer(

Ok(Some(billing_customer))
}

const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);

pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
log::warn!("failed to retrieve Stripe LLM usage price ID");
return;
};

let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
sync_with_stripe(
&app,
&llm_db,
&stripe_client,
stripe_llm_usage_price_id.clone(),
)
.await
.trace_err();

executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
}
}
});
}

async fn sync_with_stripe(
app: &Arc<AppState>,
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: Arc<str>,
) -> anyhow::Result<()> {
let subscriptions = app.db.get_active_billing_subscriptions().await?;

for (customer, subscription) in subscriptions {
update_stripe_subscription(
llm_db,
stripe_client,
&stripe_llm_usage_price_id,
customer,
subscription,
)
.await
.log_err();
}

Ok(())
}

async fn update_stripe_subscription(
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: &Arc<str>,
customer: billing_customer::Model,
subscription: billing_subscription::Model,
) -> Result<(), anyhow::Error> {
let monthly_spending = llm_db
.get_user_spending_for_month(customer.user_id, Utc::now())
.await?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;

let monthly_spending_over_free_tier =
monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS);

let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil();
Subscription::update(
stripe_client,
&subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
// TODO: Do we need to send up the `id` if a subscription item
// with this price already exists, or will Stripe take care of
// it?
id: None,
price: Some(stripe_llm_usage_price_id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]),
..Default::default()
},
)
.await?;
Ok(())
}
23 changes: 23 additions & 0 deletions crates/collab/src/db/queries/billing_subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ impl Database {
.await
}

pub async fn get_active_billing_subscriptions(
&self,
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| async move {
let mut result = Vec::new();
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;

while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
result.push((customer, subscription));
}
}

Ok(result)
})
.await
}

/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
Expand Down
8 changes: 6 additions & 2 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_price_id: Option<Arc<str>>,
pub stripe_llm_usage_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
Expand All @@ -193,6 +193,10 @@ impl Config {
}
}

pub fn is_llm_billing_enabled(&self) -> bool {
self.stripe_llm_usage_price_id.is_some()
}

#[cfg(test)]
pub fn test() -> Self {
Self {
Expand Down Expand Up @@ -231,7 +235,7 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_price_id: None,
stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
}
Expand Down
16 changes: 15 additions & 1 deletion crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
}
}

/// The maximum monthly spending an individual user can reach before they have to pay.
pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100;

/// The maximum lifetime spending an individual user can reach before being cut off.
///
/// Represented in cents.
Expand All @@ -458,6 +461,18 @@ async fn check_usage_limit(
)
.await?;

if state.config.is_llm_billing_enabled() {
if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS {
if !claims.has_llm_subscription.unwrap_or(false) {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"Maximum spending limit reached for this month.".to_string(),
));
}
}
}

// TODO: Remove this once we've rolled out monthly spending limits.
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
return Err(Error::http(
StatusCode::FORBIDDEN,
Expand Down Expand Up @@ -505,7 +520,6 @@ async fn check_usage_limit(
UsageMeasure::RequestsPerMinute => "requests_per_minute",
UsageMeasure::TokensPerMinute => "tokens_per_minute",
UsageMeasure::TokensPerDay => "tokens_per_day",
_ => "",
};

if let Some(client) = state.clickhouse_client.as_ref() {
Expand Down
8 changes: 8 additions & 0 deletions crates/collab/src/llm/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ impl LlmDatabase {
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
}

pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
Ok(self
.models
.values()
.find(|model| model.id == id)
.ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
}

pub fn options(&self) -> &ConnectOptions {
&self.options
}
Expand Down
Loading

0 comments on commit f861479

Please sign in to comment.