Detailed changes
@@ -0,0 +1,13 @@
+create table monthly_usages (
+ id serial primary key,
+ user_id integer not null,
+ model_id integer not null references models (id) on delete cascade,
+ month integer not null,
+ year integer not null,
+ input_tokens bigint not null default 0,
+ cache_creation_input_tokens bigint not null default 0,
+ cache_read_input_tokens bigint not null default 0,
+ output_tokens bigint not null default 0
+);
+
+create unique index uix_monthly_usages_on_user_id_model_id_month_year on monthly_usages (user_id, model_id, month, year);
@@ -22,12 +22,15 @@ use stripe::{
};
use util::ResultExt;
-use crate::db::billing_subscription::StripeSubscriptionStatus;
+use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingSubscriptionParams,
};
+use crate::llm::db::LlmDatabase;
+use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS;
+use crate::rpc::ResultExt as _;
use crate::{AppState, Error, Result};
pub fn router() -> Router {
@@ -79,7 +82,7 @@ async fn list_billing_subscriptions(
.into_iter()
.map(|subscription| BillingSubscriptionJson {
id: subscription.id,
- name: "Zed Pro".to_string(),
+ name: "Zed LLM Usage".to_string(),
status: subscription.stripe_subscription_status,
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at
@@ -117,7 +120,7 @@ async fn create_billing_subscription(
let Some((stripe_client, stripe_price_id)) = app
.stripe_client
.clone()
- .zip(app.config.stripe_price_id.clone())
+ .zip(app.config.stripe_llm_usage_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::http(
@@ -150,7 +153,7 @@ async fn create_billing_subscription(
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_price_id.to_string()),
- quantity: Some(1),
+ quantity: Some(0),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
@@ -631,3 +634,95 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
+
+const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+
+pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
+ let Some(stripe_client) = app.stripe_client.clone() else {
+ log::warn!("failed to retrieve Stripe client");
+ return;
+ };
+ let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
+ log::warn!("failed to retrieve Stripe LLM usage price ID");
+ return;
+ };
+
+ let executor = app.executor.clone();
+ executor.spawn_detached({
+ let executor = executor.clone();
+ async move {
+ loop {
+ sync_with_stripe(
+ &app,
+ &llm_db,
+ &stripe_client,
+ stripe_llm_usage_price_id.clone(),
+ )
+ .await
+ .trace_err();
+
+ executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
+ }
+ }
+ });
+}
+
+async fn sync_with_stripe(
+ app: &Arc<AppState>,
+ llm_db: &LlmDatabase,
+ stripe_client: &stripe::Client,
+ stripe_llm_usage_price_id: Arc<str>,
+) -> anyhow::Result<()> {
+ let subscriptions = app.db.get_active_billing_subscriptions().await?;
+
+ for (customer, subscription) in subscriptions {
+ update_stripe_subscription(
+ llm_db,
+ stripe_client,
+ &stripe_llm_usage_price_id,
+ customer,
+ subscription,
+ )
+ .await
+ .log_err();
+ }
+
+ Ok(())
+}
+
+async fn update_stripe_subscription(
+ llm_db: &LlmDatabase,
+ stripe_client: &stripe::Client,
+ stripe_llm_usage_price_id: &Arc<str>,
+ customer: billing_customer::Model,
+ subscription: billing_subscription::Model,
+) -> Result<(), anyhow::Error> {
+ let monthly_spending = llm_db
+ .get_user_spending_for_month(customer.user_id, Utc::now())
+ .await?;
+ let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
+ .context("failed to parse subscription ID")?;
+
+ let monthly_spending_over_free_tier =
+ monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS);
+
+ let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil();
+ Subscription::update(
+ stripe_client,
+ &subscription_id,
+ stripe::UpdateSubscription {
+ items: Some(vec![stripe::UpdateSubscriptionItems {
+ // TODO: Do we need to send up the `id` if a subscription item
+ // with this price already exists, or will Stripe take care of
+ // it?
+ id: None,
+ price: Some(stripe_llm_usage_price_id.to_string()),
+ quantity: Some(new_quantity as u64),
+ ..Default::default()
+ }]),
+ ..Default::default()
+ },
+ )
+ .await?;
+ Ok(())
+}
@@ -112,6 +112,29 @@ impl Database {
.await
}
+ pub async fn get_active_billing_subscriptions(
+ &self,
+ ) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
+ self.transaction(|tx| async move {
+ let mut result = Vec::new();
+ let mut rows = billing_subscription::Entity::find()
+ .inner_join(billing_customer::Entity)
+ .select_also(billing_customer::Entity)
+ .order_by_asc(billing_subscription::Column::Id)
+ .stream(&*tx)
+ .await?;
+
+ while let Some(row) = rows.next().await {
+ if let (subscription, Some(customer)) = row? {
+ result.push((customer, subscription));
+ }
+ }
+
+ Ok(result)
+ })
+ .await
+ }
+
/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
@@ -174,7 +174,7 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
- pub stripe_price_id: Option<Arc<str>>,
+ pub stripe_llm_usage_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -193,6 +193,10 @@ impl Config {
}
}
+ pub fn is_llm_billing_enabled(&self) -> bool {
+ self.stripe_llm_usage_price_id.is_some()
+ }
+
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -231,7 +235,7 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
- stripe_price_id: None,
+ stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
}
@@ -436,6 +436,9 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
}
}
+/// The maximum monthly spending an individual user can reach before they have to pay.
+pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100;
+
/// The maximum lifetime spending an individual user can reach before being cut off.
///
/// Represented in cents.
@@ -458,6 +461,18 @@ async fn check_usage_limit(
)
.await?;
+ if state.config.is_llm_billing_enabled() {
+ if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS {
+ if !claims.has_llm_subscription.unwrap_or(false) {
+ return Err(Error::http(
+ StatusCode::PAYMENT_REQUIRED,
+ "Maximum spending limit reached for this month.".to_string(),
+ ));
+ }
+ }
+ }
+
+ // TODO: Remove this once we've rolled out monthly spending limits.
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
return Err(Error::http(
StatusCode::FORBIDDEN,
@@ -505,7 +520,6 @@ async fn check_usage_limit(
UsageMeasure::RequestsPerMinute => "requests_per_minute",
UsageMeasure::TokensPerMinute => "tokens_per_minute",
UsageMeasure::TokensPerDay => "tokens_per_day",
- _ => "",
};
if let Some(client) = state.clickhouse_client.as_ref() {
@@ -97,6 +97,14 @@ impl LlmDatabase {
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
}
+ pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
+ Ok(self
+ .models
+ .values()
+ .find(|model| model.id == id)
+ .ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
+ }
+
pub fn options(&self) -> &ConnectOptions {
&self.options
}
@@ -1,5 +1,5 @@
use crate::db::UserId;
-use chrono::Duration;
+use chrono::{Datelike, Duration};
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
@@ -140,6 +140,46 @@ impl LlmDatabase {
.await
}
+ pub async fn get_user_spending_for_month(
+ &self,
+ user_id: UserId,
+ now: DateTimeUtc,
+ ) -> Result<usize> {
+ self.transaction(|tx| async move {
+ let month = now.date_naive().month() as i32;
+ let year = now.date_naive().year();
+
+ let mut monthly_usages = monthly_usage::Entity::find()
+ .filter(
+ monthly_usage::Column::UserId
+ .eq(user_id)
+ .and(monthly_usage::Column::Month.eq(month))
+ .and(monthly_usage::Column::Year.eq(year)),
+ )
+ .stream(&*tx)
+ .await?;
+ let mut monthly_spending_in_cents = 0;
+
+ while let Some(usage) = monthly_usages.next().await {
+ let usage = usage?;
+ let Ok(model) = self.model_by_id(usage.model_id) else {
+ continue;
+ };
+
+ monthly_spending_in_cents += calculate_spending(
+ model,
+ usage.input_tokens as usize,
+ usage.cache_creation_input_tokens as usize,
+ usage.cache_read_input_tokens as usize,
+ usage.output_tokens as usize,
+ );
+ }
+
+ Ok(monthly_spending_in_cents)
+ })
+ .await
+ }
+
pub async fn get_usage(
&self,
user_id: UserId,
@@ -162,6 +202,18 @@ impl LlmDatabase {
.all(&*tx)
.await?;
+ let month = now.date_naive().month() as i32;
+ let year = now.date_naive().year();
+ let monthly_usage = monthly_usage::Entity::find()
+ .filter(
+ monthly_usage::Column::UserId
+ .eq(user_id)
+ .and(monthly_usage::Column::ModelId.eq(model.id))
+ .and(monthly_usage::Column::Month.eq(month))
+ .and(monthly_usage::Column::Year.eq(year)),
+ )
+ .one(&*tx)
+ .await?;
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
@@ -177,28 +229,18 @@ impl LlmDatabase {
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
let tokens_this_day =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
- let input_tokens_this_month =
- self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?;
- let cache_creation_input_tokens_this_month = self.get_usage_for_measure(
- &usages,
- now,
- UsageMeasure::CacheCreationInputTokensPerMonth,
- )?;
- let cache_read_input_tokens_this_month = self.get_usage_for_measure(
- &usages,
- now,
- UsageMeasure::CacheReadInputTokensPerMonth,
- )?;
- let output_tokens_this_month =
- self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
- let spending_this_month = calculate_spending(
- model,
- input_tokens_this_month,
- cache_creation_input_tokens_this_month,
- cache_read_input_tokens_this_month,
- output_tokens_this_month,
- );
- let lifetime_spending = if let Some(lifetime_usage) = lifetime_usage {
+ let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
+ calculate_spending(
+ model,
+ monthly_usage.input_tokens as usize,
+ monthly_usage.cache_creation_input_tokens as usize,
+ monthly_usage.cache_read_input_tokens as usize,
+ monthly_usage.output_tokens as usize,
+ )
+ } else {
+ 0
+ };
+ let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
calculate_spending(
model,
lifetime_usage.input_tokens as usize,
@@ -214,10 +256,18 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
- input_tokens_this_month,
- cache_creation_input_tokens_this_month,
- cache_read_input_tokens_this_month,
- output_tokens_this_month,
+ input_tokens_this_month: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.input_tokens as usize),
+ cache_creation_input_tokens_this_month: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
+ cache_read_input_tokens_this_month: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.cache_read_input_tokens as usize),
+ output_tokens_this_month: monthly_usage
+ .as_ref()
+ .map_or(0, |usage| usage.output_tokens as usize),
spending_this_month,
lifetime_spending,
})
@@ -290,60 +340,68 @@ impl LlmDatabase {
&tx,
)
.await?;
- let input_tokens_this_month = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::InputTokensPerMonth,
- now,
- input_token_count,
- &tx,
- )
- .await?;
- let cache_creation_input_tokens_this_month = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::CacheCreationInputTokensPerMonth,
- now,
- cache_creation_input_tokens,
- &tx,
- )
- .await?;
- let cache_read_input_tokens_this_month = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::CacheReadInputTokensPerMonth,
- now,
- cache_read_input_tokens,
- &tx,
- )
- .await?;
- let output_tokens_this_month = self
- .update_usage_for_measure(
- user_id,
- is_staff,
- model.id,
- &usages,
- UsageMeasure::OutputTokensPerMonth,
- now,
- output_token_count,
- &tx,
+
+ let month = now.date_naive().month() as i32;
+ let year = now.date_naive().year();
+
+ // Update monthly usage
+ let monthly_usage = monthly_usage::Entity::find()
+ .filter(
+ monthly_usage::Column::UserId
+ .eq(user_id)
+ .and(monthly_usage::Column::ModelId.eq(model.id))
+ .and(monthly_usage::Column::Month.eq(month))
+ .and(monthly_usage::Column::Year.eq(year)),
)
+ .one(&*tx)
.await?;
+
+ let monthly_usage = match monthly_usage {
+ Some(usage) => {
+ monthly_usage::Entity::update(monthly_usage::ActiveModel {
+ id: ActiveValue::unchanged(usage.id),
+ input_tokens: ActiveValue::set(
+ usage.input_tokens + input_token_count as i64,
+ ),
+ cache_creation_input_tokens: ActiveValue::set(
+ usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
+ ),
+ cache_read_input_tokens: ActiveValue::set(
+ usage.cache_read_input_tokens + cache_read_input_tokens as i64,
+ ),
+ output_tokens: ActiveValue::set(
+ usage.output_tokens + output_token_count as i64,
+ ),
+ ..Default::default()
+ })
+ .exec(&*tx)
+ .await?
+ }
+ None => {
+ monthly_usage::ActiveModel {
+ user_id: ActiveValue::set(user_id),
+ model_id: ActiveValue::set(model.id),
+ month: ActiveValue::set(month),
+ year: ActiveValue::set(year),
+ input_tokens: ActiveValue::set(input_token_count as i64),
+ cache_creation_input_tokens: ActiveValue::set(
+ cache_creation_input_tokens as i64,
+ ),
+ cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
+ output_tokens: ActiveValue::set(output_token_count as i64),
+ ..Default::default()
+ }
+ .insert(&*tx)
+ .await?
+ }
+ };
+
let spending_this_month = calculate_spending(
model,
- input_tokens_this_month,
- cache_creation_input_tokens_this_month,
- cache_read_input_tokens_this_month,
- output_tokens_this_month,
+ monthly_usage.input_tokens as usize,
+ monthly_usage.cache_creation_input_tokens as usize,
+ monthly_usage.cache_read_input_tokens as usize,
+ monthly_usage.output_tokens as usize,
);
// Update lifetime usage
@@ -406,10 +464,11 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
- input_tokens_this_month,
- cache_creation_input_tokens_this_month,
- cache_read_input_tokens_this_month,
- output_tokens_this_month,
+ input_tokens_this_month: monthly_usage.input_tokens as usize,
+ cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
+ as usize,
+ cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
+ output_tokens_this_month: monthly_usage.output_tokens as usize,
spending_this_month,
lifetime_spending,
})
@@ -597,7 +656,6 @@ fn calculate_spending(
const MINUTE_BUCKET_COUNT: usize = 12;
const DAY_BUCKET_COUNT: usize = 48;
-const MONTH_BUCKET_COUNT: usize = 30;
impl UsageMeasure {
fn bucket_count(&self) -> usize {
@@ -605,10 +663,6 @@ impl UsageMeasure {
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
- UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT,
- UsageMeasure::CacheCreationInputTokensPerMonth => MONTH_BUCKET_COUNT,
- UsageMeasure::CacheReadInputTokensPerMonth => MONTH_BUCKET_COUNT,
- UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT,
}
}
@@ -617,10 +671,6 @@ impl UsageMeasure {
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerDay => Duration::hours(24),
- UsageMeasure::InputTokensPerMonth => Duration::days(30),
- UsageMeasure::CacheCreationInputTokensPerMonth => Duration::days(30),
- UsageMeasure::CacheReadInputTokensPerMonth => Duration::days(30),
- UsageMeasure::OutputTokensPerMonth => Duration::days(30),
}
}
@@ -1,5 +1,6 @@
pub mod lifetime_usage;
pub mod model;
+pub mod monthly_usage;
pub mod provider;
pub mod revoked_access_token;
pub mod usage;
@@ -0,0 +1,22 @@
+use crate::{db::UserId, llm::db::ModelId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "monthly_usages")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: i32,
+ pub user_id: UserId,
+ pub model_id: ModelId,
+ pub month: i32,
+ pub year: i32,
+ pub input_tokens: i64,
+ pub cache_creation_input_tokens: i64,
+ pub cache_read_input_tokens: i64,
+ pub output_tokens: i64,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -9,10 +9,6 @@ pub enum UsageMeasure {
RequestsPerMinute,
TokensPerMinute,
TokensPerDay,
- InputTokensPerMonth,
- CacheCreationInputTokensPerMonth,
- CacheReadInputTokensPerMonth,
- OutputTokensPerMonth,
}
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
@@ -6,7 +6,7 @@ use crate::{
},
test_llm_db,
};
-use chrono::{Duration, Utc};
+use chrono::{DateTime, Duration, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
@@ -29,7 +29,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
.await
.unwrap();
- let t0 = Utc::now();
+ // We're using a fixed datetime to prevent flakiness based on the clock.
+ let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
+ .unwrap()
+ .with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let now = t0;
@@ -134,23 +137,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
}
);
- let t2 = t0 + Duration::days(30);
- let now = t2;
- let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
- assert_eq!(
- usage,
- Usage {
- requests_this_minute: 0,
- tokens_this_minute: 0,
- tokens_this_day: 0,
- input_tokens_this_month: 9000,
- cache_creation_input_tokens_this_month: 0,
- cache_read_input_tokens_this_month: 0,
- output_tokens_this_month: 0,
- spending_this_month: 0,
- lifetime_spending: 0,
- }
- );
+ // We're using a fixed datetime to prevent flakiness based on the clock.
+ let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
+ .unwrap()
+ .with_timezone(&Utc);
// Test cache creation input tokens
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
@@ -164,7 +154,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 1500,
tokens_this_day: 1500,
- input_tokens_this_month: 10000,
+ input_tokens_this_month: 1000,
cache_creation_input_tokens_this_month: 500,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
@@ -185,7 +175,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 2800,
tokens_this_day: 2800,
- input_tokens_this_month: 11000,
+ input_tokens_this_month: 2000,
cache_creation_input_tokens_this_month: 500,
cache_read_input_tokens_this_month: 300,
output_tokens_this_month: 0,
@@ -22,6 +22,12 @@ pub struct LlmTokenClaims {
pub is_staff: bool,
#[serde(default)]
pub has_llm_closed_beta_feature_flag: bool,
+ // This field is temporarily optional so it can be added
+ // in a backwards-compatible way. We can make it required
+ // once all of the LLM tokens have cycled (~1 hour after
+ // this change has been deployed).
+ #[serde(default)]
+ pub has_llm_subscription: Option<bool>,
pub plan: rpc::proto::Plan,
}
@@ -33,6 +39,7 @@ impl LlmTokenClaims {
github_user_login: String,
is_staff: bool,
has_llm_closed_beta_feature_flag: bool,
+ has_llm_subscription: bool,
plan: rpc::proto::Plan,
config: &Config,
) -> Result<String> {
@@ -50,6 +57,7 @@ impl LlmTokenClaims {
github_user_login: Some(github_user_login),
is_staff,
has_llm_closed_beta_feature_flag,
+ has_llm_subscription: Some(has_llm_subscription),
plan,
};
@@ -6,6 +6,7 @@ use axum::{
routing::get,
Extension, Router,
};
+use collab::api::billing::sync_llm_usage_with_stripe_periodically;
use collab::api::CloudflareIpCountryHeader;
use collab::llm::{db::LlmDatabase, log_usage_periodically};
use collab::migrations::run_database_migrations;
@@ -29,7 +30,7 @@ use tower_http::trace::TraceLayer;
use tracing_subscriber::{
filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
};
-use util::ResultExt as _;
+use util::{maybe, ResultExt as _};
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@@ -136,6 +137,28 @@ async fn main() -> Result<()> {
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());
+ let llm_db = maybe!(async {
+ let database_url = state
+ .config
+ .llm_database_url
+ .as_ref()
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+ let max_connections = state
+ .config
+ .llm_database_max_connections
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+ let mut db_options = db::ConnectOptions::new(database_url);
+ db_options.max_connections(max_connections);
+ LlmDatabase::new(db_options, state.executor.clone()).await
+ })
+ .await
+ .trace_err();
+
+ if let Some(llm_db) = llm_db {
+ sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
+ }
+
app = app
.merge(collab::api::events::router())
.merge(collab::api::extensions::router())
@@ -191,16 +191,26 @@ impl Session {
}
}
- pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
+ pub async fn has_llm_subscription(
+ &self,
+ db: &MutexGuard<'_, DbHandle>,
+ ) -> anyhow::Result<bool> {
if self.is_staff() {
- return Ok(proto::Plan::ZedPro);
+ return Ok(true);
}
let Some(user_id) = self.user_id() else {
- return Ok(proto::Plan::Free);
+ return Ok(false);
};
- if db.has_active_billing_subscription(user_id).await? {
+ Ok(db.has_active_billing_subscription(user_id).await?)
+ }
+
+ pub async fn current_plan(
+ &self,
+ _db: &MutexGuard<'_, DbHandle>,
+ ) -> anyhow::Result<proto::Plan> {
+ if self.is_staff() {
Ok(proto::Plan::ZedPro)
} else {
Ok(proto::Plan::Free)
@@ -3471,7 +3481,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
}
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
- let plan = session.current_plan(session.db().await).await?;
+ let plan = session.current_plan(&session.db().await).await?;
session
.peer
@@ -4471,7 +4481,7 @@ async fn count_language_model_tokens(
};
authorize_access_to_legacy_llm_endpoints(&session).await?;
- let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
};
@@ -4592,7 +4602,7 @@ async fn compute_embeddings(
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
- let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
};
@@ -4915,7 +4925,8 @@ async fn get_llm_api_token(
user.github_login.clone(),
session.is_staff(),
has_llm_closed_beta_feature_flag,
- session.current_plan(db).await?,
+ session.has_llm_subscription(&db).await?,
+ session.current_plan(&db).await?,
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
@@ -677,7 +677,7 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
- stripe_price_id: None,
+ stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
},