Detailed changes
@@ -338,13 +338,11 @@ async fn create_billing_subscription(
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
- CustomerId::from_str(&existing_customer.stripe_customer_id)
- .context("failed to parse customer ID")?
+ StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
} else {
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?
- .try_into()?
};
let success_url = format!(
@@ -355,7 +353,7 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
ProductCode::ZedPro => {
stripe_billing
- .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+ .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
.await?
}
ProductCode::ZedProTrial => {
@@ -372,7 +370,7 @@ async fn create_billing_subscription(
stripe_billing
.checkout_with_zed_pro_trial(
- customer_id,
+ &customer_id,
&user.github_login,
feature_flags,
&success_url,
@@ -11,11 +11,14 @@ use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
- RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
- StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
- StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
- UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
- UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+ RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
+ StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+ StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+ StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
+ StripeSubscriptionTrialSettingsEndBehavior,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
+ UpdateSubscriptionParams,
};
pub struct StripeBilling {
@@ -190,9 +193,9 @@ impl StripeBilling {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
- trial_settings: Some(UpdateSubscriptionTrialSettings {
- end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
- missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
+ trial_settings: Some(StripeSubscriptionTrialSettings {
+ end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
@@ -228,30 +231,29 @@ impl StripeBilling {
pub async fn checkout_with_zed_pro(
&self,
- customer_id: stripe::CustomerId,
+ customer_id: &StripeCustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
- let mut params = stripe::CreateCheckoutSession::new();
- params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+ let mut params = StripeCreateCheckoutSessionParams::default();
+ params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
- params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+ params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
- ..Default::default()
}]);
params.success_url = Some(success_url);
- let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
+ let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro_trial(
&self,
- customer_id: stripe::CustomerId,
+ customer_id: &StripeCustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
@@ -272,34 +274,33 @@ impl StripeBilling {
);
}
- let mut params = stripe::CreateCheckoutSession::new();
- params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
+ let mut params = StripeCreateCheckoutSessionParams::default();
+ params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(trial_period_days),
- trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
- end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
- missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
- }
+ trial_settings: Some(StripeSubscriptionTrialSettings {
+ end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method:
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+ },
}),
metadata: if !subscription_metadata.is_empty() {
Some(subscription_metadata)
} else {
None
},
- ..Default::default()
});
- params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+ params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.payment_method_collection =
- Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
+ Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
- params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+ params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
- ..Default::default()
}]);
params.success_url = Some(success_url);
- let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
+ let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}
@@ -2,6 +2,7 @@
mod fake_stripe_client;
mod real_stripe_client;
+use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
@@ -47,7 +48,7 @@ pub struct StripeSubscriptionItem {
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
- pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
+ pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
@@ -55,18 +56,18 @@ pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
-#[derive(Debug, Clone)]
-pub struct UpdateSubscriptionTrialSettings {
- pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeSubscriptionTrialSettings {
+ pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
-#[derive(Debug, Clone)]
-pub struct UpdateSubscriptionTrialSettingsEndBehavior {
- pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeSubscriptionTrialSettingsEndBehavior {
+ pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
-pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
+pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
@@ -111,6 +112,48 @@ pub struct StripeCreateMeterEventPayload<'a> {
pub stripe_customer_id: &'a StripeCustomerId,
}
+#[derive(Debug, Default)]
+pub struct StripeCreateCheckoutSessionParams<'a> {
+ pub customer: Option<&'a StripeCustomerId>,
+ pub client_reference_id: Option<&'a str>,
+ pub mode: Option<StripeCheckoutSessionMode>,
+ pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
+ pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
+ pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
+ pub success_url: Option<&'a str>,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeCheckoutSessionMode {
+ Payment,
+ Setup,
+ Subscription,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeCreateCheckoutSessionLineItems {
+ pub price: Option<String>,
+ pub quantity: Option<u64>,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeCheckoutSessionPaymentMethodCollection {
+ Always,
+ IfRequired,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct StripeCreateCheckoutSessionSubscriptionData {
+ pub metadata: Option<HashMap<String, String>>,
+ pub trial_period_days: Option<u32>,
+ pub trial_settings: Option<StripeSubscriptionTrialSettings>,
+}
+
+#[derive(Debug)]
+pub struct StripeCheckoutSession {
+ pub url: Option<String>,
+}
+
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
@@ -133,4 +176,9 @@ pub trait StripeClient: Send + Sync {
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
+
+ async fn create_checkout_session(
+ &self,
+ params: StripeCreateCheckoutSessionParams<'_>,
+ ) -> Result<StripeCheckoutSession>;
}
@@ -7,7 +7,10 @@ use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
- CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+ CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
+ StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+ StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, UpdateSubscriptionParams,
};
@@ -21,6 +24,17 @@ pub struct StripeCreateMeterEventCall {
pub timestamp: Option<i64>,
}
+#[derive(Debug, Clone)]
+pub struct StripeCreateCheckoutSessionCall {
+ pub customer: Option<StripeCustomerId>,
+ pub client_reference_id: Option<String>,
+ pub mode: Option<StripeCheckoutSessionMode>,
+ pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
+ pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
+ pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
+ pub success_url: Option<String>,
+}
+
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
@@ -29,6 +43,7 @@ pub struct FakeStripeClient {
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
+ pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
@@ -40,6 +55,7 @@ impl FakeStripeClient {
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
+ create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
@@ -119,4 +135,25 @@ impl StripeClient for FakeStripeClient {
Ok(())
}
+
+ async fn create_checkout_session(
+ &self,
+ params: StripeCreateCheckoutSessionParams<'_>,
+ ) -> Result<StripeCheckoutSession> {
+ self.create_checkout_session_calls
+ .lock()
+ .push(StripeCreateCheckoutSessionCall {
+ customer: params.customer.cloned(),
+ client_reference_id: params.client_reference_id.map(|id| id.to_string()),
+ mode: params.mode,
+ line_items: params.line_items,
+ payment_method_collection: params.payment_method_collection,
+ subscription_data: params.subscription_data,
+ success_url: params.success_url.map(|url| url.to_string()),
+ });
+
+ Ok(StripeCheckoutSession {
+ url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
+ })
+ }
}
@@ -5,6 +5,11 @@ use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::Serialize;
use stripe::{
+ CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
+ CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
+ CreateCheckoutSessionSubscriptionDataTrialSettings,
+ CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
+ CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
@@ -12,10 +17,14 @@ use stripe::{
};
use crate::stripe_client::{
- CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+ CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
+ StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+ StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
- UpdateSubscriptionParams,
+ StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
@@ -150,6 +159,16 @@ impl StripeClient for RealStripeClient {
Err(error) => Err(anyhow!(error)),
}
}
+
+ async fn create_checkout_session(
+ &self,
+ params: StripeCreateCheckoutSessionParams<'_>,
+ ) -> Result<StripeCheckoutSession> {
+ let params = params.try_into()?;
+ let session = CheckoutSession::create(&self.client, params).await?;
+
+ Ok(session.into())
+ }
}
impl From<CustomerId> for StripeCustomerId {
@@ -166,6 +185,14 @@ impl TryFrom<StripeCustomerId> for CustomerId {
}
}
+impl TryFrom<&StripeCustomerId> for CustomerId {
+ type Error = anyhow::Error;
+
+ fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
+ Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
+ }
+}
+
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
@@ -213,38 +240,34 @@ impl From<SubscriptionItem> for StripeSubscriptionItem {
}
}
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
- for UpdateSubscriptionTrialSettings
-{
- fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
+impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
+ fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
+impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
- fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
+ fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
-impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
- fn from(
- value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
- ) -> Self {
+ fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
- crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
- crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
- crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
@@ -279,3 +302,103 @@ impl From<Recurring> for StripePriceRecurring {
Self { meter: value.meter }
}
}
+
+impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
+ type Error = anyhow::Error;
+
+ fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
+ Ok(Self {
+ customer: value
+ .customer
+ .map(|customer_id| customer_id.try_into())
+ .transpose()?,
+ client_reference_id: value.client_reference_id,
+ mode: value.mode.map(Into::into),
+ line_items: value
+ .line_items
+ .map(|line_items| line_items.into_iter().map(Into::into).collect()),
+ payment_method_collection: value.payment_method_collection.map(Into::into),
+ subscription_data: value.subscription_data.map(Into::into),
+ success_url: value.success_url,
+ ..Default::default()
+ })
+ }
+}
+
+impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
+ fn from(value: StripeCheckoutSessionMode) -> Self {
+ match value {
+ StripeCheckoutSessionMode::Payment => Self::Payment,
+ StripeCheckoutSessionMode::Setup => Self::Setup,
+ StripeCheckoutSessionMode::Subscription => Self::Subscription,
+ }
+ }
+}
+
+impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
+ fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
+ Self {
+ price: value.price,
+ quantity: value.quantity,
+ ..Default::default()
+ }
+ }
+}
+
+impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
+ fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
+ match value {
+ StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
+ StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
+ }
+ }
+}
+
+impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
+ fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
+ Self {
+ trial_period_days: value.trial_period_days,
+ trial_settings: value.trial_settings.map(Into::into),
+ metadata: value.metadata,
+ ..Default::default()
+ }
+ }
+}
+
+impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
+ fn from(value: StripeSubscriptionTrialSettings) -> Self {
+ Self {
+ end_behavior: value.end_behavior.into(),
+ }
+ }
+}
+
+impl From<StripeSubscriptionTrialSettingsEndBehavior>
+ for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
+{
+ fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
+ Self {
+ missing_payment_method: value.missing_payment_method.into(),
+ }
+ }
+}
+
+impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+ for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
+{
+ fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
+ match value {
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+ Self::CreateInvoice
+ }
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+ }
+ }
+}
+
+impl From<CheckoutSession> for StripeCheckoutSession {
+ fn from(value: CheckoutSession) -> Self {
+ Self { url: value.url }
+ }
+}
@@ -2,11 +2,15 @@ use std::sync::Arc;
use pretty_assertions::assert_eq;
+use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
- FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
- StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
- StripeSubscriptionItemId, UpdateSubscriptionItems,
+ FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
+ StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
+ StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+ StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+ StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -241,3 +245,204 @@ async fn test_bill_model_request_usage() {
);
assert_eq!(create_meter_event_calls[0].value, 73);
}
+
+#[gpui::test]
+async fn test_checkout_with_zed_pro() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ let customer_id = StripeCustomerId("cus_test".into());
+ let github_login = "zeduser1";
+ let success_url = "https://example.com/success";
+
+ // It returns an error when the Zed Pro price doesn't exist.
+ {
+ let result = stripe_billing
+ .checkout_with_zed_pro(&customer_id, github_login, success_url)
+ .await;
+
+ assert!(result.is_err());
+ assert_eq!(
+ result.err().unwrap().to_string(),
+ r#"no price ID found for "zed-pro""#
+ );
+ }
+
+ // Successful checkout.
+ {
+ let price = StripePrice {
+ id: StripePriceId("price_1".into()),
+ unit_amount: Some(2000),
+ lookup_key: Some("zed-pro".to_string()),
+ recurring: None,
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(price.id.clone(), price.clone());
+
+ stripe_billing.initialize().await.unwrap();
+
+ let checkout_url = stripe_billing
+ .checkout_with_zed_pro(&customer_id, github_login, success_url)
+ .await
+ .unwrap();
+
+ assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+ let create_checkout_session_calls = stripe_client
+ .create_checkout_session_calls
+ .lock()
+ .drain(..)
+ .collect::<Vec<_>>();
+ assert_eq!(create_checkout_session_calls.len(), 1);
+ let call = create_checkout_session_calls.into_iter().next().unwrap();
+ assert_eq!(call.customer, Some(customer_id));
+ assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+ assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+ assert_eq!(
+ call.line_items,
+ Some(vec![StripeCreateCheckoutSessionLineItems {
+ price: Some(price.id.to_string()),
+ quantity: Some(1)
+ }])
+ );
+ assert_eq!(call.payment_method_collection, None);
+ assert_eq!(call.subscription_data, None);
+ assert_eq!(call.success_url.as_deref(), Some(success_url));
+ }
+}
+
+#[gpui::test]
+async fn test_checkout_with_zed_pro_trial() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ let customer_id = StripeCustomerId("cus_test".into());
+ let github_login = "zeduser1";
+ let success_url = "https://example.com/success";
+
+ // It returns an error when the Zed Pro price doesn't exist.
+ {
+ let result = stripe_billing
+ .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
+ .await;
+
+ assert!(result.is_err());
+ assert_eq!(
+ result.err().unwrap().to_string(),
+ r#"no price ID found for "zed-pro""#
+ );
+ }
+
+ let price = StripePrice {
+ id: StripePriceId("price_1".into()),
+ unit_amount: Some(2000),
+ lookup_key: Some("zed-pro".to_string()),
+ recurring: None,
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(price.id.clone(), price.clone());
+
+ stripe_billing.initialize().await.unwrap();
+
+ // Successful checkout.
+ {
+ let checkout_url = stripe_billing
+ .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
+ .await
+ .unwrap();
+
+ assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+ let create_checkout_session_calls = stripe_client
+ .create_checkout_session_calls
+ .lock()
+ .drain(..)
+ .collect::<Vec<_>>();
+ assert_eq!(create_checkout_session_calls.len(), 1);
+ let call = create_checkout_session_calls.into_iter().next().unwrap();
+ assert_eq!(call.customer.as_ref(), Some(&customer_id));
+ assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+ assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+ assert_eq!(
+ call.line_items,
+ Some(vec![StripeCreateCheckoutSessionLineItems {
+ price: Some(price.id.to_string()),
+ quantity: Some(1)
+ }])
+ );
+ assert_eq!(
+ call.payment_method_collection,
+ Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
+ );
+ assert_eq!(
+ call.subscription_data,
+ Some(StripeCreateCheckoutSessionSubscriptionData {
+ trial_period_days: Some(14),
+ trial_settings: Some(StripeSubscriptionTrialSettings {
+ end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method:
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+ },
+ }),
+ metadata: None,
+ })
+ );
+ assert_eq!(call.success_url.as_deref(), Some(success_url));
+ }
+
+ // Successful checkout with extended trial.
+ {
+ let checkout_url = stripe_billing
+ .checkout_with_zed_pro_trial(
+ &customer_id,
+ github_login,
+ vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
+ success_url,
+ )
+ .await
+ .unwrap();
+
+ assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
+
+ let create_checkout_session_calls = stripe_client
+ .create_checkout_session_calls
+ .lock()
+ .drain(..)
+ .collect::<Vec<_>>();
+ assert_eq!(create_checkout_session_calls.len(), 1);
+ let call = create_checkout_session_calls.into_iter().next().unwrap();
+ assert_eq!(call.customer, Some(customer_id));
+ assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
+ assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
+ assert_eq!(
+ call.line_items,
+ Some(vec![StripeCreateCheckoutSessionLineItems {
+ price: Some(price.id.to_string()),
+ quantity: Some(1)
+ }])
+ );
+ assert_eq!(
+ call.payment_method_collection,
+ Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
+ );
+ assert_eq!(
+ call.subscription_data,
+ Some(StripeCreateCheckoutSessionSubscriptionData {
+ trial_period_days: Some(60),
+ trial_settings: Some(StripeSubscriptionTrialSettings {
+ end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method:
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
+ },
+ }),
+ metadata: Some(std::collections::HashMap::from_iter([(
+ "promo_feature_flag".into(),
+ AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
+ )])),
+ })
+ );
+ assert_eq!(call.success_url.as_deref(), Some(success_url));
+ }
+}