diff --git a/Cargo.lock b/Cargo.lock index a8baa9788bb1ffe05b391a6dd5a92134699a6cee..2a40cfe6927868e4806ae226fafa0b83f37627d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,7 @@ version = "0.37.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5" dependencies = [ + "chrono", "futures-util", "http-types", "hyper", diff --git a/Cargo.toml b/Cargo.toml index a5be5bc027016fb747fd7326bf97c90152848065..ca20aa13843add7313a8ceb3f06963920282acf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,7 +309,6 @@ async-dispatcher = "0.1" async-fs = "1.6" async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" } async-recursion = "1.0.0" -async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] } async-tar = "0.4.2" async-trait = "0.1" async-tungstenite = "0.23" @@ -461,6 +460,19 @@ wasmtime-wasi = "21.0.1" which = "6.0.0" wit-component = "0.201" +[workspace.dependencies.async-stripe] +version = "0.37" +default-features = false +features = [ + "runtime-tokio-hyper-rustls", + "billing", + "checkout", + "events", + # The features below are only enabled to get the `events` feature to build. + "chrono", + "connect", +] + [workspace.dependencies.windows] version = "0.58" features = [ diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0db4cf062b7832f64fca334a3b451d012dc14ce6..e6860942d81d88f6b12d157a277823386880990e 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,7 +1,8 @@ use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; -use anyhow::{anyhow, Context}; +use anyhow::{anyhow, bail, Context}; use axum::{extract, routing::post, Extension, Json, Router}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; @@ -10,10 +11,16 @@ use stripe::{ CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCustomer, Customer, CustomerId, + CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, + SubscriptionStatus, }; +use util::ResultExt; -use crate::db::BillingSubscriptionId; +use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::{ + billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, + CreateBillingSubscriptionParams, +}; use crate::{AppState, Error, Result}; pub fn router() -> Router { @@ -194,3 +201,184 @@ async fn manage_billing_subscription( billing_portal_session_url: session.url, })) } + +const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60); + +/// Polls the Stripe events API periodically to reconcile the records in our +/// database with the data in Stripe. +pub fn poll_stripe_events_periodically(app: Arc) { + let Some(stripe_client) = app.stripe_client.clone() else { + log::warn!("failed to retrieve Stripe client"); + return; + }; + + let executor = app.executor.clone(); + executor.spawn_detached({ + let executor = executor.clone(); + async move { + loop { + poll_stripe_events(&app, &stripe_client).await.log_err(); + + executor.sleep(POLL_EVENTS_INTERVAL).await; + } + } + }); +} + +async fn poll_stripe_events( + app: &Arc, + stripe_client: &stripe::Client, +) -> anyhow::Result<()> { + let event_types = [ + EventType::CustomerCreated.to_string(), + EventType::CustomerSubscriptionCreated.to_string(), + EventType::CustomerSubscriptionUpdated.to_string(), + EventType::CustomerSubscriptionPaused.to_string(), + EventType::CustomerSubscriptionResumed.to_string(), + EventType::CustomerSubscriptionDeleted.to_string(), + ] + .into_iter() + .map(|event_type| { + // Calling `to_string` on `stripe::EventType` members gives us a quoted string, + // so we need to unquote it. + event_type.trim_matches('"').to_string() + }) + .collect::>(); + + loop { + log::info!("retrieving events from Stripe: {}", event_types.join(", ")); + + let mut params = ListEvents::new(); + params.types = Some(event_types.clone()); + params.limit = Some(100); + + let events = stripe::Event::list(stripe_client, ¶ms).await?; + for event in events.data { + match event.type_ { + EventType::CustomerCreated => { + handle_customer_event(app, stripe_client, event) + .await + .log_err(); + } + EventType::CustomerSubscriptionCreated + | EventType::CustomerSubscriptionUpdated + | EventType::CustomerSubscriptionPaused + | EventType::CustomerSubscriptionResumed + | EventType::CustomerSubscriptionDeleted => { + handle_customer_subscription_event(app, stripe_client, event) + .await + .log_err(); + } + _ => {} + } + } + + if !events.has_more { + break; + } + } + + Ok(()) +} + +async fn handle_customer_event( + app: &Arc, + stripe_client: &stripe::Client, + event: stripe::Event, +) -> anyhow::Result<()> { + let EventObject::Customer(customer) = event.data.object else { + bail!("unexpected event payload for {}", event.id); + }; + + find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer))) + .await?; + + Ok(()) +} + +async fn handle_customer_subscription_event( + app: &Arc, + stripe_client: &stripe::Client, + event: stripe::Event, +) -> anyhow::Result<()> { + let EventObject::Subscription(subscription) = event.data.object else { + bail!("unexpected event payload for {}", event.id); + }; + + let billing_customer = + find_or_create_billing_customer(app, stripe_client, subscription.customer) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))?; + + app.db + .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams { + billing_customer_id: billing_customer.id, + stripe_subscription_id: subscription.id.to_string(), + stripe_subscription_status: subscription.status.into(), + }) + .await?; + + Ok(()) +} + +impl From for StripeSubscriptionStatus { + fn from(value: SubscriptionStatus) -> Self { + match value { + SubscriptionStatus::Incomplete => Self::Incomplete, + SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, + SubscriptionStatus::Trialing => Self::Trialing, + SubscriptionStatus::Active => Self::Active, + SubscriptionStatus::PastDue => Self::PastDue, + SubscriptionStatus::Canceled => Self::Canceled, + SubscriptionStatus::Unpaid => Self::Unpaid, + SubscriptionStatus::Paused => Self::Paused, + } + } +} + +/// Finds or creates a billing customer using the provided customer. +async fn find_or_create_billing_customer( + app: &Arc, + stripe_client: &stripe::Client, + customer_or_id: Expandable, +) -> anyhow::Result> { + let customer_id = match &customer_or_id { + Expandable::Id(id) => id, + Expandable::Object(customer) => customer.id.as_ref(), + }; + + // If we already have a billing customer record associated with the Stripe customer, + // there's nothing more we need to do. + if let Some(billing_customer) = app + .db + .get_billing_customer_by_stripe_customer_id(&customer_id) + .await? + { + return Ok(Some(billing_customer)); + } + + // If all we have is a customer ID, resolve it to a full customer record by + // hitting the Stripe API. + let customer = match customer_or_id { + Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?, + Expandable::Object(customer) => *customer, + }; + + let Some(email) = customer.email else { + return Ok(None); + }; + + let Some(user) = app.db.get_user_by_email(&email).await? else { + return Ok(None); + }; + + let billing_customer = app + .db + .create_billing_customer(&CreateBillingCustomerParams { + user_id: user.id, + stripe_customer_id: customer.id.to_string(), + }) + .await?; + + Ok(Some(billing_customer)) +} diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs index 9a0c00a0d334d2fdd9bc3c529527a6e33950e256..fd6dc8e7a1fb6c65e97958259ed6bfc1fdd15262 100644 --- a/crates/collab/src/db/queries/billing_customers.rs +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -39,4 +39,18 @@ impl Database { }) .await } + + /// Returns the billing customer for the user with the specified Stripe customer ID. + pub async fn get_billing_customer_by_stripe_customer_id( + &self, + stripe_customer_id: &str, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_customer::Entity::find() + .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id)) + .one(&*tx) + .await?) + }) + .await + } } diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 0b11f25aa3fd9efd31223835c379e81b89ac4357..85e2766a7484751123056b725f122403622911e0 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -30,6 +30,31 @@ impl Database { .await } + /// Upserts the billing subscription by its Stripe subscription ID. + pub async fn upsert_billing_subscription_by_stripe_subscription_id( + &self, + params: &CreateBillingSubscriptionParams, + ) -> Result<()> { + self.transaction(|tx| async move { + billing_subscription::Entity::insert(billing_subscription::ActiveModel { + billing_customer_id: ActiveValue::set(params.billing_customer_id), + stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), + stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([billing_subscription::Column::StripeSubscriptionId]) + .update_columns([billing_subscription::Column::StripeSubscriptionStatus]) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; + + Ok(()) + }) + .await + } + /// Returns the billing subscription with the specified ID. pub async fn get_billing_subscription_by_id( &self, diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index fde6aa0b12611cb531094e9f6582b16df43d7624..60c4aa8e3ccb808f43967f08bd5b56f6354a8e23 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -61,6 +61,17 @@ impl Database { .await } + /// Returns a user by email address. There are no access checks here, so this should only be used internally. + pub async fn get_user_by_email(&self, email: &str) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::EmailAddress.eq(email)) + .one(&*tx) + .await?) + }) + .await + } + /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally. pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result> { self.transaction(|tx| async move { diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index b8c01b8619b36d9935cc312efd57c2fdca7c3f60..7725988f99f6c679aeec422a8f8d05164e39e05e 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,6 +5,7 @@ use axum::{ routing::get, Extension, Router, }; +use collab::api::billing::poll_stripe_events_periodically; use collab::{ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, AppState, Config, RateLimiter, Result, @@ -95,6 +96,7 @@ async fn main() -> Result<()> { } if is_api { + poll_stripe_events_periodically(state.clone()); fetch_extensions_from_blob_store_periodically(state.clone()); }