From d93891ba630345d7e28b324aa027e650f4e7832f Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 29 Jul 2024 23:50:07 -0400 Subject: [PATCH] collab: Lay groundwork for reconciling with Stripe using the events API (#15459) This PR lays the initial groundwork for using the Stripe events API to reconcile the data in our system with what's in Stripe. We're using the events API over webhooks so that we don't need to stand up the associated infrastructure needed to handle webhooks effectively (namely an asynchronous job queue). Since we haven't configured the Stripe API keys yet, we won't actually spawn the reconciliation background task yet, so this is currently a no-op. Release Notes: - N/A --- Cargo.lock | 1 + Cargo.toml | 14 +- crates/collab/src/api/billing.rs | 194 +++++++++++++++++- .../src/db/queries/billing_customers.rs | 14 ++ .../src/db/queries/billing_subscriptions.rs | 25 +++ crates/collab/src/db/queries/users.rs | 11 + crates/collab/src/main.rs | 2 + 7 files changed, 257 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a8baa9788bb1ffe05b391a6dd5a92134699a6cee..2a40cfe6927868e4806ae226fafa0b83f37627d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,7 @@ version = "0.37.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5" dependencies = [ + "chrono", "futures-util", "http-types", "hyper", diff --git a/Cargo.toml b/Cargo.toml index a5be5bc027016fb747fd7326bf97c90152848065..ca20aa13843add7313a8ceb3f06963920282acf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,7 +309,6 @@ async-dispatcher = "0.1" async-fs = "1.6" async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" } async-recursion = "1.0.0" -async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] } async-tar = "0.4.2" async-trait = "0.1" async-tungstenite = "0.23" @@ -461,6 +460,19 @@ wasmtime-wasi = "21.0.1" which = "6.0.0" wit-component = "0.201" +[workspace.dependencies.async-stripe] +version = "0.37" +default-features = false +features = [ + "runtime-tokio-hyper-rustls", + "billing", + "checkout", + "events", + # The features below are only enabled to get the `events` feature to build. + "chrono", + "connect", +] + [workspace.dependencies.windows] version = "0.58" features = [ diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0db4cf062b7832f64fca334a3b451d012dc14ce6..e6860942d81d88f6b12d157a277823386880990e 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,7 +1,8 @@ use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; -use anyhow::{anyhow, Context}; +use anyhow::{anyhow, bail, Context}; use axum::{extract, routing::post, Extension, Json, Router}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; @@ -10,10 +11,16 @@ use stripe::{ CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCustomer, Customer, CustomerId, + CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, + SubscriptionStatus, }; +use util::ResultExt; -use crate::db::BillingSubscriptionId; +use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::{ + billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, + CreateBillingSubscriptionParams, +}; use crate::{AppState, Error, Result}; pub fn router() -> Router { @@ -194,3 +201,184 @@ async fn manage_billing_subscription( billing_portal_session_url: session.url, })) } + +const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60); + +/// Polls the Stripe events API periodically to reconcile the records in our +/// database with the data in Stripe. +pub fn poll_stripe_events_periodically(app: Arc) { + let Some(stripe_client) = app.stripe_client.clone() else { + log::warn!("failed to retrieve Stripe client"); + return; + }; + + let executor = app.executor.clone(); + executor.spawn_detached({ + let executor = executor.clone(); + async move { + loop { + poll_stripe_events(&app, &stripe_client).await.log_err(); + + executor.sleep(POLL_EVENTS_INTERVAL).await; + } + } + }); +} + +async fn poll_stripe_events( + app: &Arc, + stripe_client: &stripe::Client, +) -> anyhow::Result<()> { + let event_types = [ + EventType::CustomerCreated.to_string(), + EventType::CustomerSubscriptionCreated.to_string(), + EventType::CustomerSubscriptionUpdated.to_string(), + EventType::CustomerSubscriptionPaused.to_string(), + EventType::CustomerSubscriptionResumed.to_string(), + EventType::CustomerSubscriptionDeleted.to_string(), + ] + .into_iter() + .map(|event_type| { + // Calling `to_string` on `stripe::EventType` members gives us a quoted string, + // so we need to unquote it. + event_type.trim_matches('"').to_string() + }) + .collect::>(); + + loop { + log::info!("retrieving events from Stripe: {}", event_types.join(", ")); + + let mut params = ListEvents::new(); + params.types = Some(event_types.clone()); + params.limit = Some(100); + + let events = stripe::Event::list(stripe_client, ¶ms).await?; + for event in events.data { + match event.type_ { + EventType::CustomerCreated => { + handle_customer_event(app, stripe_client, event) + .await + .log_err(); + } + EventType::CustomerSubscriptionCreated + | EventType::CustomerSubscriptionUpdated + | EventType::CustomerSubscriptionPaused + | EventType::CustomerSubscriptionResumed + | EventType::CustomerSubscriptionDeleted => { + handle_customer_subscription_event(app, stripe_client, event) + .await + .log_err(); + } + _ => {} + } + } + + if !events.has_more { + break; + } + } + + Ok(()) +} + +async fn handle_customer_event( + app: &Arc, + stripe_client: &stripe::Client, + event: stripe::Event, +) -> anyhow::Result<()> { + let EventObject::Customer(customer) = event.data.object else { + bail!("unexpected event payload for {}", event.id); + }; + + find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer))) + .await?; + + Ok(()) +} + +async fn handle_customer_subscription_event( + app: &Arc, + stripe_client: &stripe::Client, + event: stripe::Event, +) -> anyhow::Result<()> { + let EventObject::Subscription(subscription) = event.data.object else { + bail!("unexpected event payload for {}", event.id); + }; + + let billing_customer = + find_or_create_billing_customer(app, stripe_client, subscription.customer) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))?; + + app.db + .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams { + billing_customer_id: billing_customer.id, + stripe_subscription_id: subscription.id.to_string(), + stripe_subscription_status: subscription.status.into(), + }) + .await?; + + Ok(()) +} + +impl From for StripeSubscriptionStatus { + fn from(value: SubscriptionStatus) -> Self { + match value { + SubscriptionStatus::Incomplete => Self::Incomplete, + SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, + SubscriptionStatus::Trialing => Self::Trialing, + SubscriptionStatus::Active => Self::Active, + SubscriptionStatus::PastDue => Self::PastDue, + SubscriptionStatus::Canceled => Self::Canceled, + SubscriptionStatus::Unpaid => Self::Unpaid, + SubscriptionStatus::Paused => Self::Paused, + } + } +} + +/// Finds or creates a billing customer using the provided customer. +async fn find_or_create_billing_customer( + app: &Arc, + stripe_client: &stripe::Client, + customer_or_id: Expandable, +) -> anyhow::Result> { + let customer_id = match &customer_or_id { + Expandable::Id(id) => id, + Expandable::Object(customer) => customer.id.as_ref(), + }; + + // If we already have a billing customer record associated with the Stripe customer, + // there's nothing more we need to do. + if let Some(billing_customer) = app + .db + .get_billing_customer_by_stripe_customer_id(&customer_id) + .await? + { + return Ok(Some(billing_customer)); + } + + // If all we have is a customer ID, resolve it to a full customer record by + // hitting the Stripe API. + let customer = match customer_or_id { + Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?, + Expandable::Object(customer) => *customer, + }; + + let Some(email) = customer.email else { + return Ok(None); + }; + + let Some(user) = app.db.get_user_by_email(&email).await? else { + return Ok(None); + }; + + let billing_customer = app + .db + .create_billing_customer(&CreateBillingCustomerParams { + user_id: user.id, + stripe_customer_id: customer.id.to_string(), + }) + .await?; + + Ok(Some(billing_customer)) +} diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs index 9a0c00a0d334d2fdd9bc3c529527a6e33950e256..fd6dc8e7a1fb6c65e97958259ed6bfc1fdd15262 100644 --- a/crates/collab/src/db/queries/billing_customers.rs +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -39,4 +39,18 @@ impl Database { }) .await } + + /// Returns the billing customer for the user with the specified Stripe customer ID. + pub async fn get_billing_customer_by_stripe_customer_id( + &self, + stripe_customer_id: &str, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_customer::Entity::find() + .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id)) + .one(&*tx) + .await?) + }) + .await + } } diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 0b11f25aa3fd9efd31223835c379e81b89ac4357..85e2766a7484751123056b725f122403622911e0 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -30,6 +30,31 @@ impl Database { .await } + /// Upserts the billing subscription by its Stripe subscription ID. + pub async fn upsert_billing_subscription_by_stripe_subscription_id( + &self, + params: &CreateBillingSubscriptionParams, + ) -> Result<()> { + self.transaction(|tx| async move { + billing_subscription::Entity::insert(billing_subscription::ActiveModel { + billing_customer_id: ActiveValue::set(params.billing_customer_id), + stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), + stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([billing_subscription::Column::StripeSubscriptionId]) + .update_columns([billing_subscription::Column::StripeSubscriptionStatus]) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; + + Ok(()) + }) + .await + } + /// Returns the billing subscription with the specified ID. pub async fn get_billing_subscription_by_id( &self, diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index fde6aa0b12611cb531094e9f6582b16df43d7624..60c4aa8e3ccb808f43967f08bd5b56f6354a8e23 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -61,6 +61,17 @@ impl Database { .await } + /// Returns a user by email address. There are no access checks here, so this should only be used internally. + pub async fn get_user_by_email(&self, email: &str) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::EmailAddress.eq(email)) + .one(&*tx) + .await?) + }) + .await + } + /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally. pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result> { self.transaction(|tx| async move { diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index b8c01b8619b36d9935cc312efd57c2fdca7c3f60..7725988f99f6c679aeec422a8f8d05164e39e05e 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,6 +5,7 @@ use axum::{ routing::get, Extension, Router, }; +use collab::api::billing::poll_stripe_events_periodically; use collab::{ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, AppState, Config, RateLimiter, Result, @@ -95,6 +96,7 @@ async fn main() -> Result<()> { } if is_api { + poll_stripe_events_periodically(state.clone()); fetch_extensions_from_blob_store_periodically(state.clone()); }