@@ -309,7 +309,6 @@ async-dispatcher = "0.1"
async-fs = "1.6"
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
async-recursion = "1.0.0"
-async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] }
async-tar = "0.4.2"
async-trait = "0.1"
async-tungstenite = "0.23"
@@ -461,6 +460,19 @@ wasmtime-wasi = "21.0.1"
which = "6.0.0"
wit-component = "0.201"
+[workspace.dependencies.async-stripe]
+version = "0.37"
+default-features = false
+features = [
+ "runtime-tokio-hyper-rustls",
+ "billing",
+ "checkout",
+ "events",
+ # The features below are only enabled to get the `events` feature to build.
+ "chrono",
+ "connect",
+]
+
[workspace.dependencies.windows]
version = "0.58"
features = [
@@ -1,7 +1,8 @@
use std::str::FromStr;
use std::sync::Arc;
+use std::time::Duration;
-use anyhow::{anyhow, Context};
+use anyhow::{anyhow, bail, Context};
use axum::{extract, routing::post, Extension, Json, Router};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
@@ -10,10 +11,16 @@ use stripe::{
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
- CreateCustomer, Customer, CustomerId,
+ CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
+ SubscriptionStatus,
};
+use util::ResultExt;
-use crate::db::BillingSubscriptionId;
+use crate::db::billing_subscription::StripeSubscriptionStatus;
+use crate::db::{
+ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
+ CreateBillingSubscriptionParams,
+};
use crate::{AppState, Error, Result};
pub fn router() -> Router {
@@ -194,3 +201,184 @@ async fn manage_billing_subscription(
billing_portal_session_url: session.url,
}))
}
+
+const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
+
+/// Polls the Stripe events API periodically to reconcile the records in our
+/// database with the data in Stripe.
+pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
+ let Some(stripe_client) = app.stripe_client.clone() else {
+ log::warn!("failed to retrieve Stripe client");
+ return;
+ };
+
+ let executor = app.executor.clone();
+ executor.spawn_detached({
+ let executor = executor.clone();
+ async move {
+ loop {
+ poll_stripe_events(&app, &stripe_client).await.log_err();
+
+ executor.sleep(POLL_EVENTS_INTERVAL).await;
+ }
+ }
+ });
+}
+
+async fn poll_stripe_events(
+ app: &Arc<AppState>,
+ stripe_client: &stripe::Client,
+) -> anyhow::Result<()> {
+ let event_types = [
+ EventType::CustomerCreated.to_string(),
+ EventType::CustomerSubscriptionCreated.to_string(),
+ EventType::CustomerSubscriptionUpdated.to_string(),
+ EventType::CustomerSubscriptionPaused.to_string(),
+ EventType::CustomerSubscriptionResumed.to_string(),
+ EventType::CustomerSubscriptionDeleted.to_string(),
+ ]
+ .into_iter()
+ .map(|event_type| {
+ // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
+ // so we need to unquote it.
+ event_type.trim_matches('"').to_string()
+ })
+ .collect::<Vec<_>>();
+
+ loop {
+ log::info!("retrieving events from Stripe: {}", event_types.join(", "));
+
+ let mut params = ListEvents::new();
+ params.types = Some(event_types.clone());
+ params.limit = Some(100);
+
+ let events = stripe::Event::list(stripe_client, ¶ms).await?;
+ for event in events.data {
+ match event.type_ {
+ EventType::CustomerCreated => {
+ handle_customer_event(app, stripe_client, event)
+ .await
+ .log_err();
+ }
+ EventType::CustomerSubscriptionCreated
+ | EventType::CustomerSubscriptionUpdated
+ | EventType::CustomerSubscriptionPaused
+ | EventType::CustomerSubscriptionResumed
+ | EventType::CustomerSubscriptionDeleted => {
+ handle_customer_subscription_event(app, stripe_client, event)
+ .await
+ .log_err();
+ }
+ _ => {}
+ }
+ }
+
+ if !events.has_more {
+ break;
+ }
+ }
+
+ Ok(())
+}
+
+async fn handle_customer_event(
+ app: &Arc<AppState>,
+ stripe_client: &stripe::Client,
+ event: stripe::Event,
+) -> anyhow::Result<()> {
+ let EventObject::Customer(customer) = event.data.object else {
+ bail!("unexpected event payload for {}", event.id);
+ };
+
+ find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer)))
+ .await?;
+
+ Ok(())
+}
+
+async fn handle_customer_subscription_event(
+ app: &Arc<AppState>,
+ stripe_client: &stripe::Client,
+ event: stripe::Event,
+) -> anyhow::Result<()> {
+ let EventObject::Subscription(subscription) = event.data.object else {
+ bail!("unexpected event payload for {}", event.id);
+ };
+
+ let billing_customer =
+ find_or_create_billing_customer(app, stripe_client, subscription.customer)
+ .await?
+ .ok_or_else(|| anyhow!("billing customer not found"))?;
+
+ app.db
+ .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams {
+ billing_customer_id: billing_customer.id,
+ stripe_subscription_id: subscription.id.to_string(),
+ stripe_subscription_status: subscription.status.into(),
+ })
+ .await?;
+
+ Ok(())
+}
+
+impl From<SubscriptionStatus> for StripeSubscriptionStatus {
+ fn from(value: SubscriptionStatus) -> Self {
+ match value {
+ SubscriptionStatus::Incomplete => Self::Incomplete,
+ SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
+ SubscriptionStatus::Trialing => Self::Trialing,
+ SubscriptionStatus::Active => Self::Active,
+ SubscriptionStatus::PastDue => Self::PastDue,
+ SubscriptionStatus::Canceled => Self::Canceled,
+ SubscriptionStatus::Unpaid => Self::Unpaid,
+ SubscriptionStatus::Paused => Self::Paused,
+ }
+ }
+}
+
+/// Finds or creates a billing customer using the provided customer.
+async fn find_or_create_billing_customer(
+ app: &Arc<AppState>,
+ stripe_client: &stripe::Client,
+ customer_or_id: Expandable<Customer>,
+) -> anyhow::Result<Option<billing_customer::Model>> {
+ let customer_id = match &customer_or_id {
+ Expandable::Id(id) => id,
+ Expandable::Object(customer) => customer.id.as_ref(),
+ };
+
+ // If we already have a billing customer record associated with the Stripe customer,
+ // there's nothing more we need to do.
+ if let Some(billing_customer) = app
+ .db
+ .get_billing_customer_by_stripe_customer_id(&customer_id)
+ .await?
+ {
+ return Ok(Some(billing_customer));
+ }
+
+ // If all we have is a customer ID, resolve it to a full customer record by
+ // hitting the Stripe API.
+ let customer = match customer_or_id {
+ Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
+ Expandable::Object(customer) => *customer,
+ };
+
+ let Some(email) = customer.email else {
+ return Ok(None);
+ };
+
+ let Some(user) = app.db.get_user_by_email(&email).await? else {
+ return Ok(None);
+ };
+
+ let billing_customer = app
+ .db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id: user.id,
+ stripe_customer_id: customer.id.to_string(),
+ })
+ .await?;
+
+ Ok(Some(billing_customer))
+}
@@ -5,6 +5,7 @@ use axum::{
routing::get,
Extension, Router,
};
+use collab::api::billing::poll_stripe_events_periodically;
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
rpc::ResultExt, AppState, Config, RateLimiter, Result,
@@ -95,6 +96,7 @@ async fn main() -> Result<()> {
}
if is_api {
+ poll_stripe_events_periodically(state.clone());
fetch_extensions_from_blob_store_periodically(state.clone());
}