Detailed changes
@@ -11,8 +11,9 @@ use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
- RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
- StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+ RealStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
+ StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+ StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
@@ -245,6 +246,7 @@ impl StripeBilling {
quantity: Some(1),
}]);
params.success_url = Some(success_url);
+ params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
@@ -298,6 +300,7 @@ impl StripeBilling {
quantity: Some(1),
}]);
params.success_url = Some(success_url);
+ params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
@@ -148,6 +148,12 @@ pub struct StripeCreateMeterEventPayload<'a> {
pub stripe_customer_id: &'a StripeCustomerId,
}
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum StripeBillingAddressCollection {
+ Auto,
+ Required,
+}
+
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
@@ -157,6 +163,7 @@ pub struct StripeCreateCheckoutSessionParams<'a> {
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
+ pub billing_address_collection: Option<StripeBillingAddressCollection>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@@ -8,8 +8,8 @@ use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
- CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
- StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+ CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
+ StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
@@ -35,6 +35,7 @@ pub struct StripeCreateCheckoutSessionCall {
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
+ pub billing_address_collection: Option<StripeBillingAddressCollection>,
}
pub struct FakeStripeClient {
@@ -231,6 +232,7 @@ impl StripeClient for FakeStripeClient {
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
+ billing_address_collection: params.billing_address_collection,
});
Ok(StripeCheckoutSession {
@@ -17,9 +17,10 @@ use stripe::{
};
use crate::stripe_client::{
- CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason,
- StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
- StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
+ CreateCustomerParams, StripeBillingAddressCollection, StripeCancellationDetails,
+ StripeCancellationDetailsReason, StripeCheckoutSession, StripeCheckoutSessionMode,
+ StripeCheckoutSessionPaymentMethodCollection, StripeClient,
+ StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
@@ -444,6 +445,7 @@ impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSessio
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
+ billing_address_collection: value.billing_address_collection.map(Into::into),
..Default::default()
})
}
@@ -526,3 +528,16 @@ impl From<CheckoutSession> for StripeCheckoutSession {
Self { url: value.url }
}
}
+
+impl From<StripeBillingAddressCollection> for stripe::CheckoutSessionBillingAddressCollection {
+ fn from(value: StripeBillingAddressCollection) -> Self {
+ match value {
+ StripeBillingAddressCollection::Auto => {
+ stripe::CheckoutSessionBillingAddressCollection::Auto
+ }
+ StripeBillingAddressCollection::Required => {
+ stripe::CheckoutSessionBillingAddressCollection::Required
+ }
+ }
+ }
+}
@@ -6,11 +6,12 @@ use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
- FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
- StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
- StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
- StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
- StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+ FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
+ StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeMeter, StripeMeterId,
+ StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
+ StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
+ StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
};
@@ -426,6 +427,10 @@ async fn test_checkout_with_zed_pro() {
assert_eq!(call.payment_method_collection, None);
assert_eq!(call.subscription_data, None);
assert_eq!(call.success_url.as_deref(), Some(success_url));
+ assert_eq!(
+ call.billing_address_collection,
+ Some(StripeBillingAddressCollection::Required)
+ );
}
}
@@ -507,6 +512,10 @@ async fn test_checkout_with_zed_pro_trial() {
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
+ assert_eq!(
+ call.billing_address_collection,
+ Some(StripeBillingAddressCollection::Required)
+ );
}
// Successful checkout with extended trial.
@@ -561,5 +570,9 @@ async fn test_checkout_with_zed_pro_trial() {
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
+ assert_eq!(
+ call.billing_address_collection,
+ Some(StripeBillingAddressCollection::Required)
+ );
}
}