collab: Add Zed Pro checkout flow (#28776)

Marshall Bowers created

This PR adds support for initiating a checkout flow for Zed Pro.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs       | 40 +++++++++++++++++++++------
crates/collab/src/lib.rs               | 11 +++++--
crates/collab/src/stripe_billing.rs    | 32 +++++++++++++++++++++-
crates/collab/src/tests/test_server.rs |  1 
4 files changed, 70 insertions(+), 14 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -198,9 +198,16 @@ async fn list_billing_subscriptions(
     }))
 }
 
+#[derive(Debug, Clone, Copy, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ProductCode {
+    ZedPro,
+}
+
 #[derive(Debug, Deserialize)]
 struct CreateBillingSubscriptionBody {
     github_user_id: i32,
+    product: Option<ProductCode>,
 }
 
 #[derive(Debug, Serialize)]
@@ -274,15 +281,30 @@ async fn create_billing_subscription(
         customer.id
     };
 
-    let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
-    let stripe_model = stripe_billing.register_model(default_model).await?;
-    let success_url = format!(
-        "{}/account?checkout_complete=1",
-        app.config.zed_dot_dev_url()
-    );
-    let checkout_session_url = stripe_billing
-        .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
-        .await?;
+    let checkout_session_url = match body.product {
+        Some(ProductCode::ZedPro) => {
+            let success_url = format!(
+                "{}/account?checkout_complete=1",
+                app.config.zed_dot_dev_url()
+            );
+            stripe_billing
+                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+                .await?
+        }
+        None => {
+            let default_model =
+                llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
+            let stripe_model = stripe_billing.register_model(default_model).await?;
+            let success_url = format!(
+                "{}/account?checkout_complete=1",
+                app.config.zed_dot_dev_url()
+            );
+            stripe_billing
+                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
+                .await?
+        }
+    };
+
     Ok(Json(CreateBillingSubscriptionResponse {
         checkout_session_url,
     }))

crates/collab/src/lib.rs 🔗

@@ -182,6 +182,7 @@ pub struct Config {
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,
     pub stripe_api_key: Option<String>,
+    pub stripe_zed_pro_price_id: Option<String>,
     pub supermaven_admin_api_key: Option<Arc<str>>,
     pub user_backfiller_github_access_token: Option<Arc<str>>,
 }
@@ -237,6 +238,7 @@ impl Config {
             migrations_path: None,
             seed_path: None,
             stripe_api_key: None,
+            stripe_zed_pro_price_id: None,
             supermaven_admin_api_key: None,
             user_backfiller_github_access_token: None,
             kinesis_region: None,
@@ -322,9 +324,12 @@ impl AppState {
             llm_db,
             livekit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
-            stripe_billing: stripe_client
-                .clone()
-                .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
+            stripe_billing: stripe_client.clone().map(|stripe_client| {
+                Arc::new(StripeBilling::new(
+                    stripe_client,
+                    config.stripe_zed_pro_price_id.clone(),
+                ))
+            }),
             stripe_client,
             rate_limiter: Arc::new(RateLimiter::new(db)),
             executor,

crates/collab/src/stripe_billing.rs 🔗

@@ -1,7 +1,7 @@
 use std::sync::Arc;
 
 use crate::{Cents, Result, llm};
-use anyhow::Context as _;
+use anyhow::{Context as _, anyhow};
 use chrono::{Datelike, Utc};
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
@@ -10,6 +10,7 @@ use tokio::sync::RwLock;
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
     client: Arc<stripe::Client>,
+    zed_pro_price_id: Option<String>,
 }
 
 #[derive(Default)]
@@ -31,10 +32,11 @@ struct StripeBillingPrice {
 }
 
 impl StripeBilling {
-    pub fn new(client: Arc<stripe::Client>) -> Self {
+    pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
         Self {
             client,
             state: RwLock::default(),
+            zed_pro_price_id,
         }
     }
 
@@ -382,6 +384,32 @@ impl StripeBilling {
         let session = stripe::CheckoutSession::create(&self.client, params).await?;
         Ok(session.url.context("no checkout session URL")?)
     }
+
+    pub async fn checkout_with_zed_pro(
+        &self,
+        customer_id: stripe::CustomerId,
+        github_login: &str,
+        success_url: &str,
+    ) -> Result<String> {
+        let zed_pro_price_id = self
+            .zed_pro_price_id
+            .as_ref()
+            .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
+
+        let mut params = stripe::CreateCheckoutSession::new();
+        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+        params.customer = Some(customer_id);
+        params.client_reference_id = Some(github_login);
+        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+            price: Some(zed_pro_price_id.clone()),
+            quantity: Some(1),
+            ..Default::default()
+        }]);
+        params.success_url = Some(success_url);
+
+        let session = stripe::CheckoutSession::create(&self.client, params).await?;
+        Ok(session.url.context("no checkout session URL")?)
+    }
 }
 
 #[derive(Serialize)]

crates/collab/src/tests/test_server.rs 🔗

@@ -557,6 +557,7 @@ impl TestServer {
                 migrations_path: None,
                 seed_path: None,
                 stripe_api_key: None,
+                stripe_zed_pro_price_id: None,
                 supermaven_admin_api_key: None,
                 user_backfiller_github_access_token: None,
                 kinesis_region: None,