collab: Use `StripeClient` in `StripeBilling::subscribe_to_price` (#31631)

Marshall Bowers created

This PR updates the `StripeBilling::subscribe_to_price` method to use
the `StripeClient` trait.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      |   2 
crates/collab/src/stripe_billing.rs                   |  49 ++--
crates/collab/src/stripe_client.rs                    |  57 +++++
crates/collab/src/stripe_client/fake_stripe_client.rs |  35 +++
crates/collab/src/stripe_client/real_stripe_client.rs | 124 ++++++++++++
crates/collab/src/tests/stripe_billing_tests.rs       |  69 +++++++
6 files changed, 306 insertions(+), 30 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -1578,7 +1578,7 @@ async fn sync_model_request_usage_with_stripe(
             };
 
             stripe_billing
-                .subscribe_to_price(&stripe_subscription_id, price)
+                .subscribe_to_price(&stripe_subscription_id.into(), price)
                 .await?;
             stripe_billing
                 .bill_model_request_usage(

crates/collab/src/stripe_billing.rs 🔗

@@ -13,6 +13,9 @@ use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use crate::stripe_client::{
     RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+    StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
+    UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
+    UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
 };
 
 pub struct StripeBilling {
@@ -166,14 +169,12 @@ impl StripeBilling {
 
     pub async fn subscribe_to_price(
         &self,
-        subscription_id: &stripe::SubscriptionId,
+        subscription_id: &StripeSubscriptionId,
         price: &StripePrice,
     ) -> Result<()> {
-        let subscription =
-            stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
+        let subscription = self.client.get_subscription(subscription_id).await?;
 
-        let price_id = price.id.clone().try_into()?;
-        if subscription_contains_price(&subscription, &price_id) {
+        if subscription_contains_price(&subscription, &price.id) {
             return Ok(());
         }
 
@@ -182,23 +183,21 @@ impl StripeBilling {
         let price_per_unit = price.unit_amount.unwrap_or_default();
         let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
 
-        stripe::Subscription::update(
-            &self.real_client,
-            subscription_id,
-            stripe::UpdateSubscription {
-                items: Some(vec![stripe::UpdateSubscriptionItems {
-                    price: Some(price.id.to_string()),
-                    ..Default::default()
-                }]),
-                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
-                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
-                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
-                    },
-                }),
-                ..Default::default()
-            },
-        )
-        .await?;
+        self.client
+            .update_subscription(
+                subscription_id,
+                UpdateSubscriptionParams {
+                    items: Some(vec![UpdateSubscriptionItems {
+                        price: Some(price.id.clone()),
+                    }]),
+                    trial_settings: Some(UpdateSubscriptionTrialSettings {
+                        end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
+                            missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
+                        },
+                    }),
+                },
+            )
+            .await?;
 
         Ok(())
     }
@@ -419,10 +418,10 @@ struct StripeCreateMeterEventPayload<'a> {
 }
 
 fn subscription_contains_price(
-    subscription: &stripe::Subscription,
-    price_id: &stripe::PriceId,
+    subscription: &StripeSubscription,
+    price_id: &StripePriceId,
 ) -> bool {
-    subscription.items.data.iter().any(|item| {
+    subscription.items.iter().any(|item| {
         item.price
             .as_ref()
             .map_or(false, |price| price.id == *price_id)

crates/collab/src/stripe_client.rs 🔗

@@ -26,6 +26,52 @@ pub struct CreateCustomerParams<'a> {
     pub email: Option<&'a str>,
 }
 
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripeSubscriptionId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeSubscription {
+    pub id: StripeSubscriptionId,
+    pub items: Vec<StripeSubscriptionItem>,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripeSubscriptionItemId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeSubscriptionItem {
+    pub id: StripeSubscriptionItemId,
+    pub price: Option<StripePrice>,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionParams {
+    pub items: Option<Vec<UpdateSubscriptionItems>>,
+    pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct UpdateSubscriptionItems {
+    pub price: Option<StripePriceId>,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionTrialSettings {
+    pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionTrialSettingsEndBehavior {
+    pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
+    Cancel,
+    CreateInvoice,
+    Pause,
+}
+
 #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
 pub struct StripePriceId(pub Arc<str>);
 
@@ -57,6 +103,17 @@ pub trait StripeClient: Send + Sync {
 
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
 
+    async fn get_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+    ) -> Result<StripeSubscription>;
+
+    async fn update_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+        params: UpdateSubscriptionParams,
+    ) -> Result<()>;
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>>;
 
     async fn list_meters(&self) -> Result<Vec<StripeMeter>>;

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -1,6 +1,6 @@
 use std::sync::Arc;
 
-use anyhow::Result;
+use anyhow::{Result, anyhow};
 use async_trait::async_trait;
 use collections::HashMap;
 use parking_lot::Mutex;
@@ -8,11 +8,15 @@ use uuid::Uuid;
 
 use crate::stripe_client::{
     CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
-    StripeMeterId, StripePrice, StripePriceId,
+    StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
+    UpdateSubscriptionParams,
 };
 
 pub struct FakeStripeClient {
     pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
+    pub update_subscription_calls:
+        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
     pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
     pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 }
@@ -21,6 +25,8 @@ impl FakeStripeClient {
     pub fn new() -> Self {
         Self {
             customers: Arc::new(Mutex::new(HashMap::default())),
+            subscriptions: Arc::new(Mutex::new(HashMap::default())),
+            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
             prices: Arc::new(Mutex::new(HashMap::default())),
             meters: Arc::new(Mutex::new(HashMap::default())),
         }
@@ -52,6 +58,31 @@ impl StripeClient for FakeStripeClient {
         Ok(customer)
     }
 
+    async fn get_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+    ) -> Result<StripeSubscription> {
+        self.subscriptions
+            .lock()
+            .get(subscription_id)
+            .cloned()
+            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
+    }
+
+    async fn update_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+        params: UpdateSubscriptionParams,
+    ) -> Result<()> {
+        let subscription = self.get_subscription(subscription_id).await?;
+
+        self.update_subscription_calls
+            .lock()
+            .push((subscription.id, params));
+
+        Ok(())
+    }
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>> {
         let prices = self.prices.lock().values().cloned().collect();
 

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -4,11 +4,17 @@ use std::sync::Arc;
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
 use serde::Serialize;
-use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
+use stripe::{
+    CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
+    SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
+    UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
+    UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+};
 
 use crate::stripe_client::{
     CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
-    StripePriceId, StripePriceRecurring,
+    StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
+    StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams,
 };
 
 pub struct RealStripeClient {
@@ -53,6 +59,46 @@ impl StripeClient for RealStripeClient {
         Ok(StripeCustomer::from(customer))
     }
 
+    async fn get_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+    ) -> Result<StripeSubscription> {
+        let subscription_id = subscription_id.try_into()?;
+
+        let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+        Ok(StripeSubscription::from(subscription))
+    }
+
+    async fn update_subscription(
+        &self,
+        subscription_id: &StripeSubscriptionId,
+        params: UpdateSubscriptionParams,
+    ) -> Result<()> {
+        let subscription_id = subscription_id.try_into()?;
+
+        stripe::Subscription::update(
+            &self.client,
+            &subscription_id,
+            stripe::UpdateSubscription {
+                items: params.items.map(|items| {
+                    items
+                        .into_iter()
+                        .map(|item| UpdateSubscriptionItems {
+                            price: item.price.map(|price| price.to_string()),
+                            ..Default::default()
+                        })
+                        .collect()
+                }),
+                trial_settings: params.trial_settings.map(Into::into),
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(())
+    }
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>> {
         let response = stripe::Price::list(
             &self.client,
@@ -108,6 +154,80 @@ impl From<Customer> for StripeCustomer {
     }
 }
 
+impl From<SubscriptionId> for StripeSubscriptionId {
+    fn from(value: SubscriptionId) -> Self {
+        Self(value.as_str().into())
+    }
+}
+
+impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
+    type Error = anyhow::Error;
+
+    fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
+        Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
+    }
+}
+
+impl From<Subscription> for StripeSubscription {
+    fn from(value: Subscription) -> Self {
+        Self {
+            id: value.id.into(),
+            items: value.items.data.into_iter().map(Into::into).collect(),
+        }
+    }
+}
+
+impl From<SubscriptionItemId> for StripeSubscriptionItemId {
+    fn from(value: SubscriptionItemId) -> Self {
+        Self(value.as_str().into())
+    }
+}
+
+impl From<SubscriptionItem> for StripeSubscriptionItem {
+    fn from(value: SubscriptionItem) -> Self {
+        Self {
+            id: value.id.into(),
+            price: value.price.map(Into::into),
+        }
+    }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
+    for UpdateSubscriptionTrialSettings
+{
+    fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
+        Self {
+            end_behavior: value.end_behavior.into(),
+        }
+    }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
+    for UpdateSubscriptionTrialSettingsEndBehavior
+{
+    fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
+        Self {
+            missing_payment_method: value.missing_payment_method.into(),
+        }
+    }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+    for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
+{
+    fn from(
+        value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+    ) -> Self {
+        match value {
+            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+                Self::CreateInvoice
+            }
+            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+        }
+    }
+}
+
 impl From<PriceId> for StripePriceId {
     fn from(value: PriceId) -> Self {
         Self(value.as_str().into())

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -5,6 +5,8 @@ use pretty_assertions::assert_eq;
 use crate::stripe_billing::StripeBilling;
 use crate::stripe_client::{
     FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+    UpdateSubscriptionItems,
 };
 
 fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -141,3 +143,70 @@ async fn test_find_or_create_customer_by_email() {
         assert_eq!(customer.email.as_deref(), Some(email));
     }
 }
+
+#[gpui::test]
+async fn test_subscribe_to_price() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    let price = StripePrice {
+        id: StripePriceId("price_test".into()),
+        unit_amount: Some(2000),
+        lookup_key: Some("test-price".to_string()),
+        recurring: None,
+    };
+    stripe_client
+        .prices
+        .lock()
+        .insert(price.id.clone(), price.clone());
+
+    let subscription = StripeSubscription {
+        id: StripeSubscriptionId("sub_test".into()),
+        items: vec![],
+    };
+    stripe_client
+        .subscriptions
+        .lock()
+        .insert(subscription.id.clone(), subscription.clone());
+
+    stripe_billing
+        .subscribe_to_price(&subscription.id, &price)
+        .await
+        .unwrap();
+
+    let update_subscription_calls = stripe_client
+        .update_subscription_calls
+        .lock()
+        .iter()
+        .map(|(id, params)| (id.clone(), params.clone()))
+        .collect::<Vec<_>>();
+    assert_eq!(update_subscription_calls.len(), 1);
+    assert_eq!(update_subscription_calls[0].0, subscription.id);
+    assert_eq!(
+        update_subscription_calls[0].1.items,
+        Some(vec![UpdateSubscriptionItems {
+            price: Some(price.id.clone())
+        }])
+    );
+
+    // Subscribing to a price that is already on the subscription is a no-op.
+    {
+        let subscription = StripeSubscription {
+            id: StripeSubscriptionId("sub_test".into()),
+            items: vec![StripeSubscriptionItem {
+                id: StripeSubscriptionItemId("si_test".into()),
+                price: Some(price.clone()),
+            }],
+        };
+        stripe_client
+            .subscriptions
+            .lock()
+            .insert(subscription.id.clone(), subscription.clone());
+
+        stripe_billing
+            .subscribe_to_price(&subscription.id, &price)
+            .await
+            .unwrap();
+
+        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
+    }
+}