From c7047d5f0a2d4567273e677e6d04a3048e0f68fb Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 29 May 2025 19:49:14 -0400 Subject: [PATCH] collab: Fully move `StripeBilling` over to using `StripeClient` (#31722) This PR moves over the last method on `StripeBilling` to use the `StripeClient` trait, allowing us to fully mock out Stripe behaviors for `StripeBilling` in tests. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 6 +- crates/collab/src/rpc.rs | 7 +- crates/collab/src/stripe_billing.rs | 52 ++++----- crates/collab/src/stripe_client.rs | 35 +++++- .../src/stripe_client/fake_stripe_client.rs | 54 ++++++++- .../src/stripe_client/real_stripe_client.rs | 60 +++++++++- .../collab/src/tests/stripe_billing_tests.rs | 109 ++++++++++++++++++ 7 files changed, 273 insertions(+), 50 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 7fd35fc0300f3c15baca5ace5650d4a0e2a22780..b4642d023d26356e42e8a6e7531cc43c6306eb6a 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1182,10 +1182,8 @@ async fn sync_subscription( .has_active_billing_subscription(billing_customer.user_id) .await?; if !already_has_active_billing_subscription { - let stripe_customer_id = billing_customer - .stripe_customer_id - .parse::() - .context("failed to parse Stripe customer ID from database")?; + let stripe_customer_id = + StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); stripe_billing .subscribe_to_zed_free(stripe_customer_id) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 5316304cb0be5516c6de0fd5de0eae1fb8ee521d..0dba1b3e65a17cb7d693e2aedad67cc2bd600144 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::db::LlmDatabase; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims}; +use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ @@ -4055,10 +4056,8 @@ async fn get_llm_api_token( if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { billing_subscription } else { - let stripe_customer_id = billing_customer - .stripe_customer_id - .parse::() - .context("failed to parse Stripe customer ID from database")?; + let stripe_customer_id = + StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); let stripe_subscription = stripe_billing .subscribe_to_zed_free(stripe_customer_id) diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 4a8bb41c2cb71ee72378b37a1bc1485041e9f84d..30e7e1b87a217523e693f364457c0628453e2dca 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -14,8 +14,9 @@ use crate::stripe_client::{ RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, - StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings, + StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, + StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription, + StripeSubscriptionId, StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, UpdateSubscriptionParams, @@ -23,7 +24,6 @@ use crate::stripe_client::{ pub struct StripeBilling { state: RwLock, - real_client: Arc, client: Arc, } @@ -38,7 +38,6 @@ impl StripeBilling { pub fn new(client: Arc) -> Self { Self { client: Arc::new(RealStripeClient::new(client.clone())), - real_client: client, state: RwLock::default(), } } @@ -46,8 +45,6 @@ impl StripeBilling { #[cfg(test)] pub fn test(client: Arc) -> Self { Self { - // This is just temporary until we can remove all usages of the real Stripe client. - real_client: Arc::new(stripe::Client::new("sk_test")), client, state: RwLock::default(), } @@ -306,40 +303,33 @@ impl StripeBilling { pub async fn subscribe_to_zed_free( &self, - customer_id: stripe::CustomerId, - ) -> Result { + customer_id: StripeCustomerId, + ) -> Result { let zed_free_price_id = self.zed_free_price_id().await?; - let existing_subscriptions = stripe::Subscription::list( - &self.real_client, - &stripe::ListSubscriptions { - customer: Some(customer_id.clone()), - status: None, - ..Default::default() - }, - ) - .await?; + let existing_subscriptions = self + .client + .list_subscriptions_for_customer(&customer_id) + .await?; let existing_active_subscription = - existing_subscriptions - .data - .into_iter() - .find(|subscription| { - subscription.status == SubscriptionStatus::Active - || subscription.status == SubscriptionStatus::Trialing - }); + existing_subscriptions.into_iter().find(|subscription| { + subscription.status == SubscriptionStatus::Active + || subscription.status == SubscriptionStatus::Trialing + }); if let Some(subscription) = existing_active_subscription { return Ok(subscription); } - let mut params = stripe::CreateSubscription::new(customer_id); - params.items = Some(vec![stripe::CreateSubscriptionItems { - price: Some(zed_free_price_id.to_string()), - quantity: Some(1), - ..Default::default() - }]); + let params = StripeCreateSubscriptionParams { + customer: customer_id, + items: vec![StripeCreateSubscriptionItems { + price: Some(zed_free_price_id), + quantity: Some(1), + }], + }; - let subscription = stripe::Subscription::create(&self.real_client, params).await?; + let subscription = self.client.create_subscription(params).await?; Ok(subscription) } diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 91ffd7a3d93299678daa0237d1bcf4f6f1da00e3..b009f5bd2c4b98c203af7e31c90a67eb811ff4fb 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -30,21 +30,38 @@ pub struct CreateCustomerParams<'a> { #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] pub struct StripeSubscriptionId(pub Arc); -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct StripeSubscription { pub id: StripeSubscriptionId, + pub customer: StripeCustomerId, + // TODO: Create our own version of this enum. + pub status: stripe::SubscriptionStatus, + pub current_period_end: i64, + pub current_period_start: i64, pub items: Vec, } #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] pub struct StripeSubscriptionItemId(pub Arc); -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct StripeSubscriptionItem { pub id: StripeSubscriptionItemId, pub price: Option, } +#[derive(Debug)] +pub struct StripeCreateSubscriptionParams { + pub customer: StripeCustomerId, + pub items: Vec, +} + +#[derive(Debug)] +pub struct StripeCreateSubscriptionItems { + pub price: Option, + pub quantity: Option, +} + #[derive(Debug, Clone)] pub struct UpdateSubscriptionParams { pub items: Option>, @@ -76,7 +93,7 @@ pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] pub struct StripePriceId(pub Arc); -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct StripePrice { pub id: StripePriceId, pub unit_amount: Option, @@ -84,7 +101,7 @@ pub struct StripePrice { pub recurring: Option, } -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct StripePriceRecurring { pub meter: Option, } @@ -160,11 +177,21 @@ pub trait StripeClient: Send + Sync { async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; + async fn list_subscriptions_for_customer( + &self, + customer_id: &StripeCustomerId, + ) -> Result>; + async fn get_subscription( &self, subscription_id: &StripeSubscriptionId, ) -> Result; + async fn create_subscription( + &self, + params: StripeCreateSubscriptionParams, + ) -> Result; + async fn update_subscription( &self, subscription_id: &StripeSubscriptionId, diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 3a2d2c8590dc58c5bd2221491c9f7fec03c9e7aa..43b03a2d9584aed59a5f9a33a438f77375826a7d 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use async_trait::async_trait; +use chrono::{Duration, Utc}; use collections::HashMap; use parking_lot::Mutex; use uuid::Uuid; @@ -10,9 +11,10 @@ use crate::stripe_client::{ CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, - StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, UpdateSubscriptionParams, + StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, + StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId, + StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, + StripeSubscriptionItemId, UpdateSubscriptionParams, }; #[derive(Debug, Clone)] @@ -85,6 +87,21 @@ impl StripeClient for FakeStripeClient { Ok(customer) } + async fn list_subscriptions_for_customer( + &self, + customer_id: &StripeCustomerId, + ) -> Result> { + let subscriptions = self + .subscriptions + .lock() + .values() + .filter(|subscription| subscription.customer == *customer_id) + .cloned() + .collect(); + + Ok(subscriptions) + } + async fn get_subscription( &self, subscription_id: &StripeSubscriptionId, @@ -96,6 +113,37 @@ impl StripeClient for FakeStripeClient { .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) } + async fn create_subscription( + &self, + params: StripeCreateSubscriptionParams, + ) -> Result { + let now = Utc::now(); + + let subscription = StripeSubscription { + id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()), + customer: params.customer, + status: stripe::SubscriptionStatus::Active, + current_period_start: now.timestamp(), + current_period_end: (now + Duration::days(30)).timestamp(), + items: params + .items + .into_iter() + .map(|item| StripeSubscriptionItem { + id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()), + price: item + .price + .and_then(|price_id| self.prices.lock().get(&price_id).cloned()), + }) + .collect(), + }; + + self.subscriptions + .lock() + .insert(subscription.id.clone(), subscription.clone()); + + Ok(subscription) + } + async fn update_subscription( &self, subscription_id: &StripeSubscriptionId, diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 724e4c64c3d9d98123e76b09abd5012cf44d0aa4..e76e9df821cb8c6058b2a317967e9b7298d5be3f 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -20,10 +20,11 @@ use crate::stripe_client::{ CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, - StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, - StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, + StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, + StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, + StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, + StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, + StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams, }; @@ -69,6 +70,29 @@ impl StripeClient for RealStripeClient { Ok(StripeCustomer::from(customer)) } + async fn list_subscriptions_for_customer( + &self, + customer_id: &StripeCustomerId, + ) -> Result> { + let customer_id = customer_id.try_into()?; + + let subscriptions = stripe::Subscription::list( + &self.client, + &stripe::ListSubscriptions { + customer: Some(customer_id), + status: None, + ..Default::default() + }, + ) + .await?; + + Ok(subscriptions + .data + .into_iter() + .map(StripeSubscription::from) + .collect()) + } + async fn get_subscription( &self, subscription_id: &StripeSubscriptionId, @@ -80,6 +104,30 @@ impl StripeClient for RealStripeClient { Ok(StripeSubscription::from(subscription)) } + async fn create_subscription( + &self, + params: StripeCreateSubscriptionParams, + ) -> Result { + let customer_id = params.customer.try_into()?; + + let mut create_subscription = stripe::CreateSubscription::new(customer_id); + create_subscription.items = Some( + params + .items + .into_iter() + .map(|item| stripe::CreateSubscriptionItems { + price: item.price.map(|price| price.to_string()), + quantity: item.quantity, + ..Default::default() + }) + .collect(), + ); + + let subscription = Subscription::create(&self.client, create_subscription).await?; + + Ok(StripeSubscription::from(subscription)) + } + async fn update_subscription( &self, subscription_id: &StripeSubscriptionId, @@ -220,6 +268,10 @@ impl From for StripeSubscription { fn from(value: Subscription) -> Self { Self { id: value.id.into(), + customer: value.customer.id().into(), + status: value.status, + current_period_start: value.current_period_start, + current_period_end: value.current_period_end, items: value.items.data.into_iter().map(Into::into).collect(), } } diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index da18d2e7a2398247acf768e10df3c86c0332e70d..45133923dee70af44c5e5ddbb4475042f2f85a9c 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; @@ -163,8 +164,13 @@ async fn test_subscribe_to_price() { .lock() .insert(price.id.clone(), price.clone()); + let now = Utc::now(); let subscription = StripeSubscription { id: StripeSubscriptionId("sub_test".into()), + customer: StripeCustomerId("cus_test".into()), + status: stripe::SubscriptionStatus::Active, + current_period_start: now.timestamp(), + current_period_end: (now + Duration::days(30)).timestamp(), items: vec![], }; stripe_client @@ -194,8 +200,13 @@ async fn test_subscribe_to_price() { // Subscribing to a price that is already on the subscription is a no-op. { + let now = Utc::now(); let subscription = StripeSubscription { id: StripeSubscriptionId("sub_test".into()), + customer: StripeCustomerId("cus_test".into()), + status: stripe::SubscriptionStatus::Active, + current_period_start: now.timestamp(), + current_period_end: (now + Duration::days(30)).timestamp(), items: vec![StripeSubscriptionItem { id: StripeSubscriptionItemId("si_test".into()), price: Some(price.clone()), @@ -215,6 +226,104 @@ async fn test_subscribe_to_price() { } } +#[gpui::test] +async fn test_subscribe_to_zed_free() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + let zed_pro_price = StripePrice { + id: StripePriceId("price_1".into()), + unit_amount: Some(0), + lookup_key: Some("zed-pro".to_string()), + recurring: None, + }; + stripe_client + .prices + .lock() + .insert(zed_pro_price.id.clone(), zed_pro_price.clone()); + let zed_free_price = StripePrice { + id: StripePriceId("price_2".into()), + unit_amount: Some(0), + lookup_key: Some("zed-free".to_string()), + recurring: None, + }; + stripe_client + .prices + .lock() + .insert(zed_free_price.id.clone(), zed_free_price.clone()); + + stripe_billing.initialize().await.unwrap(); + + // Customer is subscribed to Zed Free when not already subscribed to a plan. + { + let customer_id = StripeCustomerId("cus_no_plan".into()); + + let subscription = stripe_billing + .subscribe_to_zed_free(customer_id) + .await + .unwrap(); + + assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price)); + } + + // Customer is not subscribed to Zed Free when they already have an active subscription. + { + let customer_id = StripeCustomerId("cus_active_subscription".into()); + + let now = Utc::now(); + let existing_subscription = StripeSubscription { + id: StripeSubscriptionId("sub_existing_active".into()), + customer: customer_id.clone(), + status: stripe::SubscriptionStatus::Active, + current_period_start: now.timestamp(), + current_period_end: (now + Duration::days(30)).timestamp(), + items: vec![StripeSubscriptionItem { + id: StripeSubscriptionItemId("si_test".into()), + price: Some(zed_pro_price.clone()), + }], + }; + stripe_client.subscriptions.lock().insert( + existing_subscription.id.clone(), + existing_subscription.clone(), + ); + + let subscription = stripe_billing + .subscribe_to_zed_free(customer_id) + .await + .unwrap(); + + assert_eq!(subscription, existing_subscription); + } + + // Customer is not subscribed to Zed Free when they already have a trial subscription. + { + let customer_id = StripeCustomerId("cus_trial_subscription".into()); + + let now = Utc::now(); + let existing_subscription = StripeSubscription { + id: StripeSubscriptionId("sub_existing_trial".into()), + customer: customer_id.clone(), + status: stripe::SubscriptionStatus::Trialing, + current_period_start: now.timestamp(), + current_period_end: (now + Duration::days(14)).timestamp(), + items: vec![StripeSubscriptionItem { + id: StripeSubscriptionItemId("si_test".into()), + price: Some(zed_pro_price.clone()), + }], + }; + stripe_client.subscriptions.lock().insert( + existing_subscription.id.clone(), + existing_subscription.clone(), + ); + + let subscription = stripe_billing + .subscribe_to_zed_free(customer_id) + .await + .unwrap(); + + assert_eq!(subscription, existing_subscription); + } +} + #[gpui::test] async fn test_bill_model_request_usage() { let (stripe_billing, stripe_client) = make_stripe_billing();