collab: Add support for subscribing to Zed Pro trials (#28812)

Marshall Bowers created

This PR adds support for subscribing to Zed Pro trials (and then
upgrading from a trial to Zed Pro).

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      | 121 ++++++++++--
crates/collab/src/db/queries/billing_subscriptions.rs |   5 
crates/collab/src/db/tables/billing_subscription.rs   |   4 
crates/collab/src/lib.rs                              |  36 +++
crates/collab/src/stripe_billing.rs                   |  17 -
crates/collab/src/tests/test_server.rs                |   2 
6 files changed, 143 insertions(+), 42 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -15,10 +15,12 @@ use stripe::{
     BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
     CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
     CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
+    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
+    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
     CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
     EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
 };
-use util::ResultExt;
+use util::{ResultExt, maybe};
 
 use crate::api::events::SnowflakeRow;
 use crate::db::billing_subscription::{
@@ -159,6 +161,7 @@ struct BillingSubscriptionJson {
     id: BillingSubscriptionId,
     name: String,
     status: StripeSubscriptionStatus,
+    trial_end_at: Option<String>,
     cancel_at: Option<String>,
     /// Whether this subscription can be canceled.
     is_cancelable: bool,
@@ -188,9 +191,21 @@ async fn list_billing_subscriptions(
                 id: subscription.id,
                 name: match subscription.kind {
                     Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
+                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
+                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
                     None => "Zed LLM Usage".to_string(),
                 },
                 status: subscription.stripe_subscription_status,
+                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
+                    maybe!({
+                        let end_at = subscription.stripe_current_period_end?;
+                        let end_at = DateTime::from_timestamp(end_at, 0)?;
+
+                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
+                    })
+                } else {
+                    None
+                },
                 cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
                     cancel_at
                         .and_utc()
@@ -207,6 +222,7 @@ async fn list_billing_subscriptions(
 #[serde(rename_all = "snake_case")]
 enum ProductCode {
     ZedPro,
+    ZedProTrial,
 }
 
 #[derive(Debug, Deserialize)]
@@ -286,24 +302,36 @@ async fn create_billing_subscription(
         customer.id
     };
 
+    let success_url = format!(
+        "{}/account?checkout_complete=1",
+        app.config.zed_dot_dev_url()
+    );
+
     let checkout_session_url = match body.product {
         Some(ProductCode::ZedPro) => {
-            let success_url = format!(
-                "{}/account?checkout_complete=1",
-                app.config.zed_dot_dev_url()
-            );
             stripe_billing
-                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+                .checkout_with_price(
+                    app.config.zed_pro_price_id()?,
+                    customer_id,
+                    &user.github_login,
+                    &success_url,
+                )
+                .await?
+        }
+        Some(ProductCode::ZedProTrial) => {
+            stripe_billing
+                .checkout_with_price(
+                    app.config.zed_pro_trial_price_id()?,
+                    customer_id,
+                    &user.github_login,
+                    &success_url,
+                )
                 .await?
         }
         None => {
             let default_model =
                 llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
             let stripe_model = stripe_billing.register_model(default_model).await?;
-            let success_url = format!(
-                "{}/account?checkout_complete=1",
-                app.config.zed_dot_dev_url()
-            );
             stripe_billing
                 .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
                 .await?
@@ -322,6 +350,8 @@ enum ManageSubscriptionIntent {
     ///
     /// This will open the Stripe billing portal without putting the user in a specific flow.
     ManageSubscription,
+    /// The user intends to upgrade to Zed Pro.
+    UpgradeToPro,
     /// The user intends to cancel their subscription.
     Cancel,
     /// The user intends to stop the cancellation of their subscription.
@@ -373,11 +403,10 @@ async fn manage_billing_subscription(
         .get_billing_subscription_by_id(body.subscription_id)
         .await?
         .ok_or_else(|| anyhow!("subscription not found"))?;
+    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
+        .context("failed to parse subscription ID")?;
 
     if body.intent == ManageSubscriptionIntent::StopCancellation {
-        let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
-            .context("failed to parse subscription ID")?;
-
         let updated_stripe_subscription = Subscription::update(
             &stripe_client,
             &subscription_id,
@@ -410,6 +439,47 @@ async fn manage_billing_subscription(
 
     let flow = match body.intent {
         ManageSubscriptionIntent::ManageSubscription => None,
+        ManageSubscriptionIntent::UpgradeToPro => {
+            let zed_pro_price_id = app.config.zed_pro_price_id()?;
+            let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
+            let zed_free_price_id = app.config.zed_free_price_id()?;
+
+            let stripe_subscription =
+                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
+
+            let subscription_item_to_update = stripe_subscription
+                .items
+                .data
+                .iter()
+                .find_map(|item| {
+                    let price = item.price.as_ref()?;
+
+                    if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
+                        Some(item.id.clone())
+                    } else {
+                        None
+                    }
+                })
+                .ok_or_else(|| anyhow!("No subscription item to update"))?;
+
+            Some(CreateBillingPortalSessionFlowData {
+                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
+                subscription_update_confirm: Some(
+                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
+                        subscription: subscription.stripe_subscription_id,
+                        items: vec![
+                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
+                                id: subscription_item_to_update.to_string(),
+                                price: Some(zed_pro_price_id.to_string()),
+                                quantity: Some(1),
+                            },
+                        ],
+                        discounts: None,
+                    },
+                ),
+                ..Default::default()
+            })
+        }
         ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
             type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
             after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
@@ -696,22 +766,25 @@ async fn handle_customer_subscription_event(
 
     log::info!("handling Stripe {} event: {}", event.type_, event.id);
 
-    let subscription_kind =
-        if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
-            let has_zed_pro_price = subscription.items.data.iter().any(|item| {
-                item.price
-                    .as_ref()
-                    .map_or(false, |price| price.id.as_str() == zed_pro_price_id)
-            });
+    let subscription_kind = maybe!({
+        let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
+        let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
+        let zed_free_price_id = app.config.zed_free_price_id().ok()?;
+
+        subscription.items.data.iter().find_map(|item| {
+            let price = item.price.as_ref()?;
 
-            if has_zed_pro_price {
+            if price.id == zed_pro_price_id {
                 Some(SubscriptionKind::ZedPro)
+            } else if price.id == zed_pro_trial_price_id {
+                Some(SubscriptionKind::ZedProTrial)
+            } else if price.id == zed_free_price_id {
+                Some(SubscriptionKind::ZedFree)
             } else {
                 None
             }
-        } else {
-            None
-        };
+        })
+    });
 
     let billing_customer =
         find_or_create_billing_customer(app, stripe_client, subscription.customer)

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -62,11 +62,14 @@ impl Database {
             billing_subscription::Entity::update(billing_subscription::ActiveModel {
                 id: ActiveValue::set(id),
                 billing_customer_id: params.billing_customer_id.clone(),
+                kind: params.kind.clone(),
                 stripe_subscription_id: params.stripe_subscription_id.clone(),
                 stripe_subscription_status: params.stripe_subscription_status.clone(),
                 stripe_cancel_at: params.stripe_cancel_at.clone(),
                 stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
-                ..Default::default()
+                stripe_current_period_start: params.stripe_current_period_start.clone(),
+                stripe_current_period_end: params.stripe_current_period_end.clone(),
+                created_at: ActiveValue::not_set(),
             })
             .exec(&*tx)
             .await?;

crates/collab/src/db/tables/billing_subscription.rs 🔗

@@ -43,6 +43,10 @@ impl ActiveModelBehavior for ActiveModel {}
 pub enum SubscriptionKind {
     #[sea_orm(string_value = "zed_pro")]
     ZedPro,
+    #[sea_orm(string_value = "zed_pro_trial")]
+    ZedProTrial,
+    #[sea_orm(string_value = "zed_free")]
+    ZedFree,
 }
 
 /// The status of a Stripe subscription.

crates/collab/src/lib.rs 🔗

@@ -183,6 +183,8 @@ pub struct Config {
     pub auto_join_channel_id: Option<ChannelId>,
     pub stripe_api_key: Option<String>,
     pub stripe_zed_pro_price_id: Option<String>,
+    pub stripe_zed_pro_trial_price_id: Option<String>,
+    pub stripe_zed_free_price_id: Option<String>,
     pub supermaven_admin_api_key: Option<Arc<str>>,
     pub user_backfiller_github_access_token: Option<Arc<str>>,
 }
@@ -201,6 +203,29 @@ impl Config {
         }
     }
 
+    pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+        Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
+    }
+
+    pub fn zed_pro_trial_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+        Self::parse_stripe_price_id(
+            "Zed Pro Trial",
+            self.stripe_zed_pro_trial_price_id.as_deref(),
+        )
+    }
+
+    pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+        Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
+    }
+
+    fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
+        use std::str::FromStr as _;
+
+        let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
+
+        Ok(stripe::PriceId::from_str(price_id)?)
+    }
+
     #[cfg(test)]
     pub fn test() -> Self {
         Self {
@@ -239,6 +264,8 @@ impl Config {
             seed_path: None,
             stripe_api_key: None,
             stripe_zed_pro_price_id: None,
+            stripe_zed_pro_trial_price_id: None,
+            stripe_zed_free_price_id: None,
             supermaven_admin_api_key: None,
             user_backfiller_github_access_token: None,
             kinesis_region: None,
@@ -324,12 +351,9 @@ impl AppState {
             llm_db,
             livekit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
-            stripe_billing: stripe_client.clone().map(|stripe_client| {
-                Arc::new(StripeBilling::new(
-                    stripe_client,
-                    config.stripe_zed_pro_price_id.clone(),
-                ))
-            }),
+            stripe_billing: stripe_client
+                .clone()
+                .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
             stripe_client,
             rate_limiter: Arc::new(RateLimiter::new(db)),
             executor,

crates/collab/src/stripe_billing.rs 🔗

@@ -1,16 +1,16 @@
 use std::sync::Arc;
 
 use crate::{Cents, Result, llm};
-use anyhow::{Context as _, anyhow};
+use anyhow::Context as _;
 use chrono::{Datelike, Utc};
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
+use stripe::PriceId;
 use tokio::sync::RwLock;
 
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
     client: Arc<stripe::Client>,
-    zed_pro_price_id: Option<String>,
 }
 
 #[derive(Default)]
@@ -32,11 +32,10 @@ struct StripeBillingPrice {
 }
 
 impl StripeBilling {
-    pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
+    pub fn new(client: Arc<stripe::Client>) -> Self {
         Self {
             client,
             state: RwLock::default(),
-            zed_pro_price_id,
         }
     }
 
@@ -385,23 +384,19 @@ impl StripeBilling {
         Ok(session.url.context("no checkout session URL")?)
     }
 
-    pub async fn checkout_with_zed_pro(
+    pub async fn checkout_with_price(
         &self,
+        price_id: PriceId,
         customer_id: stripe::CustomerId,
         github_login: &str,
         success_url: &str,
     ) -> Result<String> {
-        let zed_pro_price_id = self
-            .zed_pro_price_id
-            .as_ref()
-            .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
-
         let mut params = stripe::CreateCheckoutSession::new();
         params.mode = Some(stripe::CheckoutSessionMode::Subscription);
         params.customer = Some(customer_id);
         params.client_reference_id = Some(github_login);
         params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
-            price: Some(zed_pro_price_id.clone()),
+            price: Some(price_id.to_string()),
             quantity: Some(1),
             ..Default::default()
         }]);

crates/collab/src/tests/test_server.rs 🔗

@@ -558,6 +558,8 @@ impl TestServer {
                 seed_path: None,
                 stripe_api_key: None,
                 stripe_zed_pro_price_id: None,
+                stripe_zed_pro_trial_price_id: None,
+                stripe_zed_free_price_id: None,
                 supermaven_admin_api_key: None,
                 user_backfiller_github_access_token: None,
                 kinesis_region: None,