@@ -20,7 +20,7 @@ use stripe::{
use util::ResultExt;
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
-use crate::rpc::ResultExt as _;
+use crate::rpc::{ResultExt as _, Server};
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
@@ -404,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
-pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
+pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
@@ -415,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
let executor = executor.clone();
async move {
loop {
- poll_stripe_events(&app, &stripe_client).await.log_err();
+ poll_stripe_events(&app, &rpc_server, &stripe_client)
+ .await
+ .log_err();
executor.sleep(POLL_EVENTS_INTERVAL).await;
}
@@ -425,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
async fn poll_stripe_events(
app: &Arc<AppState>,
+ rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
fn event_type_to_string(event_type: EventType) -> String {
@@ -541,7 +544,7 @@ async fn poll_stripe_events(
| EventType::CustomerSubscriptionPaused
| EventType::CustomerSubscriptionResumed
| EventType::CustomerSubscriptionDeleted => {
- handle_customer_subscription_event(app, stripe_client, event).await
+ handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
}
_ => Ok(()),
};
@@ -609,6 +612,7 @@ async fn handle_customer_event(
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
+ rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
@@ -654,6 +658,12 @@ async fn handle_customer_subscription_event(
.await?;
}
+ // When the user's subscription changes, we want to refresh their LLM tokens
+ // to either grant/revoke access.
+ rpc_server
+ .refresh_llm_tokens_for_user(billing_customer.user_id)
+ .await;
+
Ok(())
}
@@ -132,6 +132,8 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
+ poll_stripe_events_periodically(state.clone(), rpc_server.clone());
+
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));
@@ -140,7 +142,6 @@ async fn main() -> Result<()> {
}
if mode.is_api() {
- poll_stripe_events_periodically(state.clone());
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());