collab: Use `StripeClient` in `sync_subscription` (#31761)

Marshall Bowers and Ben Brandt created

This PR updates the `sync_subscription` function to use the
`StripeClient` trait instead of using `stripe::Client` directly.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/collab/src/api/billing.rs                      | 90 ++++++------
crates/collab/src/db/tables/billing_subscription.rs   | 15 ++
crates/collab/src/lib.rs                              | 10 +
crates/collab/src/rpc.rs                              | 24 +--
crates/collab/src/stripe_billing.rs                   | 10 
crates/collab/src/stripe_client.rs                    | 18 ++
crates/collab/src/stripe_client/fake_stripe_client.rs | 17 ++
crates/collab/src/stripe_client/real_stripe_client.rs | 56 +++++++
crates/collab/src/tests/stripe_billing_tests.rs       |  8 +
crates/collab/src/tests/test_server.rs                |  4 
10 files changed, 177 insertions(+), 75 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -17,8 +17,8 @@ use stripe::{
     CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
     CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
     CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
-    CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
-    Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
+    CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
+    PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
 };
 use util::{ResultExt, maybe};
 
@@ -29,7 +29,10 @@ use crate::db::billing_subscription::{
 use crate::llm::db::subscription_usage_meter::CompletionMode;
 use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
 use crate::rpc::{ResultExt as _, Server};
-use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
+use crate::stripe_client::{
+    StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
+    StripeSubscriptionId,
+};
 use crate::{AppState, Error, Result};
 use crate::{db::UserId, llm::db::LlmDatabase};
 use crate::{
@@ -426,7 +429,7 @@ async fn manage_billing_subscription(
         .await?
         .context("user not found")?;
 
-    let Some(stripe_client) = app.stripe_client.clone() else {
+    let Some(stripe_client) = app.real_stripe_client.clone() else {
         log::error!("failed to retrieve Stripe client");
         Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
@@ -644,7 +647,7 @@ async fn migrate_to_new_billing(
     Extension(app): Extension<Arc<AppState>>,
     extract::Json(body): extract::Json<MigrateToNewBillingBody>,
 ) -> Result<Json<MigrateToNewBillingResponse>> {
-    let Some(stripe_client) = app.stripe_client.clone() else {
+    let Some(stripe_client) = app.real_stripe_client.clone() else {
         log::error!("failed to retrieve Stripe client");
         Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
@@ -723,6 +726,13 @@ async fn sync_billing_subscription(
     Extension(app): Extension<Arc<AppState>>,
     extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 ) -> Result<Json<SyncBillingSubscriptionResponse>> {
+    let Some(real_stripe_client) = app.real_stripe_client.clone() else {
+        log::error!("failed to retrieve Stripe client");
+        Err(Error::http(
+            StatusCode::NOT_IMPLEMENTED,
+            "not supported".into(),
+        ))?
+    };
     let Some(stripe_client) = app.stripe_client.clone() else {
         log::error!("failed to retrieve Stripe client");
         Err(Error::http(
@@ -748,7 +758,7 @@ async fn sync_billing_subscription(
         .context("failed to parse Stripe customer ID from database")?;
 
     let subscriptions = Subscription::list(
-        &stripe_client,
+        &real_stripe_client,
         &stripe::ListSubscriptions {
             customer: Some(stripe_customer_id),
             // Sync all non-canceled subscriptions.
@@ -761,7 +771,7 @@ async fn sync_billing_subscription(
     for subscription in subscriptions.data {
         let subscription_id = subscription.id.clone();
 
-        sync_subscription(&app, &stripe_client, subscription)
+        sync_subscription(&app, &stripe_client, subscription.into())
             .await
             .with_context(|| {
                 format!(
@@ -806,6 +816,10 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 /// Polls the Stripe events API periodically to reconcile the records in our
 /// database with the data in Stripe.
 pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
+    let Some(real_stripe_client) = app.real_stripe_client.clone() else {
+        log::warn!("failed to retrieve Stripe client");
+        return;
+    };
     let Some(stripe_client) = app.stripe_client.clone() else {
         log::warn!("failed to retrieve Stripe client");
         return;
@@ -816,7 +830,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
         let executor = executor.clone();
         async move {
             loop {
-                poll_stripe_events(&app, &rpc_server, &stripe_client)
+                poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
                     .await
                     .log_err();
 
@@ -829,7 +843,8 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
 async fn poll_stripe_events(
     app: &Arc<AppState>,
     rpc_server: &Arc<Server>,
-    stripe_client: &stripe::Client,
+    stripe_client: &Arc<dyn StripeClient>,
+    real_stripe_client: &stripe::Client,
 ) -> anyhow::Result<()> {
     fn event_type_to_string(event_type: EventType) -> String {
         // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
@@ -861,7 +876,7 @@ async fn poll_stripe_events(
     params.types = Some(event_types.clone());
     params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 
-    let mut event_pages = stripe::Event::list(&stripe_client, &params)
+    let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
         .await?
         .paginate(params);
 
@@ -905,7 +920,7 @@ async fn poll_stripe_events(
                 break;
             } else {
                 log::info!("Stripe events: retrieving next page");
-                event_pages = event_pages.next(&stripe_client).await?;
+                event_pages = event_pages.next(&real_stripe_client).await?;
             }
         } else {
             break;
@@ -945,7 +960,7 @@ async fn poll_stripe_events(
 
         let process_result = match event.type_ {
             EventType::CustomerCreated | EventType::CustomerUpdated => {
-                handle_customer_event(app, stripe_client, event).await
+                handle_customer_event(app, real_stripe_client, event).await
             }
             EventType::CustomerSubscriptionCreated
             | EventType::CustomerSubscriptionUpdated
@@ -1020,8 +1035,8 @@ async fn handle_customer_event(
 
 async fn sync_subscription(
     app: &Arc<AppState>,
-    stripe_client: &stripe::Client,
-    subscription: stripe::Subscription,
+    stripe_client: &Arc<dyn StripeClient>,
+    subscription: StripeSubscription,
 ) -> anyhow::Result<billing_customer::Model> {
     let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
         stripe_billing
@@ -1032,7 +1047,7 @@ async fn sync_subscription(
     };
 
     let billing_customer =
-        find_or_create_billing_customer(app, stripe_client, subscription.customer)
+        find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
             .await?
             .context("billing customer not found")?;
 
@@ -1060,7 +1075,7 @@ async fn sync_subscription(
             .as_ref()
             .and_then(|details| details.reason)
             .map_or(false, |reason| {
-                reason == CancellationDetailsReason::PaymentFailed
+                reason == StripeCancellationDetailsReason::PaymentFailed
             });
 
     if was_canceled_due_to_payment_failure {
@@ -1077,7 +1092,7 @@ async fn sync_subscription(
 
     if let Some(existing_subscription) = app
         .db
-        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
+        .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
         .await?
     {
         app.db
@@ -1118,20 +1133,13 @@ async fn sync_subscription(
             if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
                 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
             {
-                let stripe_subscription_id = existing_subscription
-                    .stripe_subscription_id
-                    .parse::<stripe::SubscriptionId>()
-                    .context("failed to parse Stripe subscription ID from database")?;
+                let stripe_subscription_id = StripeSubscriptionId(
+                    existing_subscription.stripe_subscription_id.clone().into(),
+                );
 
-                Subscription::cancel(
-                    &stripe_client,
-                    &stripe_subscription_id,
-                    stripe::CancelSubscription {
-                        invoice_now: None,
-                        ..Default::default()
-                    },
-                )
-                .await?;
+                stripe_client
+                    .cancel_subscription(&stripe_subscription_id)
+                    .await?;
             } else {
                 // If the user already has an active billing subscription, ignore the
                 // event and return an `Ok` to signal that it was processed
@@ -1198,7 +1206,7 @@ async fn sync_subscription(
 async fn handle_customer_subscription_event(
     app: &Arc<AppState>,
     rpc_server: &Arc<Server>,
-    stripe_client: &stripe::Client,
+    stripe_client: &Arc<dyn StripeClient>,
     event: stripe::Event,
 ) -> anyhow::Result<()> {
     let EventObject::Subscription(subscription) = event.data.object else {
@@ -1207,7 +1215,7 @@ async fn handle_customer_subscription_event(
 
     log::info!("handling Stripe {} event: {}", event.type_, event.id);
 
-    let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
+    let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
 
     // When the user's subscription changes, push down any changes to their plan.
     rpc_server
@@ -1403,30 +1411,20 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
 /// Finds or creates a billing customer using the provided customer.
 pub async fn find_or_create_billing_customer(
     app: &Arc<AppState>,
-    stripe_client: &stripe::Client,
-    customer_or_id: Expandable<Customer>,
+    stripe_client: &dyn StripeClient,
+    customer_id: &StripeCustomerId,
 ) -> anyhow::Result<Option<billing_customer::Model>> {
-    let customer_id = match &customer_or_id {
-        Expandable::Id(id) => id,
-        Expandable::Object(customer) => customer.id.as_ref(),
-    };
-
     // If we already have a billing customer record associated with the Stripe customer,
     // there's nothing more we need to do.
     if let Some(billing_customer) = app
         .db
-        .get_billing_customer_by_stripe_customer_id(customer_id)
+        .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
         .await?
     {
         return Ok(Some(billing_customer));
     }
 
-    // If all we have is a customer ID, resolve it to a full customer record by
-    // hitting the Stripe API.
-    let customer = match customer_or_id {
-        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
-        Expandable::Object(customer) => *customer,
-    };
+    let customer = stripe_client.get_customer(customer_id).await?;
 
     let Some(email) = customer.email else {
         return Ok(None);

crates/collab/src/db/tables/billing_subscription.rs 🔗

@@ -1,4 +1,5 @@
 use crate::db::{BillingCustomerId, BillingSubscriptionId};
+use crate::stripe_client;
 use chrono::{Datelike as _, NaiveDate, Utc};
 use sea_orm::entity::prelude::*;
 use serde::Serialize;
@@ -159,3 +160,17 @@ pub enum StripeCancellationReason {
     #[sea_orm(string_value = "payment_failed")]
     PaymentFailed,
 }
+
+impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
+    fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
+        match value {
+            stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
+                Self::CancellationRequested
+            }
+            stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
+                Self::PaymentDisputed
+            }
+            stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
+        }
+    }
+}

crates/collab/src/lib.rs 🔗

@@ -30,6 +30,7 @@ use std::{path::PathBuf, sync::Arc};
 use util::ResultExt;
 
 use crate::stripe_billing::StripeBilling;
+use crate::stripe_client::{RealStripeClient, StripeClient};
 
 pub type Result<T, E = Error> = std::result::Result<T, E>;
 
@@ -270,7 +271,10 @@ pub struct AppState {
     pub llm_db: Option<Arc<LlmDatabase>>,
     pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
     pub blob_store_client: Option<aws_sdk_s3::Client>,
-    pub stripe_client: Option<Arc<stripe::Client>>,
+    /// This is a real instance of the Stripe client; we're working to replace references to this with the
+    /// [`StripeClient`] trait.
+    pub real_stripe_client: Option<Arc<stripe::Client>>,
+    pub stripe_client: Option<Arc<dyn StripeClient>>,
     pub stripe_billing: Option<Arc<StripeBilling>>,
     pub executor: Executor,
     pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
@@ -323,7 +327,9 @@ impl AppState {
             stripe_billing: stripe_client
                 .clone()
                 .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
-            stripe_client,
+            real_stripe_client: stripe_client.clone(),
+            stripe_client: stripe_client
+                .map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
             executor,
             kinesis_client: if config.kinesis_access_key.is_some() {
                 build_kinesis_client(&config).await.log_err()

crates/collab/src/rpc.rs 🔗

@@ -4034,23 +4034,19 @@ async fn get_llm_api_token(
         .as_ref()
         .context("failed to retrieve Stripe billing object")?;
 
-    let billing_customer =
-        if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
-            billing_customer
-        } else {
-            let customer_id = stripe_billing
-                .find_or_create_customer_by_email(user.email_address.as_deref())
-                .await?
-                .try_into()?;
+    let billing_customer = if let Some(billing_customer) =
+        db.get_billing_customer_by_user_id(user.id).await?
+    {
+        billing_customer
+    } else {
+        let customer_id = stripe_billing
+            .find_or_create_customer_by_email(user.email_address.as_deref())
+            .await?;
 
-            find_or_create_billing_customer(
-                &session.app_state,
-                &stripe_client,
-                stripe::Expandable::Id(customer_id),
-            )
+        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
             .await?
             .context("billing customer not found")?
-        };
+    };
 
     let billing_subscription =
         if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {

crates/collab/src/stripe_billing.rs 🔗

@@ -111,14 +111,12 @@ impl StripeBilling {
 
     pub async fn determine_subscription_kind(
         &self,
-        subscription: &stripe::Subscription,
+        subscription: &StripeSubscription,
     ) -> Option<SubscriptionKind> {
-        let zed_pro_price_id: stripe::PriceId =
-            self.zed_pro_price_id().await.ok()?.try_into().ok()?;
-        let zed_free_price_id: stripe::PriceId =
-            self.zed_free_price_id().await.ok()?.try_into().ok()?;
+        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
+        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
 
-        subscription.items.data.iter().find_map(|item| {
+        subscription.items.iter().find_map(|item| {
             let price = item.price.as_ref()?;
 
             if price.id == zed_pro_price_id {

crates/collab/src/stripe_client.rs 🔗

@@ -39,6 +39,8 @@ pub struct StripeSubscription {
     pub current_period_end: i64,
     pub current_period_start: i64,
     pub items: Vec<StripeSubscriptionItem>,
+    pub cancel_at: Option<i64>,
+    pub cancellation_details: Option<StripeCancellationDetails>,
 }
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
@@ -50,6 +52,18 @@ pub struct StripeSubscriptionItem {
     pub price: Option<StripePrice>,
 }
 
+#[derive(Debug, Clone, PartialEq)]
+pub struct StripeCancellationDetails {
+    pub reason: Option<StripeCancellationDetailsReason>,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeCancellationDetailsReason {
+    CancellationRequested,
+    PaymentDisputed,
+    PaymentFailed,
+}
+
 #[derive(Debug)]
 pub struct StripeCreateSubscriptionParams {
     pub customer: StripeCustomerId,
@@ -175,6 +189,8 @@ pub struct StripeCheckoutSession {
 pub trait StripeClient: Send + Sync {
     async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
 
+    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
+
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
 
     async fn list_subscriptions_for_customer(
@@ -198,6 +214,8 @@ pub trait StripeClient: Send + Sync {
         params: UpdateSubscriptionParams,
     ) -> Result<()>;
 
+    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>>;
 
     async fn list_meters(&self) -> Result<Vec<StripeMeter>>;

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -74,6 +74,14 @@ impl StripeClient for FakeStripeClient {
             .collect())
     }
 
+    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
+        self.customers
+            .lock()
+            .get(customer_id)
+            .cloned()
+            .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
+    }
+
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
         let customer = StripeCustomer {
             id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
@@ -135,6 +143,8 @@ impl StripeClient for FakeStripeClient {
                         .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
                 })
                 .collect(),
+            cancel_at: None,
+            cancellation_details: None,
         };
 
         self.subscriptions
@@ -158,6 +168,13 @@ impl StripeClient for FakeStripeClient {
         Ok(())
     }
 
+    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
+        // TODO: Implement fake subscription cancellation.
+        let _ = subscription_id;
+
+        Ok(())
+    }
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>> {
         let prices = self.prices.lock().values().cloned().collect();
 

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -5,9 +5,9 @@ use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
 use serde::Serialize;
 use stripe::{
-    CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
-    CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
-    CreateCheckoutSessionSubscriptionDataTrialSettings,
+    CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
+    CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
+    CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
     CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
     CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
     CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
@@ -17,9 +17,9 @@ use stripe::{
 };
 
 use crate::stripe_client::{
-    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
-    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
-    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+    CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason,
+    StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
+    StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
     StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
     StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
     StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
@@ -57,6 +57,14 @@ impl StripeClient for RealStripeClient {
             .collect())
     }
 
+    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
+        let customer_id = customer_id.try_into()?;
+
+        let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
+
+        Ok(StripeCustomer::from(customer))
+    }
+
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
         let customer = Customer::create(
             &self.client,
@@ -157,6 +165,22 @@ impl StripeClient for RealStripeClient {
         Ok(())
     }
 
+    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
+        let subscription_id = subscription_id.try_into()?;
+
+        Subscription::cancel(
+            &self.client,
+            &subscription_id,
+            stripe::CancelSubscription {
+                invoice_now: None,
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(())
+    }
+
     async fn list_prices(&self) -> Result<Vec<StripePrice>> {
         let response = stripe::Price::list(
             &self.client,
@@ -273,6 +297,26 @@ impl From<Subscription> for StripeSubscription {
             current_period_start: value.current_period_start,
             current_period_end: value.current_period_end,
             items: value.items.data.into_iter().map(Into::into).collect(),
+            cancel_at: value.cancel_at,
+            cancellation_details: value.cancellation_details.map(Into::into),
+        }
+    }
+}
+
+impl From<CancellationDetails> for StripeCancellationDetails {
+    fn from(value: CancellationDetails) -> Self {
+        Self {
+            reason: value.reason.map(Into::into),
+        }
+    }
+}
+
+impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
+    fn from(value: CancellationDetailsReason) -> Self {
+        match value {
+            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
+            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
+            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
         }
     }
 }

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -172,6 +172,8 @@ async fn test_subscribe_to_price() {
         current_period_start: now.timestamp(),
         current_period_end: (now + Duration::days(30)).timestamp(),
         items: vec![],
+        cancel_at: None,
+        cancellation_details: None,
     };
     stripe_client
         .subscriptions
@@ -211,6 +213,8 @@ async fn test_subscribe_to_price() {
                 id: StripeSubscriptionItemId("si_test".into()),
                 price: Some(price.clone()),
             }],
+            cancel_at: None,
+            cancellation_details: None,
         };
         stripe_client
             .subscriptions
@@ -280,6 +284,8 @@ async fn test_subscribe_to_zed_free() {
                 id: StripeSubscriptionItemId("si_test".into()),
                 price: Some(zed_pro_price.clone()),
             }],
+            cancel_at: None,
+            cancellation_details: None,
         };
         stripe_client.subscriptions.lock().insert(
             existing_subscription.id.clone(),
@@ -309,6 +315,8 @@ async fn test_subscribe_to_zed_free() {
                 id: StripeSubscriptionItemId("si_test".into()),
                 price: Some(zed_pro_price.clone()),
             }],
+            cancel_at: None,
+            cancellation_details: None,
         };
         stripe_client.subscriptions.lock().insert(
             existing_subscription.id.clone(),

crates/collab/src/tests/test_server.rs 🔗

@@ -1,3 +1,4 @@
+use crate::stripe_client::FakeStripeClient;
 use crate::{
     AppState, Config,
     db::{NewUserParams, UserId, tests::TestDb},
@@ -522,7 +523,8 @@ impl TestServer {
             llm_db: None,
             livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
             blob_store_client: None,
-            stripe_client: None,
+            real_stripe_client: None,
+            stripe_client: Some(Arc::new(FakeStripeClient::new())),
             stripe_billing: None,
             executor,
             kinesis_client: None,