@@ -1,31 +1,23 @@
use anyhow::{Context as _, bail};
use chrono::{DateTime, Utc};
-use cloud_llm_client::LanguageModelProvider;
-use collections::{HashMap, HashSet};
use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
-use util::{ResultExt, maybe};
+use util::ResultExt;
use crate::AppState;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
-use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
+use crate::db::{
+ CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
+ UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer,
+};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId,
};
-use crate::{db::UserId, llm::db::LlmDatabase};
-use crate::{
- db::{
- CreateBillingCustomerParams, CreateBillingSubscriptionParams,
- CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
- UpdateBillingSubscriptionParams, billing_customer,
- },
- stripe_billing::StripeBilling,
-};
/// The amount of time we wait in between each poll of Stripe events.
///
@@ -542,194 +534,3 @@ pub async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
-
-const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
-
-pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
- let Some(stripe_billing) = app.stripe_billing.clone() else {
- log::warn!("failed to retrieve Stripe billing object");
- return;
- };
- let Some(llm_db) = app.llm_db.clone() else {
- log::warn!("failed to retrieve LLM database");
- return;
- };
-
- let executor = app.executor.clone();
- executor.spawn_detached({
- let executor = executor.clone();
- async move {
- loop {
- sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
- .await
- .context("failed to sync LLM request usage to Stripe")
- .trace_err();
- executor
- .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
- .await;
- }
- }
- });
-}
-
-async fn sync_model_request_usage_with_stripe(
- app: &Arc<AppState>,
- llm_db: &Arc<LlmDatabase>,
- stripe_billing: &Arc<StripeBilling>,
-) -> anyhow::Result<()> {
- let feature_flags = app.db.list_feature_flags().await?;
- let sync_model_request_usage_using_cloud = feature_flags
- .iter()
- .any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
- if sync_model_request_usage_using_cloud {
- return Ok(());
- }
-
- log::info!("Stripe usage sync: Starting");
- let started_at = Utc::now();
-
- let staff_users = app.db.get_staff_users().await?;
- let staff_user_ids = staff_users
- .iter()
- .map(|user| user.id)
- .collect::<HashSet<UserId>>();
-
- let usage_meters = llm_db
- .get_current_subscription_usage_meters(Utc::now())
- .await?;
- let mut usage_meters_by_user_id =
- HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
- for (usage_meter, usage) in usage_meters {
- let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
- meters.push(usage_meter);
- }
-
- log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
- let get_zed_pro_subscriptions_started_at = Utc::now();
- let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
- log::info!(
- "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
- billing_subscriptions.len(),
- Utc::now() - get_zed_pro_subscriptions_started_at
- );
-
- let claude_sonnet_4 = stripe_billing
- .find_price_by_lookup_key("claude-sonnet-4-requests")
- .await?;
- let claude_sonnet_4_max = stripe_billing
- .find_price_by_lookup_key("claude-sonnet-4-requests-max")
- .await?;
- let claude_opus_4 = stripe_billing
- .find_price_by_lookup_key("claude-opus-4-requests")
- .await?;
- let claude_opus_4_max = stripe_billing
- .find_price_by_lookup_key("claude-opus-4-requests-max")
- .await?;
- let claude_3_5_sonnet = stripe_billing
- .find_price_by_lookup_key("claude-3-5-sonnet-requests")
- .await?;
- let claude_3_7_sonnet = stripe_billing
- .find_price_by_lookup_key("claude-3-7-sonnet-requests")
- .await?;
- let claude_3_7_sonnet_max = stripe_billing
- .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
- .await?;
-
- let model_mode_combinations = [
- ("claude-opus-4", CompletionMode::Max),
- ("claude-opus-4", CompletionMode::Normal),
- ("claude-sonnet-4", CompletionMode::Max),
- ("claude-sonnet-4", CompletionMode::Normal),
- ("claude-3-7-sonnet", CompletionMode::Max),
- ("claude-3-7-sonnet", CompletionMode::Normal),
- ("claude-3-5-sonnet", CompletionMode::Normal),
- ];
-
- let billing_subscription_count = billing_subscriptions.len();
-
- log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
-
- for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
- maybe!(async {
- if staff_user_ids.contains(&user_id) {
- return anyhow::Ok(());
- }
-
- let stripe_customer_id =
- StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
- let stripe_subscription_id =
- StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
-
- let usage_meters = usage_meters_by_user_id.get(&user_id);
-
- for (model, mode) in &model_mode_combinations {
- let Ok(model) =
- llm_db.model(LanguageModelProvider::Anthropic, model)
- else {
- log::warn!("Failed to load model for user {user_id}: {model}");
- continue;
- };
-
- let (price, meter_event_name) = match model.name.as_str() {
- "claude-opus-4" => match mode {
- CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
- CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
- },
- "claude-sonnet-4" => match mode {
- CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
- CompletionMode::Max => {
- (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
- }
- },
- "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
- "claude-3-7-sonnet" => match mode {
- CompletionMode::Normal => {
- (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
- }
- CompletionMode::Max => {
- (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
- }
- },
- model_name => {
- bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
- }
- };
-
- let model_requests = usage_meters
- .and_then(|usage_meters| {
- usage_meters
- .iter()
- .find(|meter| meter.model_id == model.id && meter.mode == *mode)
- })
- .map(|usage_meter| usage_meter.requests)
- .unwrap_or(0);
-
- if model_requests > 0 {
- stripe_billing
- .subscribe_to_price(&stripe_subscription_id, price)
- .await?;
- }
-
- stripe_billing
- .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
- .await
- .with_context(|| {
- format!(
- "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
- )
- })?;
- }
-
- Ok(())
- })
- .await
- .log_err();
- }
-
- log::info!(
- "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
- Utc::now() - started_at
- );
-
- Ok(())
-}
@@ -8,7 +8,6 @@ use axum::{
};
use collab::api::CloudflareIpCountryHeader;
-use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
@@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer;
use tracing_subscriber::{
Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt,
};
-use util::{ResultExt as _, maybe};
+use util::ResultExt as _;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@@ -133,29 +132,6 @@ async fn main() -> Result<()> {
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());
- let llm_db = maybe!(async {
- let database_url = state
- .config
- .llm_database_url
- .as_ref()
- .context("missing LLM_DATABASE_URL")?;
- let max_connections = state
- .config
- .llm_database_max_connections
- .context("missing LLM_DATABASE_MAX_CONNECTIONS")?;
-
- let mut db_options = db::ConnectOptions::new(database_url);
- db_options.max_connections(max_connections);
- LlmDatabase::new(db_options, state.executor.clone()).await
- })
- .await
- .trace_err();
-
- if let Some(mut llm_db) = llm_db {
- llm_db.initialize().await?;
- sync_llm_request_usage_with_stripe_periodically(state.clone());
- }
-
app = app
.merge(collab::api::events::router())
.merge(collab::api::extensions::router())