collab: Use `StripeClient` when creating Stripe Checkout sessions (#31644)

Marshall Bowers created

This PR updates the `StripeBilling::checkout_with_zed_pro` and
`StripeBilling::checkout_with_zed_pro_trial` methods to use the
`StripeClient` trait instead of using `stripe::Client` directly.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      |   8 
crates/collab/src/stripe_billing.rs                   |  55 +-
crates/collab/src/stripe_client.rs                    |  64 +++
crates/collab/src/stripe_client/fake_stripe_client.rs |  39 ++
crates/collab/src/stripe_client/real_stripe_client.rs | 153 ++++++++
crates/collab/src/tests/stripe_billing_tests.rs       | 211 ++++++++++++
6 files changed, 471 insertions(+), 59 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -338,13 +338,11 @@ async fn create_billing_subscription(
     }
 
     let customer_id = if let Some(existing_customer) = &existing_billing_customer {
-        CustomerId::from_str(&existing_customer.stripe_customer_id)
-            .context("failed to parse customer ID")?
+        StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
     } else {
         stripe_billing
             .find_or_create_customer_by_email(user.email_address.as_deref())
             .await?
-            .try_into()?
     };
 
     let success_url = format!(
@@ -355,7 +353,7 @@ async fn create_billing_subscription(
     let checkout_session_url = match body.product {
         ProductCode::ZedPro => {
             stripe_billing
-                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+                .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
                 .await?
         }
         ProductCode::ZedProTrial => {
@@ -372,7 +370,7 @@ async fn create_billing_subscription(
 
             stripe_billing
                 .checkout_with_zed_pro_trial(
-                    customer_id,
+                    &customer_id,
                     &user.github_login,
                     feature_flags,
                     &success_url,

crates/collab/src/stripe_billing.rs 🔗

@@ -11,11 +11,14 @@ use crate::Result;
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use crate::stripe_client::{
-    RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
-    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
-    StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
-    UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
-    UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+    RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
+    StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+    StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+    StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
+    StripeSubscriptionTrialSettingsEndBehavior,
+    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
+    UpdateSubscriptionParams,
 };
 
 pub struct StripeBilling {
@@ -190,9 +193,9 @@ impl StripeBilling {
                     items: Some(vec![UpdateSubscriptionItems {
                         price: Some(price.id.clone()),
                     }]),
-                    trial_settings: Some(UpdateSubscriptionTrialSettings {
-                        end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
-                            missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
+                    trial_settings: Some(StripeSubscriptionTrialSettings {
+                        end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+                            missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
                         },
                     }),
                 },
@@ -228,30 +231,29 @@ impl StripeBilling {
 
     pub async fn checkout_with_zed_pro(
         &self,
-        customer_id: stripe::CustomerId,
+        customer_id: &StripeCustomerId,
         github_login: &str,
         success_url: &str,
     ) -> Result<String> {
         let zed_pro_price_id = self.zed_pro_price_id().await?;
 
-        let mut params = stripe::CreateCheckoutSession::new();
-        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+        let mut params = StripeCreateCheckoutSessionParams::default();
+        params.mode = Some(StripeCheckoutSessionMode::Subscription);
         params.customer = Some(customer_id);
         params.client_reference_id = Some(github_login);
-        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+        params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
             price: Some(zed_pro_price_id.to_string()),
             quantity: Some(1),
-            ..Default::default()
         }]);
         params.success_url = Some(success_url);
 
-        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
+        let session = self.client.create_checkout_session(params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
 
     pub async fn checkout_with_zed_pro_trial(
         &self,
-        customer_id: stripe::CustomerId,
+        customer_id: &StripeCustomerId,
         github_login: &str,
         feature_flags: Vec<String>,
         success_url: &str,
@@ -272,34 +274,33 @@ impl StripeBilling {
             );
         }
 
-        let mut params = stripe::CreateCheckoutSession::new();
-        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
+        let mut params = StripeCreateCheckoutSessionParams::default();
+        params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
             trial_period_days: Some(trial_period_days),
-            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
-                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
-                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
-                }
+            trial_settings: Some(StripeSubscriptionTrialSettings {
+                end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+                    missing_payment_method:
+                        StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+                },
             }),
             metadata: if !subscription_metadata.is_empty() {
                 Some(subscription_metadata)
             } else {
                 None
             },
-            ..Default::default()
         });
-        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+        params.mode = Some(StripeCheckoutSessionMode::Subscription);
         params.payment_method_collection =
-            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
+            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
         params.customer = Some(customer_id);
         params.client_reference_id = Some(github_login);
-        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+        params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
             price: Some(zed_pro_price_id.to_string()),
             quantity: Some(1),
-            ..Default::default()
         }]);
         params.success_url = Some(success_url);
 
-        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
+        let session = self.client.create_checkout_session(params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
 

crates/collab/src/stripe_client.rs 🔗

@@ -2,6 +2,7 @@
 mod fake_stripe_client;
 mod real_stripe_client;
 
+use std::collections::HashMap;
 use std::sync::Arc;
 
 use anyhow::Result;
@@ -47,7 +48,7 @@ pub struct StripeSubscriptionItem {
 #[derive(Debug, Clone)]
 pub struct UpdateSubscriptionParams {
     pub items: Option<Vec<UpdateSubscriptionItems>>,
-    pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
+    pub trial_settings: Option<StripeSubscriptionTrialSettings>,
 }
 
 #[derive(Debug, PartialEq, Clone)]
@@ -55,18 +56,18 @@ pub struct UpdateSubscriptionItems {
     pub price: Option<StripePriceId>,
 }
 
-#[derive(Debug, Clone)]
-pub struct UpdateSubscriptionTrialSettings {
-    pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeSubscriptionTrialSettings {
+    pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
 }
 
-#[derive(Debug, Clone)]
-pub struct UpdateSubscriptionTrialSettingsEndBehavior {
-    pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeSubscriptionTrialSettingsEndBehavior {
+    pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
 }
 
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
-pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
+pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
     Cancel,
     CreateInvoice,
     Pause,
@@ -111,6 +112,48 @@ pub struct StripeCreateMeterEventPayload<'a> {
     pub stripe_customer_id: &'a StripeCustomerId,
 }
 
+#[derive(Debug, Default)]
+pub struct StripeCreateCheckoutSessionParams<'a> {
+    pub customer: Option<&'a StripeCustomerId>,
+    pub client_reference_id: Option<&'a str>,
+    pub mode: Option<StripeCheckoutSessionMode>,
+    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
+    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
+    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
+    pub success_url: Option<&'a str>,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeCheckoutSessionMode {
+    Payment,
+    Setup,
+    Subscription,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeCreateCheckoutSessionLineItems {
+    pub price: Option<String>,
+    pub quantity: Option<u64>,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeCheckoutSessionPaymentMethodCollection {
+    Always,
+    IfRequired,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeCreateCheckoutSessionSubscriptionData {
+    pub metadata: Option<HashMap<String, String>>,
+    pub trial_period_days: Option<u32>,
+    pub trial_settings: Option<StripeSubscriptionTrialSettings>,
+}
+
+#[derive(Debug)]
+pub struct StripeCheckoutSession {
+    pub url: Option<String>,
+}
+
 #[async_trait]
 pub trait StripeClient: Send + Sync {
     async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
@@ -133,4 +176,9 @@ pub trait StripeClient: Send + Sync {
     async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
 
     async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
+
+    async fn create_checkout_session(
+        &self,
+        params: StripeCreateCheckoutSessionParams<'_>,
+    ) -> Result<StripeCheckoutSession>;
 }

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -7,7 +7,10 @@ use parking_lot::Mutex;
 use uuid::Uuid;
 
 use crate::stripe_client::{
-    CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
+    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
     StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
     StripeSubscriptionId, UpdateSubscriptionParams,
 };
@@ -21,6 +24,17 @@ pub struct StripeCreateMeterEventCall {
     pub timestamp: Option<i64>,
 }
 
+#[derive(Debug, Clone)]
+pub struct StripeCreateCheckoutSessionCall {
+    pub customer: Option<StripeCustomerId>,
+    pub client_reference_id: Option<String>,
+    pub mode: Option<StripeCheckoutSessionMode>,
+    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
+    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
+    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
+    pub success_url: Option<String>,
+}
+
 pub struct FakeStripeClient {
     pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
     pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
@@ -29,6 +43,7 @@ pub struct FakeStripeClient {
     pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
     pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
     pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
+    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 }
 
 impl FakeStripeClient {
@@ -40,6 +55,7 @@ impl FakeStripeClient {
             prices: Arc::new(Mutex::new(HashMap::default())),
             meters: Arc::new(Mutex::new(HashMap::default())),
             create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
+            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
         }
     }
 }
@@ -119,4 +135,25 @@ impl StripeClient for FakeStripeClient {
 
         Ok(())
     }
+
+    async fn create_checkout_session(
+        &self,
+        params: StripeCreateCheckoutSessionParams<'_>,
+    ) -> Result<StripeCheckoutSession> {
+        self.create_checkout_session_calls
+            .lock()
+            .push(StripeCreateCheckoutSessionCall {
+                customer: params.customer.cloned(),
+                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
+                mode: params.mode,
+                line_items: params.line_items,
+                payment_method_collection: params.payment_method_collection,
+                subscription_data: params.subscription_data,
+                success_url: params.success_url.map(|url| url.to_string()),
+            });
+
+        Ok(StripeCheckoutSession {
+            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
+        })
+    }
 }

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -5,6 +5,11 @@ use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
 use serde::Serialize;
 use stripe::{
+    CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
+    CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
+    CreateCheckoutSessionSubscriptionDataTrialSettings,
+    CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
+    CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
     CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
     SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
     UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
@@ -12,10 +17,14 @@ use stripe::{
 };
 
 use crate::stripe_client::{
-    CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
+    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
     StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
     StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
-    UpdateSubscriptionParams,
+    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
 };
 
 pub struct RealStripeClient {
@@ -150,6 +159,16 @@ impl StripeClient for RealStripeClient {
             Err(error) => Err(anyhow!(error)),
         }
     }
+
+    async fn create_checkout_session(
+        &self,
+        params: StripeCreateCheckoutSessionParams<'_>,
+    ) -> Result<StripeCheckoutSession> {
+        let params = params.try_into()?;
+        let session = CheckoutSession::create(&self.client, params).await?;
+
+        Ok(session.into())
+    }
 }
 
 impl From<CustomerId> for StripeCustomerId {
@@ -166,6 +185,14 @@ impl TryFrom<StripeCustomerId> for CustomerId {
     }
 }
 
+impl TryFrom<&StripeCustomerId> for CustomerId {
+    type Error = anyhow::Error;
+
+    fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
+        Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
+    }
+}
+
 impl From<Customer> for StripeCustomer {
     fn from(value: Customer) -> Self {
         StripeCustomer {
@@ -213,38 +240,34 @@ impl From<SubscriptionItem> for StripeSubscriptionItem {
     }
 }
 
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
-    for UpdateSubscriptionTrialSettings
-{
-    fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
+impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
+    fn from(value: StripeSubscriptionTrialSettings) -> Self {
         Self {
             end_behavior: value.end_behavior.into(),
         }
     }
 }
 
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
+impl From<StripeSubscriptionTrialSettingsEndBehavior>
     for UpdateSubscriptionTrialSettingsEndBehavior
 {
-    fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
+    fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
         Self {
             missing_payment_method: value.missing_payment_method.into(),
         }
     }
 }
 
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
     for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
 {
-    fn from(
-        value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
-    ) -> Self {
+    fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
         match value {
-            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
-            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
                 Self::CreateInvoice
             }
-            crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
         }
     }
 }
@@ -279,3 +302,103 @@ impl From<Recurring> for StripePriceRecurring {
         Self { meter: value.meter }
     }
 }
+
+impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
+    type Error = anyhow::Error;
+
+    fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
+        Ok(Self {
+            customer: value
+                .customer
+                .map(|customer_id| customer_id.try_into())
+                .transpose()?,
+            client_reference_id: value.client_reference_id,
+            mode: value.mode.map(Into::into),
+            line_items: value
+                .line_items
+                .map(|line_items| line_items.into_iter().map(Into::into).collect()),
+            payment_method_collection: value.payment_method_collection.map(Into::into),
+            subscription_data: value.subscription_data.map(Into::into),
+            success_url: value.success_url,
+            ..Default::default()
+        })
+    }
+}
+
+impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
+    fn from(value: StripeCheckoutSessionMode) -> Self {
+        match value {
+            StripeCheckoutSessionMode::Payment => Self::Payment,
+            StripeCheckoutSessionMode::Setup => Self::Setup,
+            StripeCheckoutSessionMode::Subscription => Self::Subscription,
+        }
+    }
+}
+
+impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
+    fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
+        Self {
+            price: value.price,
+            quantity: value.quantity,
+            ..Default::default()
+        }
+    }
+}
+
+impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
+    fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
+        match value {
+            StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
+            StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
+        }
+    }
+}
+
+impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
+    fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
+        Self {
+            trial_period_days: value.trial_period_days,
+            trial_settings: value.trial_settings.map(Into::into),
+            metadata: value.metadata,
+            ..Default::default()
+        }
+    }
+}
+
+impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
+    fn from(value: StripeSubscriptionTrialSettings) -> Self {
+        Self {
+            end_behavior: value.end_behavior.into(),
+        }
+    }
+}
+
+impl From<StripeSubscriptionTrialSettingsEndBehavior>
+    for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
+{
+    fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
+        Self {
+            missing_payment_method: value.missing_payment_method.into(),
+        }
+    }
+}
+
+impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+    for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
+{
+    fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
+        match value {
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+                Self::CreateInvoice
+            }
+            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+        }
+    }
+}
+
+impl From<CheckoutSession> for StripeCheckoutSession {
+    fn from(value: CheckoutSession) -> Self {
+        Self { url: value.url }
+    }
+}

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -2,11 +2,15 @@ use std::sync::Arc;
 
 use pretty_assertions::assert_eq;
 
+use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use crate::stripe_billing::StripeBilling;
 use crate::stripe_client::{
-    FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
-    StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
-    StripeSubscriptionItemId, UpdateSubscriptionItems,
+    FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
+    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
+    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 };
 
 fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -241,3 +245,204 @@ async fn test_bill_model_request_usage() {
     );
     assert_eq!(create_meter_event_calls[0].value, 73);
 }
+
+#[gpui::test]
+async fn test_checkout_with_zed_pro() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    let customer_id = StripeCustomerId("cus_test".into());
+    let github_login = "zeduser1";
+    let success_url = "https://example.com/success";
+
+    // It returns an error when the Zed Pro price doesn't exist.
+    {
+        let result = stripe_billing
+            .checkout_with_zed_pro(&customer_id, github_login, success_url)
+            .await;
+
+        assert!(result.is_err());
+        assert_eq!(
+            result.err().unwrap().to_string(),
+            r#"no price ID found for "zed-pro""#
+        );
+    }
+
+    // Successful checkout.
+    {
+        let price = StripePrice {
+            id: StripePriceId("price_1".into()),
+            unit_amount: Some(2000),
+            lookup_key: Some("zed-pro".to_string()),
+            recurring: None,
+        };
+        stripe_client
+            .prices
+            .lock()
+            .insert(price.id.clone(), price.clone());
+
+        stripe_billing.initialize().await.unwrap();
+
+        let checkout_url = stripe_billing
+            .checkout_with_zed_pro(&customer_id, github_login, success_url)
+            .await
+            .unwrap();
+
+        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+        let create_checkout_session_calls = stripe_client
+            .create_checkout_session_calls
+            .lock()
+            .drain(..)
+            .collect::<Vec<_>>();
+        assert_eq!(create_checkout_session_calls.len(), 1);
+        let call = create_checkout_session_calls.into_iter().next().unwrap();
+        assert_eq!(call.customer, Some(customer_id));
+        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+        assert_eq!(
+            call.line_items,
+            Some(vec![StripeCreateCheckoutSessionLineItems {
+                price: Some(price.id.to_string()),
+                quantity: Some(1)
+            }])
+        );
+        assert_eq!(call.payment_method_collection, None);
+        assert_eq!(call.subscription_data, None);
+        assert_eq!(call.success_url.as_deref(), Some(success_url));
+    }
+}
+
+#[gpui::test]
+async fn test_checkout_with_zed_pro_trial() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    let customer_id = StripeCustomerId("cus_test".into());
+    let github_login = "zeduser1";
+    let success_url = "https://example.com/success";
+
+    // It returns an error when the Zed Pro price doesn't exist.
+    {
+        let result = stripe_billing
+            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
+            .await;
+
+        assert!(result.is_err());
+        assert_eq!(
+            result.err().unwrap().to_string(),
+            r#"no price ID found for "zed-pro""#
+        );
+    }
+
+    let price = StripePrice {
+        id: StripePriceId("price_1".into()),
+        unit_amount: Some(2000),
+        lookup_key: Some("zed-pro".to_string()),
+        recurring: None,
+    };
+    stripe_client
+        .prices
+        .lock()
+        .insert(price.id.clone(), price.clone());
+
+    stripe_billing.initialize().await.unwrap();
+
+    // Successful checkout.
+    {
+        let checkout_url = stripe_billing
+            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
+            .await
+            .unwrap();
+
+        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+        let create_checkout_session_calls = stripe_client
+            .create_checkout_session_calls
+            .lock()
+            .drain(..)
+            .collect::<Vec<_>>();
+        assert_eq!(create_checkout_session_calls.len(), 1);
+        let call = create_checkout_session_calls.into_iter().next().unwrap();
+        assert_eq!(call.customer.as_ref(), Some(&customer_id));
+        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+        assert_eq!(
+            call.line_items,
+            Some(vec![StripeCreateCheckoutSessionLineItems {
+                price: Some(price.id.to_string()),
+                quantity: Some(1)
+            }])
+        );
+        assert_eq!(
+            call.payment_method_collection,
+            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
+        );
+        assert_eq!(
+            call.subscription_data,
+            Some(StripeCreateCheckoutSessionSubscriptionData {
+                trial_period_days: Some(14),
+                trial_settings: Some(StripeSubscriptionTrialSettings {
+                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+                        missing_payment_method:
+                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+                    },
+                }),
+                metadata: None,
+            })
+        );
+        assert_eq!(call.success_url.as_deref(), Some(success_url));
+    }
+
+    // Successful checkout with extended trial.
+    {
+        let checkout_url = stripe_billing
+            .checkout_with_zed_pro_trial(
+                &customer_id,
+                github_login,
+                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
+                success_url,
+            )
+            .await
+            .unwrap();
+
+        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+        let create_checkout_session_calls = stripe_client
+            .create_checkout_session_calls
+            .lock()
+            .drain(..)
+            .collect::<Vec<_>>();
+        assert_eq!(create_checkout_session_calls.len(), 1);
+        let call = create_checkout_session_calls.into_iter().next().unwrap();
+        assert_eq!(call.customer, Some(customer_id));
+        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+        assert_eq!(
+            call.line_items,
+            Some(vec![StripeCreateCheckoutSessionLineItems {
+                price: Some(price.id.to_string()),
+                quantity: Some(1)
+            }])
+        );
+        assert_eq!(
+            call.payment_method_collection,
+            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
+        );
+        assert_eq!(
+            call.subscription_data,
+            Some(StripeCreateCheckoutSessionSubscriptionData {
+                trial_period_days: Some(60),
+                trial_settings: Some(StripeSubscriptionTrialSettings {
+                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+                        missing_payment_method:
+                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+                    },
+                }),
+                metadata: Some(std::collections::HashMap::from_iter([(
+                    "promo_feature_flag".into(),
+                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
+                )])),
+            })
+        );
+        assert_eq!(call.success_url.as_deref(), Some(success_url));
+    }
+}