From 598939d186f38f850f68b2e5a974d15f9103862f Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 16 Oct 2024 10:58:28 -0400 Subject: [PATCH] collab: Refresh the user's LLM token when their subscription changes (#19281) This PR makes it so collab will trigger a refresh for a user's LLM token whenever their subscription changes. This allows us to proactively push down changes to their subscription. In order to facilitate this, the Stripe event processing has been moved from the `api` service to the `collab` service in order to access the RPC server. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 18 ++++++++++++++---- crates/collab/src/main.rs | 3 ++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0e46e8453d95f0927b2215586becf30b6fef379b..6db40dd187788db103a961cb0d2fcdefeb47acaa 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -20,7 +20,7 @@ use stripe::{ use util::ResultExt; use crate::llm::DEFAULT_MAX_MONTHLY_SPEND; -use crate::rpc::ResultExt as _; +use crate::rpc::{ResultExt as _, Server}; use crate::{ db::{ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, @@ -404,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4; /// Polls the Stripe events API periodically to reconcile the records in our /// database with the data in Stripe. -pub fn poll_stripe_events_periodically(app: Arc) { +pub fn poll_stripe_events_periodically(app: Arc, rpc_server: Arc) { let Some(stripe_client) = app.stripe_client.clone() else { log::warn!("failed to retrieve Stripe client"); return; @@ -415,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc) { let executor = executor.clone(); async move { loop { - poll_stripe_events(&app, &stripe_client).await.log_err(); + poll_stripe_events(&app, &rpc_server, &stripe_client) + .await + .log_err(); executor.sleep(POLL_EVENTS_INTERVAL).await; } @@ -425,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc) { async fn poll_stripe_events( app: &Arc, + rpc_server: &Arc, stripe_client: &stripe::Client, ) -> anyhow::Result<()> { fn event_type_to_string(event_type: EventType) -> String { @@ -541,7 +544,7 @@ async fn poll_stripe_events( | EventType::CustomerSubscriptionPaused | EventType::CustomerSubscriptionResumed | EventType::CustomerSubscriptionDeleted => { - handle_customer_subscription_event(app, stripe_client, event).await + handle_customer_subscription_event(app, rpc_server, stripe_client, event).await } _ => Ok(()), }; @@ -609,6 +612,7 @@ async fn handle_customer_event( async fn handle_customer_subscription_event( app: &Arc, + rpc_server: &Arc, stripe_client: &stripe::Client, event: stripe::Event, ) -> anyhow::Result<()> { @@ -654,6 +658,12 @@ async fn handle_customer_subscription_event( .await?; } + // When the user's subscription changes, we want to refresh their LLM tokens + // to either grant/revoke access. + rpc_server + .refresh_llm_tokens_for_user(billing_customer.user_id) + .await; + Ok(()) } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index de77be21eb4e01416cf7b946019c56b54098ec51..ee95b6d41f53500f1d3288efba03292e4d505fec 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -132,6 +132,8 @@ async fn main() -> Result<()> { let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; + poll_stripe_events_periodically(state.clone(), rpc_server.clone()); + app = app .merge(collab::api::routes(rpc_server.clone())) .merge(collab::rpc::routes(rpc_server.clone())); @@ -140,7 +142,6 @@ async fn main() -> Result<()> { } if mode.is_api() { - poll_stripe_events_periodically(state.clone()); fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone());