diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 5bd2c3cc6de1fa1fad641352073718b7eb3114da..7fd35fc0300f3c15baca5ace5650d4a0e2a22780 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -338,13 +338,11 @@ async fn create_billing_subscription( } let customer_id = if let Some(existing_customer) = &existing_billing_customer { - CustomerId::from_str(&existing_customer.stripe_customer_id) - .context("failed to parse customer ID")? + StripeCustomerId(existing_customer.stripe_customer_id.clone().into()) } else { stripe_billing .find_or_create_customer_by_email(user.email_address.as_deref()) .await? - .try_into()? }; let success_url = format!( @@ -355,7 +353,7 @@ async fn create_billing_subscription( let checkout_session_url = match body.product { ProductCode::ZedPro => { stripe_billing - .checkout_with_zed_pro(customer_id, &user.github_login, &success_url) + .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url) .await? } ProductCode::ZedProTrial => { @@ -372,7 +370,7 @@ async fn create_billing_subscription( stripe_billing .checkout_with_zed_pro_trial( - customer_id, + &customer_id, &user.github_login, feature_flags, &success_url, diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index ded117dc3d09e9c98065d5feac0520a725c7679d..4a8bb41c2cb71ee72378b37a1bc1485041e9f84d 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -11,11 +11,14 @@ use crate::Result; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ - RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload, - StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams, - UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, - UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, + RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, + StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, + StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, + StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, + StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings, + StripeSubscriptionTrialSettingsEndBehavior, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, + UpdateSubscriptionParams, }; pub struct StripeBilling { @@ -190,9 +193,9 @@ impl StripeBilling { items: Some(vec![UpdateSubscriptionItems { price: Some(price.id.clone()), }]), - trial_settings: Some(UpdateSubscriptionTrialSettings { - end_behavior: UpdateSubscriptionTrialSettingsEndBehavior { - missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel + trial_settings: Some(StripeSubscriptionTrialSettings { + end_behavior: StripeSubscriptionTrialSettingsEndBehavior { + missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel }, }), }, @@ -228,30 +231,29 @@ impl StripeBilling { pub async fn checkout_with_zed_pro( &self, - customer_id: stripe::CustomerId, + customer_id: &StripeCustomerId, github_login: &str, success_url: &str, ) -> Result { let zed_pro_price_id = self.zed_pro_price_id().await?; - let mut params = stripe::CreateCheckoutSession::new(); - params.mode = Some(stripe::CheckoutSessionMode::Subscription); + let mut params = StripeCreateCheckoutSessionParams::default(); + params.mode = Some(StripeCheckoutSessionMode::Subscription); params.customer = Some(customer_id); params.client_reference_id = Some(github_login); - params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { + params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { price: Some(zed_pro_price_id.to_string()), quantity: Some(1), - ..Default::default() }]); params.success_url = Some(success_url); - let session = stripe::CheckoutSession::create(&self.real_client, params).await?; + let session = self.client.create_checkout_session(params).await?; Ok(session.url.context("no checkout session URL")?) } pub async fn checkout_with_zed_pro_trial( &self, - customer_id: stripe::CustomerId, + customer_id: &StripeCustomerId, github_login: &str, feature_flags: Vec, success_url: &str, @@ -272,34 +274,33 @@ impl StripeBilling { ); } - let mut params = stripe::CreateCheckoutSession::new(); - params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData { + let mut params = StripeCreateCheckoutSessionParams::default(); + params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData { trial_period_days: Some(trial_period_days), - trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings { - end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior { - missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - } + trial_settings: Some(StripeSubscriptionTrialSettings { + end_behavior: StripeSubscriptionTrialSettingsEndBehavior { + missing_payment_method: + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, + }, }), metadata: if !subscription_metadata.is_empty() { Some(subscription_metadata) } else { None }, - ..Default::default() }); - params.mode = Some(stripe::CheckoutSessionMode::Subscription); + params.mode = Some(StripeCheckoutSessionMode::Subscription); params.payment_method_collection = - Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired); + Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired); params.customer = Some(customer_id); params.client_reference_id = Some(github_login); - params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { + params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { price: Some(zed_pro_price_id.to_string()), quantity: Some(1), - ..Default::default() }]); params.success_url = Some(success_url); - let session = stripe::CheckoutSession::create(&self.real_client, params).await?; + let session = self.client.create_checkout_session(params).await?; Ok(session.url.context("no checkout session URL")?) } diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index f15e373a9e8261d55a23b5ca1f4d9068ed380881..91ffd7a3d93299678daa0237d1bcf4f6f1da00e3 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -2,6 +2,7 @@ mod fake_stripe_client; mod real_stripe_client; +use std::collections::HashMap; use std::sync::Arc; use anyhow::Result; @@ -47,7 +48,7 @@ pub struct StripeSubscriptionItem { #[derive(Debug, Clone)] pub struct UpdateSubscriptionParams { pub items: Option>, - pub trial_settings: Option, + pub trial_settings: Option, } #[derive(Debug, PartialEq, Clone)] @@ -55,18 +56,18 @@ pub struct UpdateSubscriptionItems { pub price: Option, } -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionTrialSettings { - pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior, +#[derive(Debug, PartialEq, Clone)] +pub struct StripeSubscriptionTrialSettings { + pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior, } -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionTrialSettingsEndBehavior { - pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, +#[derive(Debug, PartialEq, Clone)] +pub struct StripeSubscriptionTrialSettingsEndBehavior { + pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { +pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { Cancel, CreateInvoice, Pause, @@ -111,6 +112,48 @@ pub struct StripeCreateMeterEventPayload<'a> { pub stripe_customer_id: &'a StripeCustomerId, } +#[derive(Debug, Default)] +pub struct StripeCreateCheckoutSessionParams<'a> { + pub customer: Option<&'a StripeCustomerId>, + pub client_reference_id: Option<&'a str>, + pub mode: Option, + pub line_items: Option>, + pub payment_method_collection: Option, + pub subscription_data: Option, + pub success_url: Option<&'a str>, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum StripeCheckoutSessionMode { + Payment, + Setup, + Subscription, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct StripeCreateCheckoutSessionLineItems { + pub price: Option, + pub quantity: Option, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum StripeCheckoutSessionPaymentMethodCollection { + Always, + IfRequired, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct StripeCreateCheckoutSessionSubscriptionData { + pub metadata: Option>, + pub trial_period_days: Option, + pub trial_settings: Option, +} + +#[derive(Debug)] +pub struct StripeCheckoutSession { + pub url: Option, +} + #[async_trait] pub trait StripeClient: Send + Sync { async fn list_customers_by_email(&self, email: &str) -> Result>; @@ -133,4 +176,9 @@ pub trait StripeClient: Send + Sync { async fn list_meters(&self) -> Result>; async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; + + async fn create_checkout_session( + &self, + params: StripeCreateCheckoutSessionParams<'_>, + ) -> Result; } diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index ddcdaacc3d6cce00854e996ab9795c406e66101b..3a2d2c8590dc58c5bd2221491c9f7fec03c9e7aa 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -7,7 +7,10 @@ use parking_lot::Mutex; use uuid::Uuid; use crate::stripe_client::{ - CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer, + CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, + StripeCheckoutSessionPaymentMethodCollection, StripeClient, + StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, + StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, UpdateSubscriptionParams, }; @@ -21,6 +24,17 @@ pub struct StripeCreateMeterEventCall { pub timestamp: Option, } +#[derive(Debug, Clone)] +pub struct StripeCreateCheckoutSessionCall { + pub customer: Option, + pub client_reference_id: Option, + pub mode: Option, + pub line_items: Option>, + pub payment_method_collection: Option, + pub subscription_data: Option, + pub success_url: Option, +} + pub struct FakeStripeClient { pub customers: Arc>>, pub subscriptions: Arc>>, @@ -29,6 +43,7 @@ pub struct FakeStripeClient { pub prices: Arc>>, pub meters: Arc>>, pub create_meter_event_calls: Arc>>, + pub create_checkout_session_calls: Arc>>, } impl FakeStripeClient { @@ -40,6 +55,7 @@ impl FakeStripeClient { prices: Arc::new(Mutex::new(HashMap::default())), meters: Arc::new(Mutex::new(HashMap::default())), create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), + create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())), } } } @@ -119,4 +135,25 @@ impl StripeClient for FakeStripeClient { Ok(()) } + + async fn create_checkout_session( + &self, + params: StripeCreateCheckoutSessionParams<'_>, + ) -> Result { + self.create_checkout_session_calls + .lock() + .push(StripeCreateCheckoutSessionCall { + customer: params.customer.cloned(), + client_reference_id: params.client_reference_id.map(|id| id.to_string()), + mode: params.mode, + line_items: params.line_items, + payment_method_collection: params.payment_method_collection, + subscription_data: params.subscription_data, + success_url: params.success_url.map(|url| url.to_string()), + }); + + Ok(StripeCheckoutSession { + url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()), + }) + } } diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index fa0b08790d7ac9d664884ff1c9a2270aa5cea36a..724e4c64c3d9d98123e76b09abd5012cf44d0aa4 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -5,6 +5,11 @@ use anyhow::{Context as _, Result, anyhow}; use async_trait::async_trait; use serde::Serialize; use stripe::{ + CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection, + CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData, + CreateCheckoutSessionSubscriptionDataTrialSettings, + CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, + CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, @@ -12,10 +17,14 @@ use stripe::{ }; use crate::stripe_client::{ - CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer, + CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, + StripeCheckoutSessionPaymentMethodCollection, StripeClient, + StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, + StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - UpdateSubscriptionParams, + StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams, }; pub struct RealStripeClient { @@ -150,6 +159,16 @@ impl StripeClient for RealStripeClient { Err(error) => Err(anyhow!(error)), } } + + async fn create_checkout_session( + &self, + params: StripeCreateCheckoutSessionParams<'_>, + ) -> Result { + let params = params.try_into()?; + let session = CheckoutSession::create(&self.client, params).await?; + + Ok(session.into()) + } } impl From for StripeCustomerId { @@ -166,6 +185,14 @@ impl TryFrom for CustomerId { } } +impl TryFrom<&StripeCustomerId> for CustomerId { + type Error = anyhow::Error; + + fn try_from(value: &StripeCustomerId) -> Result { + Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") + } +} + impl From for StripeCustomer { fn from(value: Customer) -> Self { StripeCustomer { @@ -213,38 +240,34 @@ impl From for StripeSubscriptionItem { } } -impl From - for UpdateSubscriptionTrialSettings -{ - fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self { +impl From for UpdateSubscriptionTrialSettings { + fn from(value: StripeSubscriptionTrialSettings) -> Self { Self { end_behavior: value.end_behavior.into(), } } } -impl From +impl From for UpdateSubscriptionTrialSettingsEndBehavior { - fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self { + fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { Self { missing_payment_method: value.missing_payment_method.into(), } } } -impl From +impl From for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { - fn from( - value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, - ) -> Self { + fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { match value { - crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { Self::CreateInvoice } - crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, } } } @@ -279,3 +302,103 @@ impl From for StripePriceRecurring { Self { meter: value.meter } } } + +impl<'a> TryFrom> for CreateCheckoutSession<'a> { + type Error = anyhow::Error; + + fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result { + Ok(Self { + customer: value + .customer + .map(|customer_id| customer_id.try_into()) + .transpose()?, + client_reference_id: value.client_reference_id, + mode: value.mode.map(Into::into), + line_items: value + .line_items + .map(|line_items| line_items.into_iter().map(Into::into).collect()), + payment_method_collection: value.payment_method_collection.map(Into::into), + subscription_data: value.subscription_data.map(Into::into), + success_url: value.success_url, + ..Default::default() + }) + } +} + +impl From for CheckoutSessionMode { + fn from(value: StripeCheckoutSessionMode) -> Self { + match value { + StripeCheckoutSessionMode::Payment => Self::Payment, + StripeCheckoutSessionMode::Setup => Self::Setup, + StripeCheckoutSessionMode::Subscription => Self::Subscription, + } + } +} + +impl From for CreateCheckoutSessionLineItems { + fn from(value: StripeCreateCheckoutSessionLineItems) -> Self { + Self { + price: value.price, + quantity: value.quantity, + ..Default::default() + } + } +} + +impl From for CheckoutSessionPaymentMethodCollection { + fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self { + match value { + StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always, + StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired, + } + } +} + +impl From for CreateCheckoutSessionSubscriptionData { + fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self { + Self { + trial_period_days: value.trial_period_days, + trial_settings: value.trial_settings.map(Into::into), + metadata: value.metadata, + ..Default::default() + } + } +} + +impl From for CreateCheckoutSessionSubscriptionDataTrialSettings { + fn from(value: StripeSubscriptionTrialSettings) -> Self { + Self { + end_behavior: value.end_behavior.into(), + } + } +} + +impl From + for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior +{ + fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { + Self { + missing_payment_method: value.missing_payment_method.into(), + } + } +} + +impl From + for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod +{ + fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { + match value { + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { + Self::CreateInvoice + } + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, + } + } +} + +impl From for StripeCheckoutSession { + fn from(value: CheckoutSession) -> Self { + Self { url: value.url } + } +} diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index 6a8bab90feea3f0e2e13f9295b352f0bdc4e768d..da18d2e7a2398247acf768e10df3c86c0332e70d 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -2,11 +2,15 @@ use std::sync::Arc; use pretty_assertions::assert_eq; +use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_billing::StripeBilling; use crate::stripe_client::{ - FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, - StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, - StripeSubscriptionItemId, UpdateSubscriptionItems, + FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, + StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData, + StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring, + StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, + StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, }; fn make_stripe_billing() -> (StripeBilling, Arc) { @@ -241,3 +245,204 @@ async fn test_bill_model_request_usage() { ); assert_eq!(create_meter_event_calls[0].value, 73); } + +#[gpui::test] +async fn test_checkout_with_zed_pro() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + let customer_id = StripeCustomerId("cus_test".into()); + let github_login = "zeduser1"; + let success_url = "https://example.com/success"; + + // It returns an error when the Zed Pro price doesn't exist. + { + let result = stripe_billing + .checkout_with_zed_pro(&customer_id, github_login, success_url) + .await; + + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + r#"no price ID found for "zed-pro""# + ); + } + + // Successful checkout. + { + let price = StripePrice { + id: StripePriceId("price_1".into()), + unit_amount: Some(2000), + lookup_key: Some("zed-pro".to_string()), + recurring: None, + }; + stripe_client + .prices + .lock() + .insert(price.id.clone(), price.clone()); + + stripe_billing.initialize().await.unwrap(); + + let checkout_url = stripe_billing + .checkout_with_zed_pro(&customer_id, github_login, success_url) + .await + .unwrap(); + + assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); + + let create_checkout_session_calls = stripe_client + .create_checkout_session_calls + .lock() + .drain(..) + .collect::>(); + assert_eq!(create_checkout_session_calls.len(), 1); + let call = create_checkout_session_calls.into_iter().next().unwrap(); + assert_eq!(call.customer, Some(customer_id)); + assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); + assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); + assert_eq!( + call.line_items, + Some(vec![StripeCreateCheckoutSessionLineItems { + price: Some(price.id.to_string()), + quantity: Some(1) + }]) + ); + assert_eq!(call.payment_method_collection, None); + assert_eq!(call.subscription_data, None); + assert_eq!(call.success_url.as_deref(), Some(success_url)); + } +} + +#[gpui::test] +async fn test_checkout_with_zed_pro_trial() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + let customer_id = StripeCustomerId("cus_test".into()); + let github_login = "zeduser1"; + let success_url = "https://example.com/success"; + + // It returns an error when the Zed Pro price doesn't exist. + { + let result = stripe_billing + .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) + .await; + + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + r#"no price ID found for "zed-pro""# + ); + } + + let price = StripePrice { + id: StripePriceId("price_1".into()), + unit_amount: Some(2000), + lookup_key: Some("zed-pro".to_string()), + recurring: None, + }; + stripe_client + .prices + .lock() + .insert(price.id.clone(), price.clone()); + + stripe_billing.initialize().await.unwrap(); + + // Successful checkout. + { + let checkout_url = stripe_billing + .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) + .await + .unwrap(); + + assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); + + let create_checkout_session_calls = stripe_client + .create_checkout_session_calls + .lock() + .drain(..) + .collect::>(); + assert_eq!(create_checkout_session_calls.len(), 1); + let call = create_checkout_session_calls.into_iter().next().unwrap(); + assert_eq!(call.customer.as_ref(), Some(&customer_id)); + assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); + assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); + assert_eq!( + call.line_items, + Some(vec![StripeCreateCheckoutSessionLineItems { + price: Some(price.id.to_string()), + quantity: Some(1) + }]) + ); + assert_eq!( + call.payment_method_collection, + Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) + ); + assert_eq!( + call.subscription_data, + Some(StripeCreateCheckoutSessionSubscriptionData { + trial_period_days: Some(14), + trial_settings: Some(StripeSubscriptionTrialSettings { + end_behavior: StripeSubscriptionTrialSettingsEndBehavior { + missing_payment_method: + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, + }, + }), + metadata: None, + }) + ); + assert_eq!(call.success_url.as_deref(), Some(success_url)); + } + + // Successful checkout with extended trial. + { + let checkout_url = stripe_billing + .checkout_with_zed_pro_trial( + &customer_id, + github_login, + vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()], + success_url, + ) + .await + .unwrap(); + + assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); + + let create_checkout_session_calls = stripe_client + .create_checkout_session_calls + .lock() + .drain(..) + .collect::>(); + assert_eq!(create_checkout_session_calls.len(), 1); + let call = create_checkout_session_calls.into_iter().next().unwrap(); + assert_eq!(call.customer, Some(customer_id)); + assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); + assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); + assert_eq!( + call.line_items, + Some(vec![StripeCreateCheckoutSessionLineItems { + price: Some(price.id.to_string()), + quantity: Some(1) + }]) + ); + assert_eq!( + call.payment_method_collection, + Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) + ); + assert_eq!( + call.subscription_data, + Some(StripeCreateCheckoutSessionSubscriptionData { + trial_period_days: Some(60), + trial_settings: Some(StripeSubscriptionTrialSettings { + end_behavior: StripeSubscriptionTrialSettingsEndBehavior { + missing_payment_method: + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, + }, + }), + metadata: Some(std::collections::HashMap::from_iter([( + "promo_feature_flag".into(), + AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into() + )])), + }) + ); + assert_eq!(call.success_url.as_deref(), Some(success_url)); + } +}