Detailed changes
@@ -1182,10 +1182,8 @@ async fn sync_subscription(
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
- let stripe_customer_id = billing_customer
- .stripe_customer_id
- .parse::<stripe::CustomerId>()
- .context("failed to parse Stripe customer ID from database")?;
+ let stripe_customer_id =
+ StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
@@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
+use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@@ -4055,10 +4056,8 @@ async fn get_llm_api_token(
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
- let stripe_customer_id = billing_customer
- .stripe_customer_id
- .parse::<stripe::CustomerId>()
- .context("failed to parse Stripe customer ID from database")?;
+ let stripe_customer_id =
+ StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
@@ -14,8 +14,9 @@ use crate::stripe_client::{
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
- StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
- StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
+ StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
+ StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
+ StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
@@ -23,7 +24,6 @@ use crate::stripe_client::{
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
- real_client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>,
}
@@ -38,7 +38,6 @@ impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
- real_client: client,
state: RwLock::default(),
}
}
@@ -46,8 +45,6 @@ impl StripeBilling {
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
- // This is just temporary until we can remove all usages of the real Stripe client.
- real_client: Arc::new(stripe::Client::new("sk_test")),
client,
state: RwLock::default(),
}
@@ -306,40 +303,33 @@ impl StripeBilling {
pub async fn subscribe_to_zed_free(
&self,
- customer_id: stripe::CustomerId,
- ) -> Result<stripe::Subscription> {
+ customer_id: StripeCustomerId,
+ ) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
- let existing_subscriptions = stripe::Subscription::list(
- &self.real_client,
- &stripe::ListSubscriptions {
- customer: Some(customer_id.clone()),
- status: None,
- ..Default::default()
- },
- )
- .await?;
+ let existing_subscriptions = self
+ .client
+ .list_subscriptions_for_customer(&customer_id)
+ .await?;
let existing_active_subscription =
- existing_subscriptions
- .data
- .into_iter()
- .find(|subscription| {
- subscription.status == SubscriptionStatus::Active
- || subscription.status == SubscriptionStatus::Trialing
- });
+ existing_subscriptions.into_iter().find(|subscription| {
+ subscription.status == SubscriptionStatus::Active
+ || subscription.status == SubscriptionStatus::Trialing
+ });
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
- let mut params = stripe::CreateSubscription::new(customer_id);
- params.items = Some(vec![stripe::CreateSubscriptionItems {
- price: Some(zed_free_price_id.to_string()),
- quantity: Some(1),
- ..Default::default()
- }]);
+ let params = StripeCreateSubscriptionParams {
+ customer: customer_id,
+ items: vec![StripeCreateSubscriptionItems {
+ price: Some(zed_free_price_id),
+ quantity: Some(1),
+ }],
+ };
- let subscription = stripe::Subscription::create(&self.real_client, params).await?;
+ let subscription = self.client.create_subscription(params).await?;
Ok(subscription)
}
@@ -30,21 +30,38 @@ pub struct CreateCustomerParams<'a> {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
+ pub customer: StripeCustomerId,
+ // TODO: Create our own version of this enum.
+ pub status: stripe::SubscriptionStatus,
+ pub current_period_end: i64,
+ pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
+#[derive(Debug)]
+pub struct StripeCreateSubscriptionParams {
+ pub customer: StripeCustomerId,
+ pub items: Vec<StripeCreateSubscriptionItems>,
+}
+
+#[derive(Debug)]
+pub struct StripeCreateSubscriptionItems {
+ pub price: Option<StripePriceId>,
+ pub quantity: Option<u64>,
+}
+
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
@@ -76,7 +93,7 @@ pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
@@ -84,7 +101,7 @@ pub struct StripePrice {
pub recurring: Option<StripePriceRecurring>,
}
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
@@ -160,11 +177,21 @@ pub trait StripeClient: Send + Sync {
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+ async fn list_subscriptions_for_customer(
+ &self,
+ customer_id: &StripeCustomerId,
+ ) -> Result<Vec<StripeSubscription>>;
+
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
+ async fn create_subscription(
+ &self,
+ params: StripeCreateSubscriptionParams,
+ ) -> Result<StripeSubscription>;
+
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@@ -2,6 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
+use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
@@ -10,9 +11,10 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
- StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
- StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
- StripeSubscriptionId, UpdateSubscriptionParams,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+ StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
+ StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
+ StripeSubscriptionItemId, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
@@ -85,6 +87,21 @@ impl StripeClient for FakeStripeClient {
Ok(customer)
}
+ async fn list_subscriptions_for_customer(
+ &self,
+ customer_id: &StripeCustomerId,
+ ) -> Result<Vec<StripeSubscription>> {
+ let subscriptions = self
+ .subscriptions
+ .lock()
+ .values()
+ .filter(|subscription| subscription.customer == *customer_id)
+ .cloned()
+ .collect();
+
+ Ok(subscriptions)
+ }
+
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@@ -96,6 +113,37 @@ impl StripeClient for FakeStripeClient {
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
+ async fn create_subscription(
+ &self,
+ params: StripeCreateSubscriptionParams,
+ ) -> Result<StripeSubscription> {
+ let now = Utc::now();
+
+ let subscription = StripeSubscription {
+ id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
+ customer: params.customer,
+ status: stripe::SubscriptionStatus::Active,
+ current_period_start: now.timestamp(),
+ current_period_end: (now + Duration::days(30)).timestamp(),
+ items: params
+ .items
+ .into_iter()
+ .map(|item| StripeSubscriptionItem {
+ id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
+ price: item
+ .price
+ .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
+ })
+ .collect(),
+ };
+
+ self.subscriptions
+ .lock()
+ .insert(subscription.id.clone(), subscription.clone());
+
+ Ok(subscription)
+ }
+
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@@ -20,10 +20,11 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
- StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
- StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
- StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
- StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
+ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+ StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
+ StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
+ StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
+ StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
@@ -69,6 +70,29 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer))
}
+ async fn list_subscriptions_for_customer(
+ &self,
+ customer_id: &StripeCustomerId,
+ ) -> Result<Vec<StripeSubscription>> {
+ let customer_id = customer_id.try_into()?;
+
+ let subscriptions = stripe::Subscription::list(
+ &self.client,
+ &stripe::ListSubscriptions {
+ customer: Some(customer_id),
+ status: None,
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(subscriptions
+ .data
+ .into_iter()
+ .map(StripeSubscription::from)
+ .collect())
+ }
+
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@@ -80,6 +104,30 @@ impl StripeClient for RealStripeClient {
Ok(StripeSubscription::from(subscription))
}
+ async fn create_subscription(
+ &self,
+ params: StripeCreateSubscriptionParams,
+ ) -> Result<StripeSubscription> {
+ let customer_id = params.customer.try_into()?;
+
+ let mut create_subscription = stripe::CreateSubscription::new(customer_id);
+ create_subscription.items = Some(
+ params
+ .items
+ .into_iter()
+ .map(|item| stripe::CreateSubscriptionItems {
+ price: item.price.map(|price| price.to_string()),
+ quantity: item.quantity,
+ ..Default::default()
+ })
+ .collect(),
+ );
+
+ let subscription = Subscription::create(&self.client, create_subscription).await?;
+
+ Ok(StripeSubscription::from(subscription))
+ }
+
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@@ -220,6 +268,10 @@ impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
+ customer: value.customer.id().into(),
+ status: value.status,
+ current_period_start: value.current_period_start,
+ current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
}
}
@@ -1,5 +1,6 @@
use std::sync::Arc;
+use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
@@ -163,8 +164,13 @@ async fn test_subscribe_to_price() {
.lock()
.insert(price.id.clone(), price.clone());
+ let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
+ customer: StripeCustomerId("cus_test".into()),
+ status: stripe::SubscriptionStatus::Active,
+ current_period_start: now.timestamp(),
+ current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
};
stripe_client
@@ -194,8 +200,13 @@ async fn test_subscribe_to_price() {
// Subscribing to a price that is already on the subscription is a no-op.
{
+ let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
+ customer: StripeCustomerId("cus_test".into()),
+ status: stripe::SubscriptionStatus::Active,
+ current_period_start: now.timestamp(),
+ current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
@@ -215,6 +226,104 @@ async fn test_subscribe_to_price() {
}
}
+#[gpui::test]
+async fn test_subscribe_to_zed_free() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ let zed_pro_price = StripePrice {
+ id: StripePriceId("price_1".into()),
+ unit_amount: Some(0),
+ lookup_key: Some("zed-pro".to_string()),
+ recurring: None,
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
+ let zed_free_price = StripePrice {
+ id: StripePriceId("price_2".into()),
+ unit_amount: Some(0),
+ lookup_key: Some("zed-free".to_string()),
+ recurring: None,
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(zed_free_price.id.clone(), zed_free_price.clone());
+
+ stripe_billing.initialize().await.unwrap();
+
+ // Customer is subscribed to Zed Free when not already subscribed to a plan.
+ {
+ let customer_id = StripeCustomerId("cus_no_plan".into());
+
+ let subscription = stripe_billing
+ .subscribe_to_zed_free(customer_id)
+ .await
+ .unwrap();
+
+ assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
+ }
+
+ // Customer is not subscribed to Zed Free when they already have an active subscription.
+ {
+ let customer_id = StripeCustomerId("cus_active_subscription".into());
+
+ let now = Utc::now();
+ let existing_subscription = StripeSubscription {
+ id: StripeSubscriptionId("sub_existing_active".into()),
+ customer: customer_id.clone(),
+ status: stripe::SubscriptionStatus::Active,
+ current_period_start: now.timestamp(),
+ current_period_end: (now + Duration::days(30)).timestamp(),
+ items: vec![StripeSubscriptionItem {
+ id: StripeSubscriptionItemId("si_test".into()),
+ price: Some(zed_pro_price.clone()),
+ }],
+ };
+ stripe_client.subscriptions.lock().insert(
+ existing_subscription.id.clone(),
+ existing_subscription.clone(),
+ );
+
+ let subscription = stripe_billing
+ .subscribe_to_zed_free(customer_id)
+ .await
+ .unwrap();
+
+ assert_eq!(subscription, existing_subscription);
+ }
+
+ // Customer is not subscribed to Zed Free when they already have a trial subscription.
+ {
+ let customer_id = StripeCustomerId("cus_trial_subscription".into());
+
+ let now = Utc::now();
+ let existing_subscription = StripeSubscription {
+ id: StripeSubscriptionId("sub_existing_trial".into()),
+ customer: customer_id.clone(),
+ status: stripe::SubscriptionStatus::Trialing,
+ current_period_start: now.timestamp(),
+ current_period_end: (now + Duration::days(14)).timestamp(),
+ items: vec![StripeSubscriptionItem {
+ id: StripeSubscriptionItemId("si_test".into()),
+ price: Some(zed_pro_price.clone()),
+ }],
+ };
+ stripe_client.subscriptions.lock().insert(
+ existing_subscription.id.clone(),
+ existing_subscription.clone(),
+ );
+
+ let subscription = stripe_billing
+ .subscribe_to_zed_free(customer_id)
+ .await
+ .unwrap();
+
+ assert_eq!(subscription, existing_subscription);
+ }
+}
+
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();