collab: Introduce `StripeClient` trait to abstract over Stripe interactions (#31615)

Marshall Bowers created

This PR introduces a new `StripeClient` trait to abstract over
interacting with the Stripe API.

This will allow us to more easily test our billing code.

This initial cut is small and focuses just on making
`StripeBilling::find_or_create_customer_by_email` testable. I'll follow
up with using the `StripeClient` in more places.

Release Notes:

- N/A

Change summary

crates/collab/Cargo.toml                              |  1 
crates/collab/src/api/billing.rs                      |  1 
crates/collab/src/lib.rs                              |  1 
crates/collab/src/rpc.rs                              |  3 
crates/collab/src/stripe_billing.rs                   | 69 ++++++-----
crates/collab/src/stripe_client.rs                    | 33 +++++
crates/collab/src/stripe_client/fake_stripe_client.rs | 47 ++++++++
crates/collab/src/stripe_client/real_stripe_client.rs | 74 +++++++++++++
crates/collab/src/tests.rs                            |  1 
crates/collab/src/tests/stripe_billing_tests.rs       | 60 ++++++++++
10 files changed, 257 insertions(+), 33 deletions(-)

Detailed changes

crates/collab/Cargo.toml 🔗

@@ -20,6 +20,7 @@ test-support = ["sqlite"]
 [dependencies]
 anyhow.workspace = true
 async-stripe.workspace = true
+async-trait.workspace = true
 async-tungstenite.workspace = true
 aws-config = { version = "1.1.5" }
 aws-sdk-s3 = { version = "1.15.0" }

crates/collab/src/api/billing.rs 🔗

@@ -344,6 +344,7 @@ async fn create_billing_subscription(
         stripe_billing
             .find_or_create_customer_by_email(user.email_address.as_deref())
             .await?
+            .try_into()?
     };
 
     let success_url = format!(

crates/collab/src/lib.rs 🔗

@@ -9,6 +9,7 @@ pub mod migrations;
 pub mod rpc;
 pub mod seed;
 pub mod stripe_billing;
+pub mod stripe_client;
 pub mod user_backfiller;
 
 #[cfg(test)]

crates/collab/src/rpc.rs 🔗

@@ -4039,7 +4039,8 @@ async fn get_llm_api_token(
         } else {
             let customer_id = stripe_billing
                 .find_or_create_customer_by_email(user.email_address.as_deref())
-                .await?;
+                .await?
+                .try_into()?;
 
             find_or_create_billing_customer(
                 &session.app_state,

crates/collab/src/stripe_billing.rs 🔗

@@ -1,19 +1,22 @@
 use std::sync::Arc;
 
-use crate::Result;
-use crate::db::billing_subscription::SubscriptionKind;
-use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use anyhow::{Context as _, anyhow};
 use chrono::Utc;
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
-use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
+use stripe::{PriceId, SubscriptionStatus};
 use tokio::sync::RwLock;
 use uuid::Uuid;
 
+use crate::Result;
+use crate::db::billing_subscription::SubscriptionKind;
+use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
+use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
+
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
-    client: Arc<stripe::Client>,
+    real_client: Arc<stripe::Client>,
+    client: Arc<dyn StripeClient>,
 }
 
 #[derive(Default)]
@@ -26,6 +29,17 @@ struct StripeBillingState {
 impl StripeBilling {
     pub fn new(client: Arc<stripe::Client>) -> Self {
         Self {
+            client: Arc::new(RealStripeClient::new(client.clone())),
+            real_client: client,
+            state: RwLock::default(),
+        }
+    }
+
+    #[cfg(test)]
+    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
+        Self {
+            // This is just temporary until we can remove all usages of the real Stripe client.
+            real_client: Arc::new(stripe::Client::new("sk_test")),
             client,
             state: RwLock::default(),
         }
@@ -37,9 +51,9 @@ impl StripeBilling {
         let mut state = self.state.write().await;
 
         let (meters, prices) = futures::try_join!(
-            StripeMeter::list(&self.client),
+            StripeMeter::list(&self.real_client),
             stripe::Price::list(
-                &self.client,
+                &self.real_client,
                 &stripe::ListPrices {
                     limit: Some(100),
                     ..Default::default()
@@ -129,18 +143,11 @@ impl StripeBilling {
     pub async fn find_or_create_customer_by_email(
         &self,
         email_address: Option<&str>,
-    ) -> Result<CustomerId> {
+    ) -> Result<StripeCustomerId> {
         let existing_customer = if let Some(email) = email_address {
-            let customers = Customer::list(
-                &self.client,
-                &stripe::ListCustomers {
-                    email: Some(email),
-                    ..Default::default()
-                },
-            )
-            .await?;
+            let customers = self.client.list_customers_by_email(email).await?;
 
-            customers.data.first().cloned()
+            customers.first().cloned()
         } else {
             None
         };
@@ -148,14 +155,12 @@ impl StripeBilling {
         let customer_id = if let Some(existing_customer) = existing_customer {
             existing_customer.id
         } else {
-            let customer = Customer::create(
-                &self.client,
-                CreateCustomer {
+            let customer = self
+                .client
+                .create_customer(crate::stripe_client::CreateCustomerParams {
                     email: email_address,
-                    ..Default::default()
-                },
-            )
-            .await?;
+                })
+                .await?;
 
             customer.id
         };
@@ -169,7 +174,7 @@ impl StripeBilling {
         price: &stripe::Price,
     ) -> Result<()> {
         let subscription =
-            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+            stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
 
         if subscription_contains_price(&subscription, &price.id) {
             return Ok(());
@@ -181,7 +186,7 @@ impl StripeBilling {
         let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
 
         stripe::Subscription::update(
-            &self.client,
+            &self.real_client,
             subscription_id,
             stripe::UpdateSubscription {
                 items: Some(vec![stripe::UpdateSubscriptionItems {
@@ -211,7 +216,7 @@ impl StripeBilling {
         let idempotency_key = Uuid::new_v4();
 
         StripeMeterEvent::create(
-            &self.client,
+            &self.real_client,
             StripeCreateMeterEventParams {
                 identifier: &format!("model_requests/{}", idempotency_key),
                 event_name,
@@ -246,7 +251,7 @@ impl StripeBilling {
         }]);
         params.success_url = Some(success_url);
 
-        let session = stripe::CheckoutSession::create(&self.client, params).await?;
+        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
 
@@ -300,7 +305,7 @@ impl StripeBilling {
         }]);
         params.success_url = Some(success_url);
 
-        let session = stripe::CheckoutSession::create(&self.client, params).await?;
+        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
 
@@ -311,7 +316,7 @@ impl StripeBilling {
         let zed_free_price_id = self.zed_free_price_id().await?;
 
         let existing_subscriptions = stripe::Subscription::list(
-            &self.client,
+            &self.real_client,
             &stripe::ListSubscriptions {
                 customer: Some(customer_id.clone()),
                 status: None,
@@ -339,7 +344,7 @@ impl StripeBilling {
             ..Default::default()
         }]);
 
-        let subscription = stripe::Subscription::create(&self.client, params).await?;
+        let subscription = stripe::Subscription::create(&self.real_client, params).await?;
 
         Ok(subscription)
     }
@@ -365,7 +370,7 @@ impl StripeBilling {
         }]);
         params.success_url = Some(success_url);
 
-        let session = stripe::CheckoutSession::create(&self.client, params).await?;
+        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
 }

crates/collab/src/stripe_client.rs 🔗

@@ -0,0 +1,33 @@
+#[cfg(test)]
+mod fake_stripe_client;
+mod real_stripe_client;
+
+use std::sync::Arc;
+
+use anyhow::Result;
+use async_trait::async_trait;
+
+#[cfg(test)]
+pub use fake_stripe_client::*;
+pub use real_stripe_client::*;
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone)]
+pub struct StripeCustomerId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeCustomer {
+    pub id: StripeCustomerId,
+    pub email: Option<String>,
+}
+
+#[derive(Debug)]
+pub struct CreateCustomerParams<'a> {
+    pub email: Option<&'a str>,
+}
+
+#[async_trait]
+pub trait StripeClient: Send + Sync {
+    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
+
+    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+}

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -0,0 +1,47 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use collections::HashMap;
+use parking_lot::Mutex;
+use uuid::Uuid;
+
+use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+
+pub struct FakeStripeClient {
+    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+}
+
+impl FakeStripeClient {
+    pub fn new() -> Self {
+        Self {
+            customers: Arc::new(Mutex::new(HashMap::default())),
+        }
+    }
+}
+
+#[async_trait]
+impl StripeClient for FakeStripeClient {
+    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
+        Ok(self
+            .customers
+            .lock()
+            .values()
+            .filter(|customer| customer.email.as_deref() == Some(email))
+            .cloned()
+            .collect())
+    }
+
+    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
+        let customer = StripeCustomer {
+            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
+            email: params.email.map(|email| email.to_string()),
+        };
+
+        self.customers
+            .lock()
+            .insert(customer.id.clone(), customer.clone());
+
+        Ok(customer)
+    }
+}

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -0,0 +1,74 @@
+use std::str::FromStr as _;
+use std::sync::Arc;
+
+use anyhow::{Context as _, Result};
+use async_trait::async_trait;
+use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
+
+use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+
+pub struct RealStripeClient {
+    client: Arc<stripe::Client>,
+}
+
+impl RealStripeClient {
+    pub fn new(client: Arc<stripe::Client>) -> Self {
+        Self { client }
+    }
+}
+
+#[async_trait]
+impl StripeClient for RealStripeClient {
+    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
+        let response = Customer::list(
+            &self.client,
+            &ListCustomers {
+                email: Some(email),
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(response
+            .data
+            .into_iter()
+            .map(StripeCustomer::from)
+            .collect())
+    }
+
+    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
+        let customer = Customer::create(
+            &self.client,
+            CreateCustomer {
+                email: params.email,
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(StripeCustomer::from(customer))
+    }
+}
+
+impl From<CustomerId> for StripeCustomerId {
+    fn from(value: CustomerId) -> Self {
+        Self(value.as_str().into())
+    }
+}
+
+impl TryFrom<StripeCustomerId> for CustomerId {
+    type Error = anyhow::Error;
+
+    fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
+        Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
+    }
+}
+
+impl From<Customer> for StripeCustomer {
+    fn from(value: Customer) -> Self {
+        StripeCustomer {
+            id: value.id.into(),
+            email: value.email,
+        }
+    }
+}

crates/collab/src/tests.rs 🔗

@@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
 mod random_project_collaboration_tests;
 mod randomized_test_helpers;
 mod remote_editing_collaboration_tests;
+mod stripe_billing_tests;
 mod test_server;
 
 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -0,0 +1,60 @@
+use std::sync::Arc;
+
+use pretty_assertions::assert_eq;
+
+use crate::stripe_billing::StripeBilling;
+use crate::stripe_client::FakeStripeClient;
+
+fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
+    let stripe_client = Arc::new(FakeStripeClient::new());
+    let stripe_billing = StripeBilling::test(stripe_client.clone());
+
+    (stripe_billing, stripe_client)
+}
+
+#[gpui::test]
+async fn test_find_or_create_customer_by_email() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    // Create a customer with an email that doesn't yet correspond to a customer.
+    {
+        let email = "user@example.com";
+
+        let customer_id = stripe_billing
+            .find_or_create_customer_by_email(Some(email))
+            .await
+            .unwrap();
+
+        let customer = stripe_client
+            .customers
+            .lock()
+            .get(&customer_id)
+            .unwrap()
+            .clone();
+        assert_eq!(customer.email.as_deref(), Some(email));
+    }
+
+    // Create a customer with an email that corresponds to an existing customer.
+    {
+        let email = "user2@example.com";
+
+        let existing_customer_id = stripe_billing
+            .find_or_create_customer_by_email(Some(email))
+            .await
+            .unwrap();
+
+        let customer_id = stripe_billing
+            .find_or_create_customer_by_email(Some(email))
+            .await
+            .unwrap();
+        assert_eq!(customer_id, existing_customer_id);
+
+        let customer = stripe_client
+            .customers
+            .lock()
+            .get(&customer_id)
+            .unwrap()
+            .clone();
+        assert_eq!(customer.email.as_deref(), Some(email));
+    }
+}