diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0e46e8453d95f0927b2215586becf30b6fef379b..6db40dd187788db103a961cb0d2fcdefeb47acaa 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -20,7 +20,7 @@ use stripe::{ use util::ResultExt; use crate::llm::DEFAULT_MAX_MONTHLY_SPEND; -use crate::rpc::ResultExt as _; +use crate::rpc::{ResultExt as _, Server}; use crate::{ db::{ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, @@ -404,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4; /// Polls the Stripe events API periodically to reconcile the records in our /// database with the data in Stripe. -pub fn poll_stripe_events_periodically(app: Arc) { +pub fn poll_stripe_events_periodically(app: Arc, rpc_server: Arc) { let Some(stripe_client) = app.stripe_client.clone() else { log::warn!("failed to retrieve Stripe client"); return; @@ -415,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc) { let executor = executor.clone(); async move { loop { - poll_stripe_events(&app, &stripe_client).await.log_err(); + poll_stripe_events(&app, &rpc_server, &stripe_client) + .await + .log_err(); executor.sleep(POLL_EVENTS_INTERVAL).await; } @@ -425,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc) { async fn poll_stripe_events( app: &Arc, + rpc_server: &Arc, stripe_client: &stripe::Client, ) -> anyhow::Result<()> { fn event_type_to_string(event_type: EventType) -> String { @@ -541,7 +544,7 @@ async fn poll_stripe_events( | EventType::CustomerSubscriptionPaused | EventType::CustomerSubscriptionResumed | EventType::CustomerSubscriptionDeleted => { - handle_customer_subscription_event(app, stripe_client, event).await + handle_customer_subscription_event(app, rpc_server, stripe_client, event).await } _ => Ok(()), }; @@ -609,6 +612,7 @@ async fn handle_customer_event( async fn handle_customer_subscription_event( app: &Arc, + rpc_server: &Arc, stripe_client: &stripe::Client, event: stripe::Event, ) -> anyhow::Result<()> { @@ -654,6 +658,12 @@ async fn handle_customer_subscription_event( .await?; } + // When the user's subscription changes, we want to refresh their LLM tokens + // to either grant/revoke access. + rpc_server + .refresh_llm_tokens_for_user(billing_customer.user_id) + .await; + Ok(()) } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index de77be21eb4e01416cf7b946019c56b54098ec51..ee95b6d41f53500f1d3288efba03292e4d505fec 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -132,6 +132,8 @@ async fn main() -> Result<()> { let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; + poll_stripe_events_periodically(state.clone(), rpc_server.clone()); + app = app .merge(collab::api::routes(rpc_server.clone())) .merge(collab::rpc::routes(rpc_server.clone())); @@ -140,7 +142,6 @@ async fn main() -> Result<()> { } if mode.is_api() { - poll_stripe_events_periodically(state.clone()); fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone());