From 90dec1d4512555bc97cd424ecc0dd3f45ad46eb4 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 15 Apr 2025 11:45:51 -0400 Subject: [PATCH] collab: Add Zed Pro checkout flow (#28776) This PR adds support for initiating a checkout flow for Zed Pro. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 40 ++++++++++++++++++++------ crates/collab/src/lib.rs | 11 +++++-- crates/collab/src/stripe_billing.rs | 32 +++++++++++++++++++-- crates/collab/src/tests/test_server.rs | 1 + 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 72b8b89f66500730831c070cd957fc68d7313cab..4fba728cae4757cd9a35931455cc8a4ce667d7ad 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -198,9 +198,16 @@ async fn list_billing_subscriptions( })) } +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ProductCode { + ZedPro, +} + #[derive(Debug, Deserialize)] struct CreateBillingSubscriptionBody { github_user_id: i32, + product: Option, } #[derive(Debug, Serialize)] @@ -274,15 +281,30 @@ async fn create_billing_subscription( customer.id }; - let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; - let stripe_model = stripe_billing.register_model(default_model).await?; - let success_url = format!( - "{}/account?checkout_complete=1", - app.config.zed_dot_dev_url() - ); - let checkout_session_url = stripe_billing - .checkout(customer_id, &user.github_login, &stripe_model, &success_url) - .await?; + let checkout_session_url = match body.product { + Some(ProductCode::ZedPro) => { + let success_url = format!( + "{}/account?checkout_complete=1", + app.config.zed_dot_dev_url() + ); + stripe_billing + .checkout_with_zed_pro(customer_id, &user.github_login, &success_url) + .await? + } + None => { + let default_model = + llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; + let stripe_model = stripe_billing.register_model(default_model).await?; + let success_url = format!( + "{}/account?checkout_complete=1", + app.config.zed_dot_dev_url() + ); + stripe_billing + .checkout(customer_id, &user.github_login, &stripe_model, &success_url) + .await? + } + }; + Ok(Json(CreateBillingSubscriptionResponse { checkout_session_url, })) diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 2e682d287878aed8fb68160d307675d03a08d1df..6cc4274ab12ada87dde0bd0638840347c95c4509 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -182,6 +182,7 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, + pub stripe_zed_pro_price_id: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -237,6 +238,7 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_zed_pro_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -322,9 +324,12 @@ impl AppState { llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client - .clone() - .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), + stripe_billing: stripe_client.clone().map(|stripe_client| { + Arc::new(StripeBilling::new( + stripe_client, + config.stripe_zed_pro_price_id.clone(), + )) + }), stripe_client, rate_limiter: Arc::new(RateLimiter::new(db)), executor, diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 8ad5105c9b321f2f7002f291f05ccbb9e57ef0b7..edbeab1b0109743eefc72e90476359d4d0ff1ddb 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{Cents, Result, llm}; -use anyhow::Context as _; +use anyhow::{Context as _, anyhow}; use chrono::{Datelike, Utc}; use collections::HashMap; use serde::{Deserialize, Serialize}; @@ -10,6 +10,7 @@ use tokio::sync::RwLock; pub struct StripeBilling { state: RwLock, client: Arc, + zed_pro_price_id: Option, } #[derive(Default)] @@ -31,10 +32,11 @@ struct StripeBillingPrice { } impl StripeBilling { - pub fn new(client: Arc) -> Self { + pub fn new(client: Arc, zed_pro_price_id: Option) -> Self { Self { client, state: RwLock::default(), + zed_pro_price_id, } } @@ -382,6 +384,32 @@ impl StripeBilling { let session = stripe::CheckoutSession::create(&self.client, params).await?; Ok(session.url.context("no checkout session URL")?) } + + pub async fn checkout_with_zed_pro( + &self, + customer_id: stripe::CustomerId, + github_login: &str, + success_url: &str, + ) -> Result { + let zed_pro_price_id = self + .zed_pro_price_id + .as_ref() + .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?; + + let mut params = stripe::CreateCheckoutSession::new(); + params.mode = Some(stripe::CheckoutSessionMode::Subscription); + params.customer = Some(customer_id); + params.client_reference_id = Some(github_login); + params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { + price: Some(zed_pro_price_id.clone()), + quantity: Some(1), + ..Default::default() + }]); + params.success_url = Some(success_url); + + let session = stripe::CheckoutSession::create(&self.client, params).await?; + Ok(session.url.context("no checkout session URL")?) + } } #[derive(Serialize)] diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 77fbf1f7f71285774ac94369642d520015c4b5d1..8c6130def2f3d249ced7144b99a013d826b9e879 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -557,6 +557,7 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_zed_pro_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None,