Detailed changes
@@ -1578,7 +1578,7 @@ async fn sync_model_request_usage_with_stripe(
};
stripe_billing
- .subscribe_to_price(&stripe_subscription_id, price)
+ .subscribe_to_price(&stripe_subscription_id.into(), price)
.await?;
stripe_billing
.bill_model_request_usage(
@@ -13,6 +13,9 @@ use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+ StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
+ UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
+ UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
pub struct StripeBilling {
@@ -166,14 +169,12 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
- subscription_id: &stripe::SubscriptionId,
+ subscription_id: &StripeSubscriptionId,
price: &StripePrice,
) -> Result<()> {
- let subscription =
- stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
+ let subscription = self.client.get_subscription(subscription_id).await?;
- let price_id = price.id.clone().try_into()?;
- if subscription_contains_price(&subscription, &price_id) {
+ if subscription_contains_price(&subscription, &price.id) {
return Ok(());
}
@@ -182,23 +183,21 @@ impl StripeBilling {
let price_per_unit = price.unit_amount.unwrap_or_default();
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
- stripe::Subscription::update(
- &self.real_client,
- subscription_id,
- stripe::UpdateSubscription {
- items: Some(vec![stripe::UpdateSubscriptionItems {
- price: Some(price.id.to_string()),
- ..Default::default()
- }]),
- trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
- end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
- missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
- },
- }),
- ..Default::default()
- },
- )
- .await?;
+ self.client
+ .update_subscription(
+ subscription_id,
+ UpdateSubscriptionParams {
+ items: Some(vec![UpdateSubscriptionItems {
+ price: Some(price.id.clone()),
+ }]),
+ trial_settings: Some(UpdateSubscriptionTrialSettings {
+ end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
+ missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
+ },
+ }),
+ },
+ )
+ .await?;
Ok(())
}
@@ -419,10 +418,10 @@ struct StripeCreateMeterEventPayload<'a> {
}
fn subscription_contains_price(
- subscription: &stripe::Subscription,
- price_id: &stripe::PriceId,
+ subscription: &StripeSubscription,
+ price_id: &StripePriceId,
) -> bool {
- subscription.items.data.iter().any(|item| {
+ subscription.items.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)
@@ -26,6 +26,52 @@ pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripeSubscriptionId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeSubscription {
+ pub id: StripeSubscriptionId,
+ pub items: Vec<StripeSubscriptionItem>,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripeSubscriptionItemId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeSubscriptionItem {
+ pub id: StripeSubscriptionItemId,
+ pub price: Option<StripePrice>,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionParams {
+ pub items: Option<Vec<UpdateSubscriptionItems>>,
+ pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct UpdateSubscriptionItems {
+ pub price: Option<StripePriceId>,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionTrialSettings {
+ pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
+}
+
+#[derive(Debug, Clone)]
+pub struct UpdateSubscriptionTrialSettingsEndBehavior {
+ pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
+ Cancel,
+ CreateInvoice,
+ Pause,
+}
+
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
@@ -57,6 +103,17 @@ pub trait StripeClient: Send + Sync {
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+ async fn get_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ ) -> Result<StripeSubscription>;
+
+ async fn update_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ params: UpdateSubscriptionParams,
+ ) -> Result<()>;
+
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
@@ -1,6 +1,6 @@
use std::sync::Arc;
-use anyhow::Result;
+use anyhow::{Result, anyhow};
use async_trait::async_trait;
use collections::HashMap;
use parking_lot::Mutex;
@@ -8,11 +8,15 @@ use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
- StripeMeterId, StripePrice, StripePriceId,
+ StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
+ UpdateSubscriptionParams,
};
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+ pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
+ pub update_subscription_calls:
+ Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
}
@@ -21,6 +25,8 @@ impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
+ subscriptions: Arc::new(Mutex::new(HashMap::default())),
+ update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
}
@@ -52,6 +58,31 @@ impl StripeClient for FakeStripeClient {
Ok(customer)
}
+ async fn get_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ ) -> Result<StripeSubscription> {
+ self.subscriptions
+ .lock()
+ .get(subscription_id)
+ .cloned()
+ .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
+ }
+
+ async fn update_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ params: UpdateSubscriptionParams,
+ ) -> Result<()> {
+ let subscription = self.get_subscription(subscription_id).await?;
+
+ self.update_subscription_calls
+ .lock()
+ .push((subscription.id, params));
+
+ Ok(())
+ }
+
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
@@ -4,11 +4,17 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use serde::Serialize;
-use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
+use stripe::{
+ CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
+ SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
+ UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
+ UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+};
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
- StripePriceId, StripePriceRecurring,
+ StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
+ StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
@@ -53,6 +59,46 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer))
}
+ async fn get_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ ) -> Result<StripeSubscription> {
+ let subscription_id = subscription_id.try_into()?;
+
+ let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+
+ Ok(StripeSubscription::from(subscription))
+ }
+
+ async fn update_subscription(
+ &self,
+ subscription_id: &StripeSubscriptionId,
+ params: UpdateSubscriptionParams,
+ ) -> Result<()> {
+ let subscription_id = subscription_id.try_into()?;
+
+ stripe::Subscription::update(
+ &self.client,
+ &subscription_id,
+ stripe::UpdateSubscription {
+ items: params.items.map(|items| {
+ items
+ .into_iter()
+ .map(|item| UpdateSubscriptionItems {
+ price: item.price.map(|price| price.to_string()),
+ ..Default::default()
+ })
+ .collect()
+ }),
+ trial_settings: params.trial_settings.map(Into::into),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(())
+ }
+
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
@@ -108,6 +154,80 @@ impl From<Customer> for StripeCustomer {
}
}
+impl From<SubscriptionId> for StripeSubscriptionId {
+ fn from(value: SubscriptionId) -> Self {
+ Self(value.as_str().into())
+ }
+}
+
+impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
+ type Error = anyhow::Error;
+
+ fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
+ Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
+ }
+}
+
+impl From<Subscription> for StripeSubscription {
+ fn from(value: Subscription) -> Self {
+ Self {
+ id: value.id.into(),
+ items: value.items.data.into_iter().map(Into::into).collect(),
+ }
+ }
+}
+
+impl From<SubscriptionItemId> for StripeSubscriptionItemId {
+ fn from(value: SubscriptionItemId) -> Self {
+ Self(value.as_str().into())
+ }
+}
+
+impl From<SubscriptionItem> for StripeSubscriptionItem {
+ fn from(value: SubscriptionItem) -> Self {
+ Self {
+ id: value.id.into(),
+ price: value.price.map(Into::into),
+ }
+ }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
+ for UpdateSubscriptionTrialSettings
+{
+ fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
+ Self {
+ end_behavior: value.end_behavior.into(),
+ }
+ }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
+ for UpdateSubscriptionTrialSettingsEndBehavior
+{
+ fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
+ Self {
+ missing_payment_method: value.missing_payment_method.into(),
+ }
+ }
+}
+
+impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
+ for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
+{
+ fn from(
+ value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
+ ) -> Self {
+ match value {
+ crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
+ crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
+ Self::CreateInvoice
+ }
+ crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
+ }
+ }
+}
+
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
@@ -5,6 +5,8 @@ use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+ StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+ UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -141,3 +143,70 @@ async fn test_find_or_create_customer_by_email() {
assert_eq!(customer.email.as_deref(), Some(email));
}
}
+
+#[gpui::test]
+async fn test_subscribe_to_price() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ let price = StripePrice {
+ id: StripePriceId("price_test".into()),
+ unit_amount: Some(2000),
+ lookup_key: Some("test-price".to_string()),
+ recurring: None,
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(price.id.clone(), price.clone());
+
+ let subscription = StripeSubscription {
+ id: StripeSubscriptionId("sub_test".into()),
+ items: vec![],
+ };
+ stripe_client
+ .subscriptions
+ .lock()
+ .insert(subscription.id.clone(), subscription.clone());
+
+ stripe_billing
+ .subscribe_to_price(&subscription.id, &price)
+ .await
+ .unwrap();
+
+ let update_subscription_calls = stripe_client
+ .update_subscription_calls
+ .lock()
+ .iter()
+ .map(|(id, params)| (id.clone(), params.clone()))
+ .collect::<Vec<_>>();
+ assert_eq!(update_subscription_calls.len(), 1);
+ assert_eq!(update_subscription_calls[0].0, subscription.id);
+ assert_eq!(
+ update_subscription_calls[0].1.items,
+ Some(vec![UpdateSubscriptionItems {
+ price: Some(price.id.clone())
+ }])
+ );
+
+ // Subscribing to a price that is already on the subscription is a no-op.
+ {
+ let subscription = StripeSubscription {
+ id: StripeSubscriptionId("sub_test".into()),
+ items: vec![StripeSubscriptionItem {
+ id: StripeSubscriptionItemId("si_test".into()),
+ price: Some(price.clone()),
+ }],
+ };
+ stripe_client
+ .subscriptions
+ .lock()
+ .insert(subscription.id.clone(), subscription.clone());
+
+ stripe_billing
+ .subscribe_to_price(&subscription.id, &price)
+ .await
+ .unwrap();
+
+ assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
+ }
+}