Detailed changes
@@ -839,9 +839,8 @@ dependencies = [
[[package]]
name = "async-stripe"
-version = "0.39.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "58d670cf4d47a1b8ffef54286a5625382e360a34ee76902fd93ad8c7032a0c30"
+version = "0.40.0"
+source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
dependencies = [
"chrono",
"futures-util",
@@ -480,7 +480,8 @@ which = "6.0.0"
wit-component = "0.201"
[workspace.dependencies.async-stripe]
-version = "0.39"
+git = "https://github.com/zed-industries/async-stripe"
+rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
default-features = false
features = [
"runtime-tokio-hyper-rustls",
@@ -0,0 +1,12 @@
+create table billing_events (
+ id serial primary key,
+ idempotency_key uuid not null default gen_random_uuid(),
+ user_id integer not null,
+ model_id integer not null references models (id) on delete cascade,
+ input_tokens bigint not null default 0,
+ input_cache_creation_tokens bigint not null default 0,
+ input_cache_read_tokens bigint not null default 0,
+ output_tokens bigint not null default 0
+);
+
+create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
@@ -1,7 +1,3 @@
-use std::str::FromStr;
-use std::sync::Arc;
-use std::time::Duration;
-
use anyhow::{anyhow, bail, Context};
use axum::{
extract::{self, Query},
@@ -9,28 +5,35 @@ use axum::{
Extension, Json, Router,
};
use chrono::{DateTime, SecondsFormat, Utc};
+use collections::HashSet;
use reqwest::StatusCode;
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
+use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{
- BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
- CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
+ BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
+ CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
- CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
- CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
- Subscription, SubscriptionId, SubscriptionStatus,
+ CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
+ EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::ResultExt;
-use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
-use crate::db::{
- billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
- CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
- UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
-};
-use crate::llm::db::LlmDatabase;
-use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
+use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
use crate::rpc::ResultExt as _;
+use crate::{
+ db::{
+ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
+ CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
+ UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
+ UpdateBillingSubscriptionParams,
+ },
+ stripe_billing::StripeBilling,
+};
+use crate::{
+ db::{billing_subscription::StripeSubscriptionStatus, UserId},
+ llm::db::LlmDatabase,
+};
use crate::{AppState, Error, Result};
pub fn router() -> Router {
@@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
async fn update_billing_preferences(
Extension(app): Extension<Arc<AppState>>,
+ Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
) -> Result<Json<BillingPreferencesResponse>> {
let user = app
@@ -119,6 +123,8 @@ async fn update_billing_preferences(
.await?
};
+ rpc_server.refresh_llm_tokens_for_user(user.id).await;
+
Ok(Json(BillingPreferencesResponse {
max_monthly_llm_usage_spending_in_cents: billing_preferences
.max_monthly_llm_usage_spending_in_cents,
@@ -197,12 +203,15 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;
- let Some((stripe_client, stripe_access_price_id)) = app
- .stripe_client
- .clone()
- .zip(app.config.stripe_llm_access_price_id.clone())
- else {
- log::error!("failed to retrieve Stripe client or price ID");
+ let Some(stripe_client) = app.stripe_client.clone() else {
+ log::error!("failed to retrieve Stripe client");
+ Err(Error::http(
+ StatusCode::NOT_IMPLEMENTED,
+ "not supported".into(),
+ ))?
+ };
+ let Some(llm_db) = app.llm_db.clone() else {
+ log::error!("failed to retrieve LLM database");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
@@ -226,26 +235,15 @@ async fn create_billing_subscription(
customer.id
};
- let checkout_session = {
- let mut params = CreateCheckoutSession::new();
- params.mode = Some(stripe::CheckoutSessionMode::Subscription);
- params.customer = Some(customer_id);
- params.client_reference_id = Some(user.github_login.as_str());
- params.line_items = Some(vec![CreateCheckoutSessionLineItems {
- price: Some(stripe_access_price_id.to_string()),
- quantity: Some(1),
- ..Default::default()
- }]);
- let success_url = format!("{}/account", app.config.zed_dot_dev_url());
- params.success_url = Some(&success_url);
-
- CheckoutSession::create(&stripe_client, params).await?
- };
-
+ let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
+ let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
+ let stripe_model = stripe_billing.register_model(default_model).await?;
+ let success_url = format!("{}/account", app.config.zed_dot_dev_url());
+ let checkout_session_url = stripe_billing
+ .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
+ .await?;
Ok(Json(CreateBillingSubscriptionResponse {
- checkout_session_url: checkout_session
- .url
- .ok_or_else(|| anyhow!("no checkout session URL"))?,
+ checkout_session_url,
}))
}
@@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
-const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
-pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
+pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
- let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
- log::warn!("failed to retrieve Stripe LLM usage price ID");
+ let Some(llm_db) = app.llm_db.clone() else {
+ log::warn!("failed to retrieve LLM database");
return;
};
@@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
let executor = executor.clone();
async move {
loop {
- sync_with_stripe(
- &app,
- &llm_db,
- &stripe_client,
- stripe_llm_usage_price_id.clone(),
- )
- .await
- .trace_err();
-
+ sync_with_stripe(&app, &llm_db, &stripe_client)
+ .await
+ .trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
}
}
@@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
async fn sync_with_stripe(
app: &Arc<AppState>,
- llm_db: &LlmDatabase,
- stripe_client: &stripe::Client,
- stripe_llm_usage_price_id: Arc<str>,
+ llm_db: &Arc<LlmDatabase>,
+ stripe_client: &Arc<stripe::Client>,
) -> anyhow::Result<()> {
- let subscriptions = app.db.get_active_billing_subscriptions().await?;
-
- for (customer, subscription) in subscriptions {
- update_stripe_subscription(
- llm_db,
- stripe_client,
- &stripe_llm_usage_price_id,
- customer,
- subscription,
- )
- .await
- .log_err();
- }
-
- Ok(())
-}
-
-async fn update_stripe_subscription(
- llm_db: &LlmDatabase,
- stripe_client: &stripe::Client,
- stripe_llm_usage_price_id: &Arc<str>,
- customer: billing_customer::Model,
- subscription: billing_subscription::Model,
-) -> Result<(), anyhow::Error> {
- let monthly_spending = llm_db
- .get_user_spending_for_month(customer.user_id, Utc::now())
- .await?;
- let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
- .context("failed to parse subscription ID")?;
-
- let monthly_spending_over_free_tier =
- monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
-
- let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
- let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
-
- let mut update_params = stripe::UpdateSubscription {
- proration_behavior: Some(
- stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
- ),
- ..Default::default()
- };
-
- if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
- item.price.as_ref().map_or(false, |price| {
- price.id == stripe_llm_usage_price_id.as_ref()
- })
- }) {
- update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
- id: Some(existing_item.id.to_string()),
- quantity: Some(new_quantity as u64),
- ..Default::default()
- }]);
- } else {
- update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
- price: Some(stripe_llm_usage_price_id.to_string()),
- quantity: Some(new_quantity as u64),
- ..Default::default()
- }]);
+ let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
+
+ let events = llm_db.get_billing_events().await?;
+ let user_ids = events
+ .iter()
+ .map(|(event, _)| event.user_id)
+ .collect::<HashSet<UserId>>();
+ let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
+
+ for (event, model) in events {
+ let Some((stripe_db_customer, stripe_db_subscription)) =
+ stripe_subscriptions.get(&event.user_id)
+ else {
+ tracing::warn!(
+ user_id = event.user_id.0,
+ "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
+ );
+ continue;
+ };
+ let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
+ .stripe_subscription_id
+ .parse()
+ .context("failed to parse stripe subscription id from db")?;
+ let stripe_customer_id: stripe::CustomerId = stripe_db_customer
+ .stripe_customer_id
+ .parse()
+ .context("failed to parse stripe customer id from db")?;
+
+ let stripe_model = stripe_billing.register_model(&model).await?;
+ stripe_billing
+ .subscribe_to_model(&stripe_subscription_id, &stripe_model)
+ .await?;
+ stripe_billing
+ .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
+ .await?;
+ llm_db.consume_billing_event(event.id).await?;
}
- Subscription::update(stripe_client, &subscription_id, update_params).await?;
Ok(())
}
@@ -114,23 +114,31 @@ impl Database {
pub async fn get_active_billing_subscriptions(
&self,
- ) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
- self.transaction(|tx| async move {
- let mut result = Vec::new();
- let mut rows = billing_subscription::Entity::find()
- .inner_join(billing_customer::Entity)
- .select_also(billing_customer::Entity)
- .order_by_asc(billing_subscription::Column::Id)
- .stream(&*tx)
- .await?;
-
- while let Some(row) = rows.next().await {
- if let (subscription, Some(customer)) = row? {
- result.push((customer, subscription));
+ user_ids: HashSet<UserId>,
+ ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
+ self.transaction(|tx| {
+ let user_ids = user_ids.clone();
+ async move {
+ let mut rows = billing_subscription::Entity::find()
+ .inner_join(billing_customer::Entity)
+ .select_also(billing_customer::Entity)
+ .filter(billing_customer::Column::UserId.is_in(user_ids))
+ .filter(
+ billing_subscription::Column::StripeSubscriptionStatus
+ .eq(StripeSubscriptionStatus::Active),
+ )
+ .order_by_asc(billing_subscription::Column::Id)
+ .stream(&*tx)
+ .await?;
+
+ let mut subscriptions = HashMap::default();
+ while let Some(row) = rows.next().await {
+ if let (subscription, Some(customer)) = row? {
+ subscriptions.insert(customer.user_id, (customer, subscription));
+ }
}
+ Ok(subscriptions)
}
-
- Ok(result)
})
.await
}
@@ -10,6 +10,7 @@ pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
+pub mod stripe_billing;
pub mod user_backfiller;
#[cfg(test)]
@@ -24,6 +25,7 @@ use axum::{
pub use cents::*;
use db::{ChannelId, Database};
use executor::Executor;
+use llm::db::LlmDatabase;
pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
@@ -176,8 +178,6 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
- pub stripe_llm_access_price_id: Option<Arc<str>>,
- pub stripe_llm_usage_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -197,7 +197,7 @@ impl Config {
}
pub fn is_llm_billing_enabled(&self) -> bool {
- self.stripe_llm_usage_price_id.is_some()
+ self.stripe_api_key.is_some()
}
#[cfg(test)]
@@ -238,8 +238,6 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
- stripe_llm_access_price_id: None,
- stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
}
@@ -272,6 +270,7 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
+ pub llm_db: Option<Arc<LlmDatabase>>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
@@ -288,6 +287,20 @@ impl AppState {
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
+ let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
+ .llm_database_url
+ .clone()
+ .zip(config.llm_database_max_connections)
+ {
+ let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
+ llm_db_options.max_connections(llm_database_max_connections);
+ let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
+ llm_db.initialize().await?;
+ Some(Arc::new(llm_db))
+ } else {
+ None
+ };
+
let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server
.as_ref()
@@ -306,9 +319,10 @@ impl AppState {
let db = Arc::new(db);
let this = Self {
db: db.clone(),
+ llm_db,
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
- stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
+ stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
@@ -321,12 +335,11 @@ impl AppState {
}
}
-async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
+fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config
.stripe_api_key
.as_ref()
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
-
Ok(stripe::Client::new(api_key))
}
@@ -20,13 +20,14 @@ use axum::{
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
+use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use isahc_http_client::IsahcHttpClient;
-use rpc::ListModelsResponse;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
+use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use std::{
pin::Pin,
sync::Arc,
@@ -418,10 +419,7 @@ async fn perform_completion(
claims,
provider: params.provider,
model,
- input_tokens: 0,
- output_tokens: 0,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
+ tokens: TokenUsage::default(),
inner_stream: stream,
})))
}
@@ -476,6 +474,19 @@ async fn check_usage_limit(
"Maximum spending limit reached for this month.".to_string(),
));
}
+
+ if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
+ return Err(Error::Http(
+ StatusCode::FORBIDDEN,
+ "Maximum spending limit reached for this month.".to_string(),
+ [(
+ HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
+ HeaderValue::from_static("true"),
+ )]
+ .into_iter()
+ .collect(),
+ ));
+ }
}
}
@@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
- input_tokens: usize,
- output_tokens: usize,
- cache_creation_input_tokens: usize,
- cache_read_input_tokens: usize,
+ tokens: TokenUsage,
inner_stream: S,
}
@@ -615,10 +623,10 @@ where
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
- self.input_tokens += chunk.input_tokens;
- self.output_tokens += chunk.output_tokens;
- self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
- self.cache_read_input_tokens += chunk.cache_read_input_tokens;
+ self.tokens.input += chunk.input_tokens;
+ self.tokens.output += chunk.output_tokens;
+ self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
+ self.tokens.input_cache_read += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
@@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
- let input_token_count = self.input_tokens;
- let output_token_count = self.output_tokens;
- let cache_creation_input_token_count = self.cache_creation_input_tokens;
- let cache_read_input_token_count = self.cache_read_input_tokens;
+ let tokens = self.tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
@@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
claims.is_staff,
provider,
&model,
- input_token_count,
- cache_creation_input_token_count,
- cache_read_input_token_count,
- output_token_count,
+ tokens,
+ claims.has_llm_subscription,
+ Cents(claims.max_monthly_spend_in_cents),
Utc::now(),
)
.await
@@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
},
model,
provider: provider.to_string(),
- input_token_count: input_token_count as u64,
- cache_creation_input_token_count: cache_creation_input_token_count
- as u64,
- cache_read_input_token_count: cache_read_input_token_count as u64,
- output_token_count: output_token_count as u64,
+ input_token_count: tokens.input as u64,
+ cache_creation_input_token_count: tokens.input_cache_creation as u64,
+ cache_read_input_token_count: tokens.input_cache_read as u64,
+ output_token_count: tokens.output as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
- input_tokens_this_month: usage.input_tokens_this_month as u64,
+ input_tokens_this_month: usage.tokens_this_month.input as u64,
cache_creation_input_tokens_this_month: usage
- .cache_creation_input_tokens_this_month
+ .tokens_this_month
+ .input_cache_creation
as u64,
cache_read_input_tokens_this_month: usage
- .cache_read_input_tokens_this_month
+ .tokens_this_month
+ .input_cache_read
as u64,
- output_tokens_this_month: usage.output_tokens_this_month as u64,
+ output_tokens_this_month: usage.tokens_this_month.output as u64,
spending_this_month: usage.spending_this_month.0 as u64,
lifetime_spending: usage.lifetime_spending.0 as u64,
},
@@ -20,7 +20,7 @@ use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
-pub use queries::usages::ActiveUserCount;
+pub use queries::usages::{ActiveUserCount, TokenUsage};
use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions;
use sea_orm::{
@@ -8,3 +8,4 @@ id_type!(ProviderId);
id_type!(UsageId);
id_type!(UsageMeasureId);
id_type!(RevokedAccessTokenId);
+id_type!(BillingEventId);
@@ -1,5 +1,6 @@
use super::*;
+pub mod billing_events;
pub mod providers;
pub mod revoked_access_tokens;
pub mod usages;
@@ -0,0 +1,31 @@
+use super::*;
+use crate::Result;
+use anyhow::Context as _;
+
+impl LlmDatabase {
+ pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
+ self.transaction(|tx| async move {
+ let events_with_models = billing_event::Entity::find()
+ .find_also_related(model::Entity)
+ .all(&*tx)
+ .await?;
+ events_with_models
+ .into_iter()
+ .map(|(event, model)| {
+ let model =
+ model.context("could not find model associated with billing event")?;
+ Ok((event, model))
+ })
+ .collect()
+ })
+ .await
+ }
+
+ pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
+ self.transaction(|tx| async move {
+ billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
+ Ok(())
+ })
+ .await
+ }
+}
@@ -1,5 +1,5 @@
-use crate::db::UserId;
use crate::llm::Cents;
+use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
use chrono::{Datelike, Duration};
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
@@ -9,15 +9,26 @@ use strum::IntoEnumIterator as _;
use super::*;
+#[derive(Debug, PartialEq, Clone, Copy, Default)]
+pub struct TokenUsage {
+ pub input: usize,
+ pub input_cache_creation: usize,
+ pub input_cache_read: usize,
+ pub output: usize,
+}
+
+impl TokenUsage {
+ pub fn total(&self) -> usize {
+ self.input + self.input_cache_creation + self.input_cache_read + self.output
+ }
+}
+
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub tokens_this_day: usize,
- pub input_tokens_this_month: usize,
- pub cache_creation_input_tokens_this_month: usize,
- pub cache_read_input_tokens_this_month: usize,
- pub output_tokens_this_month: usize,
+ pub tokens_this_month: TokenUsage,
pub spending_this_month: Cents,
pub lifetime_spending: Cents,
}
@@ -257,18 +268,20 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
- input_tokens_this_month: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.input_tokens as usize),
- cache_creation_input_tokens_this_month: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
- cache_read_input_tokens_this_month: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.cache_read_input_tokens as usize),
- output_tokens_this_month: monthly_usage
- .as_ref()
- .map_or(0, |usage| usage.output_tokens as usize),
+ tokens_this_month: TokenUsage {
+ input: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.input_tokens as usize),
+ input_cache_creation: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
+ input_cache_read: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.cache_read_input_tokens as usize),
+ output: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.output_tokens as usize),
+ },
spending_this_month,
lifetime_spending,
})
@@ -283,10 +296,9 @@ impl LlmDatabase {
is_staff: bool,
provider: LanguageModelProvider,
model_name: &str,
- input_token_count: usize,
- cache_creation_input_tokens: usize,
- cache_read_input_tokens: usize,
- output_token_count: usize,
+ tokens: TokenUsage,
+ has_llm_subscription: bool,
+ max_monthly_spend: Cents,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
@@ -313,10 +325,6 @@ impl LlmDatabase {
&tx,
)
.await?;
- let total_token_count = input_token_count
- + cache_read_input_tokens
- + cache_creation_input_tokens
- + output_token_count;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
@@ -325,7 +333,7 @@ impl LlmDatabase {
&usages,
UsageMeasure::TokensPerMinute,
now,
- total_token_count,
+ tokens.total(),
&tx,
)
.await?;
@@ -337,7 +345,7 @@ impl LlmDatabase {
&usages,
UsageMeasure::TokensPerDay,
now,
- total_token_count,
+ tokens.total(),
&tx,
)
.await?;
@@ -361,18 +369,14 @@ impl LlmDatabase {
Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
- input_tokens: ActiveValue::set(
- usage.input_tokens + input_token_count as i64,
- ),
+ input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
- usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+ usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
- usage.cache_read_input_tokens + cache_read_input_tokens as i64,
- ),
- output_tokens: ActiveValue::set(
- usage.output_tokens + output_token_count as i64,
+ usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
+ output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
@@ -384,12 +388,12 @@ impl LlmDatabase {
model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month),
year: ActiveValue::set(year),
- input_tokens: ActiveValue::set(input_token_count as i64),
+ input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
- cache_creation_input_tokens as i64,
+ tokens.input_cache_creation as i64,
),
- cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
- output_tokens: ActiveValue::set(output_token_count as i64),
+ cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+ output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
@@ -405,6 +409,26 @@ impl LlmDatabase {
monthly_usage.output_tokens as usize,
);
+ if spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
+ && has_llm_subscription
+ && spending_this_month <= max_monthly_spend
+ {
+ billing_event::ActiveModel {
+ id: ActiveValue::not_set(),
+ idempotency_key: ActiveValue::not_set(),
+ user_id: ActiveValue::set(user_id),
+ model_id: ActiveValue::set(model.id),
+ input_tokens: ActiveValue::set(tokens.input as i64),
+ input_cache_creation_tokens: ActiveValue::set(
+ tokens.input_cache_creation as i64,
+ ),
+ input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+ output_tokens: ActiveValue::set(tokens.output as i64),
+ }
+ .insert(&*tx)
+ .await?;
+ }
+
// Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
@@ -419,18 +443,14 @@ impl LlmDatabase {
Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
- input_tokens: ActiveValue::set(
- usage.input_tokens + input_token_count as i64,
- ),
+ input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
- usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+ usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
- usage.cache_read_input_tokens + cache_read_input_tokens as i64,
- ),
- output_tokens: ActiveValue::set(
- usage.output_tokens + output_token_count as i64,
+ usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
+ output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
@@ -440,12 +460,12 @@ impl LlmDatabase {
lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
- input_tokens: ActiveValue::set(input_token_count as i64),
+ input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
- cache_creation_input_tokens as i64,
+ tokens.input_cache_creation as i64,
),
- cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
- output_tokens: ActiveValue::set(output_token_count as i64),
+ cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
+ output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
@@ -465,11 +485,12 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
- input_tokens_this_month: monthly_usage.input_tokens as usize,
- cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
- as usize,
- cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
- output_tokens_this_month: monthly_usage.output_tokens as usize,
+ tokens_this_month: TokenUsage {
+ input: monthly_usage.input_tokens as usize,
+ input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
+ input_cache_read: monthly_usage.cache_read_input_tokens as usize,
+ output: monthly_usage.output_tokens as usize,
+ },
spending_this_month,
lifetime_spending,
})
@@ -1,3 +1,4 @@
+pub mod billing_event;
pub mod lifetime_usage;
pub mod model;
pub mod monthly_usage;
@@ -0,0 +1,37 @@
+use crate::{
+ db::UserId,
+ llm::db::{BillingEventId, ModelId},
+};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_events")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: BillingEventId,
+ pub idempotency_key: Uuid,
+ pub user_id: UserId,
+ pub model_id: ModelId,
+ pub input_tokens: i64,
+ pub input_cache_creation_tokens: i64,
+ pub input_cache_read_tokens: i64,
+ pub output_tokens: i64,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::model::Entity",
+ from = "Column::ModelId",
+ to = "super::model::Column::Id"
+ )]
+ Model,
+}
+
+impl Related<super::model::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Model.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -29,6 +29,8 @@ pub enum Relation {
Provider,
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
+ #[sea_orm(has_many = "super::billing_event::Entity")]
+ BillingEvents,
}
impl Related<super::provider::Entity> for Entity {
@@ -43,4 +45,10 @@ impl Related<super::usage::Entity> for Entity {
}
}
+impl Related<super::billing_event::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::BillingEvents.def()
+ }
+}
+
impl ActiveModelBehavior for ActiveModel {}
@@ -2,7 +2,7 @@ use crate::{
db::UserId,
llm::db::{
queries::{providers::ModelParams, usages::Usage},
- LlmDatabase,
+ LlmDatabase, TokenUsage,
},
test_llm_db, Cents,
};
@@ -36,14 +36,42 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
let user_id = UserId::from_proto(123);
let now = t0;
- db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 1000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let now = t0 + Duration::seconds(10);
- db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 2000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -52,10 +80,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 3000,
tokens_this_day: 3000,
- input_tokens_this_month: 3000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 3000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -69,19 +99,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 2000,
tokens_this_day: 3000,
- input_tokens_this_month: 3000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 3000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
- db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 3000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -90,10 +136,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 5000,
tokens_this_day: 6000,
- input_tokens_this_month: 6000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 6000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -108,18 +156,34 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 0,
tokens_this_minute: 0,
tokens_this_day: 5000,
- input_tokens_this_month: 6000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 6000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
- db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 4000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -128,10 +192,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 4000,
tokens_this_day: 9000,
- input_tokens_this_month: 10000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 10000,
+ input_cache_creation: 0,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -143,9 +209,23 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
.with_timezone(&Utc);
// Test cache creation input tokens
- db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 1000,
+ input_cache_creation: 500,
+ input_cache_read: 0,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -154,19 +234,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 1500,
tokens_this_day: 1500,
- input_tokens_this_month: 1000,
- cache_creation_input_tokens_this_month: 500,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 1000,
+ input_cache_creation: 500,
+ input_cache_read: 0,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// Test cache read input tokens
- db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
- .await
- .unwrap();
+ db.record_usage(
+ user_id,
+ false,
+ provider,
+ model,
+ TokenUsage {
+ input: 1000,
+ input_cache_creation: 0,
+ input_cache_read: 300,
+ output: 0,
+ },
+ false,
+ Cents::ZERO,
+ now,
+ )
+ .await
+ .unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -175,10 +271,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 2800,
tokens_this_day: 2800,
- input_tokens_this_month: 2000,
- cache_creation_input_tokens_this_month: 500,
- cache_read_input_tokens_this_month: 300,
- output_tokens_this_month: 0,
+ tokens_this_month: TokenUsage {
+ input: 2000,
+ input_cache_creation: 500,
+ input_cache_read: 300,
+ output: 0,
+ },
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -157,7 +157,7 @@ async fn main() -> Result<()> {
if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?;
- sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
+ sync_llm_usage_with_stripe_periodically(state.clone());
}
app = app
@@ -1218,6 +1218,15 @@ impl Server {
Ok(())
}
+ pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
+ let pool = self.connection_pool.lock();
+ for connection_id in pool.user_connection_ids(user_id) {
+ self.peer
+ .send(connection_id, proto::RefreshLlmToken {})
+ .trace_err();
+ }
+ }
+
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot {
connection_pool: ConnectionPoolGuard {
@@ -0,0 +1,427 @@
+use std::sync::Arc;
+
+use crate::{llm, Cents, Result};
+use anyhow::Context;
+use chrono::Utc;
+use collections::HashMap;
+use serde::{Deserialize, Serialize};
+
+pub struct StripeBilling {
+ meters_by_event_name: HashMap<String, StripeMeter>,
+ price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
+ client: Arc<stripe::Client>,
+}
+
+pub struct StripeModel {
+ input_tokens_price: StripeBillingPrice,
+ input_cache_creation_tokens_price: StripeBillingPrice,
+ input_cache_read_tokens_price: StripeBillingPrice,
+ output_tokens_price: StripeBillingPrice,
+}
+
+struct StripeBillingPrice {
+ id: stripe::PriceId,
+ meter_event_name: String,
+}
+
+impl StripeBilling {
+ pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
+ let mut meters_by_event_name = HashMap::default();
+ for meter in StripeMeter::list(&client).await?.data {
+ meters_by_event_name.insert(meter.event_name.clone(), meter);
+ }
+
+ let mut price_ids_by_meter_id = HashMap::default();
+ for price in stripe::Price::list(&client, &stripe::ListPrices::default())
+ .await?
+ .data
+ {
+ if let Some(recurring) = price.recurring {
+ if let Some(meter) = recurring.meter {
+ price_ids_by_meter_id.insert(meter, price.id);
+ }
+ }
+ }
+
+ Ok(Self {
+ meters_by_event_name,
+ price_ids_by_meter_id,
+ client,
+ })
+ }
+
+ pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
+ let input_tokens_price = self
+ .get_or_insert_price(
+ &format!("model_{}/input_tokens", model.id),
+ &format!("{} (Input Tokens)", model.name),
+ Cents::new(model.price_per_million_input_tokens as u32),
+ )
+ .await?;
+ let input_cache_creation_tokens_price = self
+ .get_or_insert_price(
+ &format!("model_{}/input_cache_creation_tokens", model.id),
+ &format!("{} (Input Cache Creation Tokens)", model.name),
+ Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
+ )
+ .await?;
+ let input_cache_read_tokens_price = self
+ .get_or_insert_price(
+ &format!("model_{}/input_cache_read_tokens", model.id),
+ &format!("{} (Input Cache Read Tokens)", model.name),
+ Cents::new(model.price_per_million_cache_read_input_tokens as u32),
+ )
+ .await?;
+ let output_tokens_price = self
+ .get_or_insert_price(
+ &format!("model_{}/output_tokens", model.id),
+ &format!("{} (Output Tokens)", model.name),
+ Cents::new(model.price_per_million_output_tokens as u32),
+ )
+ .await?;
+ Ok(StripeModel {
+ input_tokens_price,
+ input_cache_creation_tokens_price,
+ input_cache_read_tokens_price,
+ output_tokens_price,
+ })
+ }
+
+ async fn get_or_insert_price(
+ &mut self,
+ meter_event_name: &str,
+ price_description: &str,
+ price_per_million_tokens: Cents,
+ ) -> Result<StripeBillingPrice> {
+ let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
+ meter.clone()
+ } else {
+ let meter = StripeMeter::create(
+ &self.client,
+ StripeCreateMeterParams {
+ default_aggregation: DefaultAggregation { formula: "sum" },
+ display_name: price_description.to_string(),
+ event_name: meter_event_name,
+ },
+ )
+ .await?;
+ self.meters_by_event_name
+ .insert(meter_event_name.to_string(), meter.clone());
+ meter
+ };
+
+ let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
+ price_id.clone()
+ } else {
+ let price = stripe::Price::create(
+ &self.client,
+ stripe::CreatePrice {
+ active: Some(true),
+ billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
+ currency: stripe::Currency::USD,
+ currency_options: None,
+ custom_unit_amount: None,
+ expand: &[],
+ lookup_key: None,
+ metadata: None,
+ nickname: None,
+ product: None,
+ product_data: Some(stripe::CreatePriceProductData {
+ id: None,
+ active: Some(true),
+ metadata: None,
+ name: price_description.to_string(),
+ statement_descriptor: None,
+ tax_code: None,
+ unit_label: None,
+ }),
+ recurring: Some(stripe::CreatePriceRecurring {
+ aggregate_usage: None,
+ interval: stripe::CreatePriceRecurringInterval::Month,
+ interval_count: None,
+ trial_period_days: None,
+ usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
+ meter: Some(meter.id.clone()),
+ }),
+ tax_behavior: None,
+ tiers: None,
+ tiers_mode: None,
+ transfer_lookup_key: None,
+ transform_quantity: None,
+ unit_amount: None,
+ unit_amount_decimal: Some(&format!(
+ "{:.12}",
+ price_per_million_tokens.0 as f64 / 1_000_000f64
+ )),
+ },
+ )
+ .await?;
+ self.price_ids_by_meter_id
+ .insert(meter.id, price.id.clone());
+ price.id
+ };
+
+ Ok(StripeBillingPrice {
+ id: price_id,
+ meter_event_name: meter_event_name.to_string(),
+ })
+ }
+
+ pub async fn subscribe_to_model(
+ &self,
+ subscription_id: &stripe::SubscriptionId,
+ model: &StripeModel,
+ ) -> Result<()> {
+ let subscription =
+ stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+ let mut items = Vec::new();
+
+ if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
+ items.push(stripe::UpdateSubscriptionItems {
+ price: Some(model.input_tokens_price.id.to_string()),
+ ..Default::default()
+ });
+ }
+
+ if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
+ {
+ items.push(stripe::UpdateSubscriptionItems {
+ price: Some(model.input_cache_creation_tokens_price.id.to_string()),
+ ..Default::default()
+ });
+ }
+
+ if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
+ items.push(stripe::UpdateSubscriptionItems {
+ price: Some(model.input_cache_read_tokens_price.id.to_string()),
+ ..Default::default()
+ });
+ }
+
+ if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
+ items.push(stripe::UpdateSubscriptionItems {
+ price: Some(model.output_tokens_price.id.to_string()),
+ ..Default::default()
+ });
+ }
+
+ if !items.is_empty() {
+ items.extend(subscription.items.data.iter().map(|item| {
+ stripe::UpdateSubscriptionItems {
+ id: Some(item.id.to_string()),
+ ..Default::default()
+ }
+ }));
+
+ stripe::Subscription::update(
+ &self.client,
+ subscription_id,
+ stripe::UpdateSubscription {
+ items: Some(items),
+ ..Default::default()
+ },
+ )
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn bill_model_usage(
+ &self,
+ customer_id: &stripe::CustomerId,
+ model: &StripeModel,
+ event: &llm::db::billing_event::Model,
+ ) -> Result<()> {
+ let timestamp = Utc::now().timestamp();
+
+ if event.input_tokens > 0 {
+ StripeMeterEvent::create(
+ &self.client,
+ StripeCreateMeterEventParams {
+ identifier: &format!("input_tokens/{}", event.idempotency_key),
+ event_name: &model.input_tokens_price.meter_event_name,
+ payload: StripeCreateMeterEventPayload {
+ value: event.input_tokens as u64,
+ stripe_customer_id: customer_id,
+ },
+ timestamp: Some(timestamp),
+ },
+ )
+ .await?;
+ }
+
+ if event.input_cache_creation_tokens > 0 {
+ StripeMeterEvent::create(
+ &self.client,
+ StripeCreateMeterEventParams {
+ identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
+ event_name: &model.input_cache_creation_tokens_price.meter_event_name,
+ payload: StripeCreateMeterEventPayload {
+ value: event.input_cache_creation_tokens as u64,
+ stripe_customer_id: customer_id,
+ },
+ timestamp: Some(timestamp),
+ },
+ )
+ .await?;
+ }
+
+ if event.input_cache_read_tokens > 0 {
+ StripeMeterEvent::create(
+ &self.client,
+ StripeCreateMeterEventParams {
+ identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
+ event_name: &model.input_cache_read_tokens_price.meter_event_name,
+ payload: StripeCreateMeterEventPayload {
+ value: event.input_cache_read_tokens as u64,
+ stripe_customer_id: customer_id,
+ },
+ timestamp: Some(timestamp),
+ },
+ )
+ .await?;
+ }
+
+ if event.output_tokens > 0 {
+ StripeMeterEvent::create(
+ &self.client,
+ StripeCreateMeterEventParams {
+ identifier: &format!("output_tokens/{}", event.idempotency_key),
+ event_name: &model.output_tokens_price.meter_event_name,
+ payload: StripeCreateMeterEventPayload {
+ value: event.output_tokens as u64,
+ stripe_customer_id: customer_id,
+ },
+ timestamp: Some(timestamp),
+ },
+ )
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ pub async fn checkout(
+ &self,
+ customer_id: stripe::CustomerId,
+ github_login: &str,
+ model: &StripeModel,
+ success_url: &str,
+ ) -> Result<String> {
+ let mut params = stripe::CreateCheckoutSession::new();
+ params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+ params.customer = Some(customer_id);
+ params.client_reference_id = Some(github_login);
+ params.line_items = Some(
+ [
+ &model.input_tokens_price.id,
+ &model.input_cache_creation_tokens_price.id,
+ &model.input_cache_read_tokens_price.id,
+ &model.output_tokens_price.id,
+ ]
+ .into_iter()
+ .map(|price_id| stripe::CreateCheckoutSessionLineItems {
+ price: Some(price_id.to_string()),
+ ..Default::default()
+ })
+ .collect(),
+ );
+ params.success_url = Some(success_url);
+
+ let session = stripe::CheckoutSession::create(&self.client, params).await?;
+ Ok(session.url.context("no checkout session URL")?)
+ }
+}
+
+#[derive(Serialize)]
+struct DefaultAggregation {
+ formula: &'static str,
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterParams<'a> {
+ default_aggregation: DefaultAggregation,
+ display_name: String,
+ event_name: &'a str,
+}
+
+#[derive(Clone, Deserialize)]
+struct StripeMeter {
+ id: String,
+ event_name: String,
+}
+
+impl StripeMeter {
+ pub fn create(
+ client: &stripe::Client,
+ params: StripeCreateMeterParams,
+ ) -> stripe::Response<Self> {
+ client.post_form("/billing/meters", params)
+ }
+
+ pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
+ #[derive(Serialize)]
+ struct Params {}
+
+ client.get_query("/billing/meters", Params {})
+ }
+}
+
+#[derive(Deserialize)]
+struct StripeMeterEvent {
+ identifier: String,
+}
+
+impl StripeMeterEvent {
+ pub async fn create(
+ client: &stripe::Client,
+ params: StripeCreateMeterEventParams<'_>,
+ ) -> Result<Self, stripe::StripeError> {
+ let identifier = params.identifier;
+ match client.post_form("/billing/meter_events", params).await {
+ Ok(event) => Ok(event),
+ Err(stripe::StripeError::Stripe(error)) => {
+ if error.http_status == 400
+ && error
+ .message
+ .as_ref()
+ .map_or(false, |message| message.contains(identifier))
+ {
+ Ok(Self {
+ identifier: identifier.to_string(),
+ })
+ } else {
+ Err(stripe::StripeError::Stripe(error))
+ }
+ }
+ Err(error) => Err(error),
+ }
+ }
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterEventParams<'a> {
+ identifier: &'a str,
+ event_name: &'a str,
+ payload: StripeCreateMeterEventPayload<'a>,
+ timestamp: Option<i64>,
+}
+
+#[derive(Serialize)]
+struct StripeCreateMeterEventPayload<'a> {
+ value: u64,
+ stripe_customer_id: &'a stripe::CustomerId,
+}
+
+fn subscription_contains_price(
+ subscription: &stripe::Subscription,
+ price_id: &stripe::PriceId,
+) -> bool {
+ subscription.items.data.iter().any(|item| {
+ item.price
+ .as_ref()
+ .map_or(false, |price| price.id == *price_id)
+ })
+}
@@ -635,6 +635,7 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
+ llm_db: None,
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
stripe_client: None,
@@ -677,8 +678,6 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
- stripe_llm_access_price_id: None,
- stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
},