collab: Fully move `StripeBilling` over to using `StripeClient` (#31722)

Marshall Bowers created

This PR moves over the last method on `StripeBilling` to use the
`StripeClient` trait, allowing us to fully mock out Stripe behaviors for
`StripeBilling` in tests.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      |   6 
crates/collab/src/rpc.rs                              |   7 
crates/collab/src/stripe_billing.rs                   |  52 ++---
crates/collab/src/stripe_client.rs                    |  35 +++
crates/collab/src/stripe_client/fake_stripe_client.rs |  54 ++++++
crates/collab/src/stripe_client/real_stripe_client.rs |  60 ++++++
crates/collab/src/tests/stripe_billing_tests.rs       | 109 +++++++++++++
7 files changed, 273 insertions(+), 50 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -1182,10 +1182,8 @@ async fn sync_subscription(
                 .has_active_billing_subscription(billing_customer.user_id)
                 .await?;
             if !already_has_active_billing_subscription {
-                let stripe_customer_id = billing_customer
-                    .stripe_customer_id
-                    .parse::<stripe::CustomerId>()
-                    .context("failed to parse Stripe customer ID from database")?;
+                let stripe_customer_id =
+                    StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 
                 stripe_billing
                     .subscribe_to_zed_free(stripe_customer_id)

crates/collab/src/rpc.rs 🔗

@@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::db::LlmDatabase;
 use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
+use crate::stripe_client::StripeCustomerId;
 use crate::{
     AppState, Error, Result, auth,
     db::{
@@ -4055,10 +4056,8 @@ async fn get_llm_api_token(
         if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
             billing_subscription
         } else {
-            let stripe_customer_id = billing_customer
-                .stripe_customer_id
-                .parse::<stripe::CustomerId>()
-                .context("failed to parse Stripe customer ID from database")?;
+            let stripe_customer_id =
+                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 
             let stripe_subscription = stripe_billing
                 .subscribe_to_zed_free(stripe_customer_id)

crates/collab/src/stripe_billing.rs 🔗

@@ -14,8 +14,9 @@ use crate::stripe_client::{
     RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
     StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
     StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
-    StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
-    StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
+    StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
+    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
+    StripeSubscriptionId, StripeSubscriptionTrialSettings,
     StripeSubscriptionTrialSettingsEndBehavior,
     StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
     UpdateSubscriptionParams,
@@ -23,7 +24,6 @@ use crate::stripe_client::{
 
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
-    real_client: Arc<stripe::Client>,
     client: Arc<dyn StripeClient>,
 }
 
@@ -38,7 +38,6 @@ impl StripeBilling {
     pub fn new(client: Arc<stripe::Client>) -> Self {
         Self {
             client: Arc::new(RealStripeClient::new(client.clone())),
-            real_client: client,
             state: RwLock::default(),
         }
     }
@@ -46,8 +45,6 @@ impl StripeBilling {
     #[cfg(test)]
     pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
         Self {
-            // This is just temporary until we can remove all usages of the real Stripe client.
-            real_client: Arc::new(stripe::Client::new("sk_test")),
             client,
             state: RwLock::default(),
         }
@@ -306,40 +303,33 @@ impl StripeBilling {
 
     pub async fn subscribe_to_zed_free(
         &self,
-        customer_id: stripe::CustomerId,
-    ) -> Result<stripe::Subscription> {
+        customer_id: StripeCustomerId,
+    ) -> Result<StripeSubscription> {
         let zed_free_price_id = self.zed_free_price_id().await?;
 
-        let existing_subscriptions = stripe::Subscription::list(
-            &self.real_client,
-            &stripe::ListSubscriptions {
-                customer: Some(customer_id.clone()),
-                status: None,
-                ..Default::default()
-            },
-        )
-        .await?;
+        let existing_subscriptions = self
+            .client
+            .list_subscriptions_for_customer(&customer_id)
+            .await?;
 
         let existing_active_subscription =
-            existing_subscriptions
-                .data
-                .into_iter()
-                .find(|subscription| {
-                    subscription.status == SubscriptionStatus::Active
-                        || subscription.status == SubscriptionStatus::Trialing
-                });
+            existing_subscriptions.into_iter().find(|subscription| {
+                subscription.status == SubscriptionStatus::Active
+                    || subscription.status == SubscriptionStatus::Trialing
+            });
         if let Some(subscription) = existing_active_subscription {
             return Ok(subscription);
         }
 
-        let mut params = stripe::CreateSubscription::new(customer_id);
-        params.items = Some(vec![stripe::CreateSubscriptionItems {
-            price: Some(zed_free_price_id.to_string()),
-            quantity: Some(1),
-            ..Default::default()
-        }]);
+        let params = StripeCreateSubscriptionParams {
+            customer: customer_id,
+            items: vec![StripeCreateSubscriptionItems {
+                price: Some(zed_free_price_id),
+                quantity: Some(1),
+            }],
+        };
 
-        let subscription = stripe::Subscription::create(&self.real_client, params).await?;
+        let subscription = self.client.create_subscription(params).await?;
 
         Ok(subscription)
     }

crates/collab/src/stripe_client.rs 🔗

@@ -30,21 +30,38 @@ pub struct CreateCustomerParams<'a> {
 #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
 pub struct StripeSubscriptionId(pub Arc<str>);
 
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
 pub struct StripeSubscription {
     pub id: StripeSubscriptionId,
+    pub customer: StripeCustomerId,
+    // TODO: Create our own version of this enum.
+    pub status: stripe::SubscriptionStatus,
+    pub current_period_end: i64,
+    pub current_period_start: i64,
     pub items: Vec<StripeSubscriptionItem>,
 }
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
 pub struct StripeSubscriptionItemId(pub Arc<str>);
 
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
 pub struct StripeSubscriptionItem {
     pub id: StripeSubscriptionItemId,
     pub price: Option<StripePrice>,
 }
 
+#[derive(Debug)]
+pub struct StripeCreateSubscriptionParams {
+    pub customer: StripeCustomerId,
+    pub items: Vec<StripeCreateSubscriptionItems>,
+}
+
+#[derive(Debug)]
+pub struct StripeCreateSubscriptionItems {
+    pub price: Option<StripePriceId>,
+    pub quantity: Option<u64>,
+}
+
 #[derive(Debug, Clone)]
 pub struct UpdateSubscriptionParams {
     pub items: Option<Vec<UpdateSubscriptionItems>>,
@@ -76,7 +93,7 @@ pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
 #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
 pub struct StripePriceId(pub Arc<str>);
 
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
 pub struct StripePrice {
     pub id: StripePriceId,
     pub unit_amount: Option<i64>,
@@ -84,7 +101,7 @@ pub struct StripePrice {
     pub recurring: Option<StripePriceRecurring>,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
 pub struct StripePriceRecurring {
     pub meter: Option<String>,
 }
@@ -160,11 +177,21 @@ pub trait StripeClient: Send + Sync {
 
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
 
+    async fn list_subscriptions_for_customer(
+        &self,
+        customer_id: &StripeCustomerId,
+    ) -> Result<Vec<StripeSubscription>>;
+
     async fn get_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,
     ) -> Result<StripeSubscription>;
 
+    async fn create_subscription(
+        &self,
+        params: StripeCreateSubscriptionParams,
+    ) -> Result<StripeSubscription>;
+
     async fn update_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -2,6 +2,7 @@ use std::sync::Arc;
 
 use anyhow::{Result, anyhow};
 use async_trait::async_trait;
+use chrono::{Duration, Utc};
 use collections::HashMap;
 use parking_lot::Mutex;
 use uuid::Uuid;
@@ -10,9 +11,10 @@ use crate::stripe_client::{
     CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
     StripeCheckoutSessionPaymentMethodCollection, StripeClient,
     StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
-    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
-    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
-    StripeSubscriptionId, UpdateSubscriptionParams,
+    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
+    StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
+    StripeSubscriptionItemId, UpdateSubscriptionParams,
 };
 
 #[derive(Debug, Clone)]
@@ -85,6 +87,21 @@ impl StripeClient for FakeStripeClient {
         Ok(customer)
     }
 
+    async fn list_subscriptions_for_customer(
+        &self,
+        customer_id: &StripeCustomerId,
+    ) -> Result<Vec<StripeSubscription>> {
+        let subscriptions = self
+            .subscriptions
+            .lock()
+            .values()
+            .filter(|subscription| subscription.customer == *customer_id)
+            .cloned()
+            .collect();
+
+        Ok(subscriptions)
+    }
+
     async fn get_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,
@@ -96,6 +113,37 @@ impl StripeClient for FakeStripeClient {
             .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
     }
 
+    async fn create_subscription(
+        &self,
+        params: StripeCreateSubscriptionParams,
+    ) -> Result<StripeSubscription> {
+        let now = Utc::now();
+
+        let subscription = StripeSubscription {
+            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
+            customer: params.customer,
+            status: stripe::SubscriptionStatus::Active,
+            current_period_start: now.timestamp(),
+            current_period_end: (now + Duration::days(30)).timestamp(),
+            items: params
+                .items
+                .into_iter()
+                .map(|item| StripeSubscriptionItem {
+                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
+                    price: item
+                        .price
+                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
+                })
+                .collect(),
+        };
+
+        self.subscriptions
+            .lock()
+            .insert(subscription.id.clone(), subscription.clone());
+
+        Ok(subscription)
+    }
+
     async fn update_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -20,10 +20,11 @@ use crate::stripe_client::{
     CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
     StripeCheckoutSessionPaymentMethodCollection, StripeClient,
     StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
-    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
-    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
-    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
-    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
+    StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
+    StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
+    StripeSubscriptionTrialSettingsEndBehavior,
     StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
 };
 
@@ -69,6 +70,29 @@ impl StripeClient for RealStripeClient {
         Ok(StripeCustomer::from(customer))
     }
 
+    async fn list_subscriptions_for_customer(
+        &self,
+        customer_id: &StripeCustomerId,
+    ) -> Result<Vec<StripeSubscription>> {
+        let customer_id = customer_id.try_into()?;
+
+        let subscriptions = stripe::Subscription::list(
+            &self.client,
+            &stripe::ListSubscriptions {
+                customer: Some(customer_id),
+                status: None,
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(subscriptions
+            .data
+            .into_iter()
+            .map(StripeSubscription::from)
+            .collect())
+    }
+
     async fn get_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,
@@ -80,6 +104,30 @@ impl StripeClient for RealStripeClient {
         Ok(StripeSubscription::from(subscription))
     }
 
+    async fn create_subscription(
+        &self,
+        params: StripeCreateSubscriptionParams,
+    ) -> Result<StripeSubscription> {
+        let customer_id = params.customer.try_into()?;
+
+        let mut create_subscription = stripe::CreateSubscription::new(customer_id);
+        create_subscription.items = Some(
+            params
+                .items
+                .into_iter()
+                .map(|item| stripe::CreateSubscriptionItems {
+                    price: item.price.map(|price| price.to_string()),
+                    quantity: item.quantity,
+                    ..Default::default()
+                })
+                .collect(),
+        );
+
+        let subscription = Subscription::create(&self.client, create_subscription).await?;
+
+        Ok(StripeSubscription::from(subscription))
+    }
+
     async fn update_subscription(
         &self,
         subscription_id: &StripeSubscriptionId,
@@ -220,6 +268,10 @@ impl From<Subscription> for StripeSubscription {
     fn from(value: Subscription) -> Self {
         Self {
             id: value.id.into(),
+            customer: value.customer.id().into(),
+            status: value.status,
+            current_period_start: value.current_period_start,
+            current_period_end: value.current_period_end,
             items: value.items.data.into_iter().map(Into::into).collect(),
         }
     }

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -1,5 +1,6 @@
 use std::sync::Arc;
 
+use chrono::{Duration, Utc};
 use pretty_assertions::assert_eq;
 
 use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
@@ -163,8 +164,13 @@ async fn test_subscribe_to_price() {
         .lock()
         .insert(price.id.clone(), price.clone());
 
+    let now = Utc::now();
     let subscription = StripeSubscription {
         id: StripeSubscriptionId("sub_test".into()),
+        customer: StripeCustomerId("cus_test".into()),
+        status: stripe::SubscriptionStatus::Active,
+        current_period_start: now.timestamp(),
+        current_period_end: (now + Duration::days(30)).timestamp(),
         items: vec![],
     };
     stripe_client
@@ -194,8 +200,13 @@ async fn test_subscribe_to_price() {
 
     // Subscribing to a price that is already on the subscription is a no-op.
     {
+        let now = Utc::now();
         let subscription = StripeSubscription {
             id: StripeSubscriptionId("sub_test".into()),
+            customer: StripeCustomerId("cus_test".into()),
+            status: stripe::SubscriptionStatus::Active,
+            current_period_start: now.timestamp(),
+            current_period_end: (now + Duration::days(30)).timestamp(),
             items: vec![StripeSubscriptionItem {
                 id: StripeSubscriptionItemId("si_test".into()),
                 price: Some(price.clone()),
@@ -215,6 +226,104 @@ async fn test_subscribe_to_price() {
     }
 }
 
+#[gpui::test]
+async fn test_subscribe_to_zed_free() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    let zed_pro_price = StripePrice {
+        id: StripePriceId("price_1".into()),
+        unit_amount: Some(0),
+        lookup_key: Some("zed-pro".to_string()),
+        recurring: None,
+    };
+    stripe_client
+        .prices
+        .lock()
+        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
+    let zed_free_price = StripePrice {
+        id: StripePriceId("price_2".into()),
+        unit_amount: Some(0),
+        lookup_key: Some("zed-free".to_string()),
+        recurring: None,
+    };
+    stripe_client
+        .prices
+        .lock()
+        .insert(zed_free_price.id.clone(), zed_free_price.clone());
+
+    stripe_billing.initialize().await.unwrap();
+
+    // Customer is subscribed to Zed Free when not already subscribed to a plan.
+    {
+        let customer_id = StripeCustomerId("cus_no_plan".into());
+
+        let subscription = stripe_billing
+            .subscribe_to_zed_free(customer_id)
+            .await
+            .unwrap();
+
+        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
+    }
+
+    // Customer is not subscribed to Zed Free when they already have an active subscription.
+    {
+        let customer_id = StripeCustomerId("cus_active_subscription".into());
+
+        let now = Utc::now();
+        let existing_subscription = StripeSubscription {
+            id: StripeSubscriptionId("sub_existing_active".into()),
+            customer: customer_id.clone(),
+            status: stripe::SubscriptionStatus::Active,
+            current_period_start: now.timestamp(),
+            current_period_end: (now + Duration::days(30)).timestamp(),
+            items: vec![StripeSubscriptionItem {
+                id: StripeSubscriptionItemId("si_test".into()),
+                price: Some(zed_pro_price.clone()),
+            }],
+        };
+        stripe_client.subscriptions.lock().insert(
+            existing_subscription.id.clone(),
+            existing_subscription.clone(),
+        );
+
+        let subscription = stripe_billing
+            .subscribe_to_zed_free(customer_id)
+            .await
+            .unwrap();
+
+        assert_eq!(subscription, existing_subscription);
+    }
+
+    // Customer is not subscribed to Zed Free when they already have a trial subscription.
+    {
+        let customer_id = StripeCustomerId("cus_trial_subscription".into());
+
+        let now = Utc::now();
+        let existing_subscription = StripeSubscription {
+            id: StripeSubscriptionId("sub_existing_trial".into()),
+            customer: customer_id.clone(),
+            status: stripe::SubscriptionStatus::Trialing,
+            current_period_start: now.timestamp(),
+            current_period_end: (now + Duration::days(14)).timestamp(),
+            items: vec![StripeSubscriptionItem {
+                id: StripeSubscriptionItemId("si_test".into()),
+                price: Some(zed_pro_price.clone()),
+            }],
+        };
+        stripe_client.subscriptions.lock().insert(
+            existing_subscription.id.clone(),
+            existing_subscription.clone(),
+        );
+
+        let subscription = stripe_billing
+            .subscribe_to_zed_free(customer_id)
+            .await
+            .unwrap();
+
+        assert_eq!(subscription, existing_subscription);
+    }
+}
+
 #[gpui::test]
 async fn test_bill_model_request_usage() {
     let (stripe_billing, stripe_client) = make_stripe_billing();