From dad6067e187a2e585b3983da53bd0998a7c4fbe4 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 15 Apr 2025 16:49:16 -0400 Subject: [PATCH] collab: Add support for subscribing to Zed Pro trials (#28812) This PR adds support for subscribing to Zed Pro trials (and then upgrading from a trial to Zed Pro). Release Notes: - N/A --- crates/collab/src/api/billing.rs | 121 ++++++++++++++---- .../src/db/queries/billing_subscriptions.rs | 5 +- .../src/db/tables/billing_subscription.rs | 4 + crates/collab/src/lib.rs | 36 +++++- crates/collab/src/stripe_billing.rs | 17 +-- crates/collab/src/tests/test_server.rs | 2 + 6 files changed, 143 insertions(+), 42 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 25a0dfe8397612a723bcc7579d2b9a7a83bac35a..bbb8ac8428e9d14f440d86236022116fde3a4691 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -15,10 +15,12 @@ use stripe::{ BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession, CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus, }; -use util::ResultExt; +use util::{ResultExt, maybe}; use crate::api::events::SnowflakeRow; use crate::db::billing_subscription::{ @@ -159,6 +161,7 @@ struct BillingSubscriptionJson { id: BillingSubscriptionId, name: String, status: StripeSubscriptionStatus, + trial_end_at: Option, cancel_at: Option, /// Whether this subscription can be canceled. is_cancelable: bool, @@ -188,9 +191,21 @@ async fn list_billing_subscriptions( id: subscription.id, name: match subscription.kind { Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(), + Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(), + Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(), None => "Zed LLM Usage".to_string(), }, status: subscription.stripe_subscription_status, + trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) { + maybe!({ + let end_at = subscription.stripe_current_period_end?; + let end_at = DateTime::from_timestamp(end_at, 0)?; + + Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true)) + }) + } else { + None + }, cancel_at: subscription.stripe_cancel_at.map(|cancel_at| { cancel_at .and_utc() @@ -207,6 +222,7 @@ async fn list_billing_subscriptions( #[serde(rename_all = "snake_case")] enum ProductCode { ZedPro, + ZedProTrial, } #[derive(Debug, Deserialize)] @@ -286,24 +302,36 @@ async fn create_billing_subscription( customer.id }; + let success_url = format!( + "{}/account?checkout_complete=1", + app.config.zed_dot_dev_url() + ); + let checkout_session_url = match body.product { Some(ProductCode::ZedPro) => { - let success_url = format!( - "{}/account?checkout_complete=1", - app.config.zed_dot_dev_url() - ); stripe_billing - .checkout_with_zed_pro(customer_id, &user.github_login, &success_url) + .checkout_with_price( + app.config.zed_pro_price_id()?, + customer_id, + &user.github_login, + &success_url, + ) + .await? + } + Some(ProductCode::ZedProTrial) => { + stripe_billing + .checkout_with_price( + app.config.zed_pro_trial_price_id()?, + customer_id, + &user.github_login, + &success_url, + ) .await? } None => { let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; let stripe_model = stripe_billing.register_model(default_model).await?; - let success_url = format!( - "{}/account?checkout_complete=1", - app.config.zed_dot_dev_url() - ); stripe_billing .checkout(customer_id, &user.github_login, &stripe_model, &success_url) .await? @@ -322,6 +350,8 @@ enum ManageSubscriptionIntent { /// /// This will open the Stripe billing portal without putting the user in a specific flow. ManageSubscription, + /// The user intends to upgrade to Zed Pro. + UpgradeToPro, /// The user intends to cancel their subscription. Cancel, /// The user intends to stop the cancellation of their subscription. @@ -373,11 +403,10 @@ async fn manage_billing_subscription( .get_billing_subscription_by_id(body.subscription_id) .await? .ok_or_else(|| anyhow!("subscription not found"))?; + let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) + .context("failed to parse subscription ID")?; if body.intent == ManageSubscriptionIntent::StopCancellation { - let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) - .context("failed to parse subscription ID")?; - let updated_stripe_subscription = Subscription::update( &stripe_client, &subscription_id, @@ -410,6 +439,47 @@ async fn manage_billing_subscription( let flow = match body.intent { ManageSubscriptionIntent::ManageSubscription => None, + ManageSubscriptionIntent::UpgradeToPro => { + let zed_pro_price_id = app.config.zed_pro_price_id()?; + let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?; + let zed_free_price_id = app.config.zed_free_price_id()?; + + let stripe_subscription = + Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; + + let subscription_item_to_update = stripe_subscription + .items + .data + .iter() + .find_map(|item| { + let price = item.price.as_ref()?; + + if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id { + Some(item.id.clone()) + } else { + None + } + }) + .ok_or_else(|| anyhow!("No subscription item to update"))?; + + Some(CreateBillingPortalSessionFlowData { + type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm, + subscription_update_confirm: Some( + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm { + subscription: subscription.stripe_subscription_id, + items: vec![ + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems { + id: subscription_item_to_update.to_string(), + price: Some(zed_pro_price_id.to_string()), + quantity: Some(1), + }, + ], + discounts: None, + }, + ), + ..Default::default() + }) + } ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData { type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { @@ -696,22 +766,25 @@ async fn handle_customer_subscription_event( log::info!("handling Stripe {} event: {}", event.type_, event.id); - let subscription_kind = - if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() { - let has_zed_pro_price = subscription.items.data.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id.as_str() == zed_pro_price_id) - }); + let subscription_kind = maybe!({ + let zed_pro_price_id = app.config.zed_pro_price_id().ok()?; + let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?; + let zed_free_price_id = app.config.zed_free_price_id().ok()?; + + subscription.items.data.iter().find_map(|item| { + let price = item.price.as_ref()?; - if has_zed_pro_price { + if price.id == zed_pro_price_id { Some(SubscriptionKind::ZedPro) + } else if price.id == zed_pro_trial_price_id { + Some(SubscriptionKind::ZedProTrial) + } else if price.id == zed_free_price_id { + Some(SubscriptionKind::ZedFree) } else { None } - } else { - None - }; + }) + }); let billing_customer = find_or_create_billing_customer(app, stripe_client, subscription.customer) diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index a868fd3e1f0e03dd74dcc45a8a17e8e10562e1fe..1eef99beb76b7eeb903312ea6b0f96e69280f224 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -62,11 +62,14 @@ impl Database { billing_subscription::Entity::update(billing_subscription::ActiveModel { id: ActiveValue::set(id), billing_customer_id: params.billing_customer_id.clone(), + kind: params.kind.clone(), stripe_subscription_id: params.stripe_subscription_id.clone(), stripe_subscription_status: params.stripe_subscription_status.clone(), stripe_cancel_at: params.stripe_cancel_at.clone(), stripe_cancellation_reason: params.stripe_cancellation_reason.clone(), - ..Default::default() + stripe_current_period_start: params.stripe_current_period_start.clone(), + stripe_current_period_end: params.stripe_current_period_end.clone(), + created_at: ActiveValue::not_set(), }) .exec(&*tx) .await?; diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 16da6880016fafcfaa223d6f4998ce15e4e01460..d834a2d3ac7723861e57c188af330253e09b2559 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -43,6 +43,10 @@ impl ActiveModelBehavior for ActiveModel {} pub enum SubscriptionKind { #[sea_orm(string_value = "zed_pro")] ZedPro, + #[sea_orm(string_value = "zed_pro_trial")] + ZedProTrial, + #[sea_orm(string_value = "zed_free")] + ZedFree, } /// The status of a Stripe subscription. diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 6cc4274ab12ada87dde0bd0638840347c95c4509..f697e01e43423e055361ab459dc2ff4d7c8b7117 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -183,6 +183,8 @@ pub struct Config { pub auto_join_channel_id: Option, pub stripe_api_key: Option, pub stripe_zed_pro_price_id: Option, + pub stripe_zed_pro_trial_price_id: Option, + pub stripe_zed_free_price_id: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -201,6 +203,29 @@ impl Config { } } + pub fn zed_pro_price_id(&self) -> anyhow::Result { + Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref()) + } + + pub fn zed_pro_trial_price_id(&self) -> anyhow::Result { + Self::parse_stripe_price_id( + "Zed Pro Trial", + self.stripe_zed_pro_trial_price_id.as_deref(), + ) + } + + pub fn zed_free_price_id(&self) -> anyhow::Result { + Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref()) + } + + fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result { + use std::str::FromStr as _; + + let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?; + + Ok(stripe::PriceId::from_str(price_id)?) + } + #[cfg(test)] pub fn test() -> Self { Self { @@ -239,6 +264,8 @@ impl Config { seed_path: None, stripe_api_key: None, stripe_zed_pro_price_id: None, + stripe_zed_pro_trial_price_id: None, + stripe_zed_free_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -324,12 +351,9 @@ impl AppState { llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client.clone().map(|stripe_client| { - Arc::new(StripeBilling::new( - stripe_client, - config.stripe_zed_pro_price_id.clone(), - )) - }), + stripe_billing: stripe_client + .clone() + .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), stripe_client, rate_limiter: Arc::new(RateLimiter::new(db)), executor, diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index edbeab1b0109743eefc72e90476359d4d0ff1ddb..44d27714ade2afa27e1ff62ebd8ca8652822bc2d 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,16 +1,16 @@ use std::sync::Arc; use crate::{Cents, Result, llm}; -use anyhow::{Context as _, anyhow}; +use anyhow::Context as _; use chrono::{Datelike, Utc}; use collections::HashMap; use serde::{Deserialize, Serialize}; +use stripe::PriceId; use tokio::sync::RwLock; pub struct StripeBilling { state: RwLock, client: Arc, - zed_pro_price_id: Option, } #[derive(Default)] @@ -32,11 +32,10 @@ struct StripeBillingPrice { } impl StripeBilling { - pub fn new(client: Arc, zed_pro_price_id: Option) -> Self { + pub fn new(client: Arc) -> Self { Self { client, state: RwLock::default(), - zed_pro_price_id, } } @@ -385,23 +384,19 @@ impl StripeBilling { Ok(session.url.context("no checkout session URL")?) } - pub async fn checkout_with_zed_pro( + pub async fn checkout_with_price( &self, + price_id: PriceId, customer_id: stripe::CustomerId, github_login: &str, success_url: &str, ) -> Result { - let zed_pro_price_id = self - .zed_pro_price_id - .as_ref() - .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?; - let mut params = stripe::CreateCheckoutSession::new(); params.mode = Some(stripe::CheckoutSessionMode::Subscription); params.customer = Some(customer_id); params.client_reference_id = Some(github_login); params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { - price: Some(zed_pro_price_id.clone()), + price: Some(price_id.to_string()), quantity: Some(1), ..Default::default() }]); diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 8c6130def2f3d249ced7144b99a013d826b9e879..f7b2d7ce6d7678b304717ad4ee0bd00049824996 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -558,6 +558,8 @@ impl TestServer { seed_path: None, stripe_api_key: None, stripe_zed_pro_price_id: None, + stripe_zed_pro_trial_price_id: None, + stripe_zed_free_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None,