@@ -15,10 +15,12 @@ use stripe::{
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
+ CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
+ CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
-use util::ResultExt;
+use util::{ResultExt, maybe};
use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{
@@ -159,6 +161,7 @@ struct BillingSubscriptionJson {
id: BillingSubscriptionId,
name: String,
status: StripeSubscriptionStatus,
+ trial_end_at: Option<String>,
cancel_at: Option<String>,
/// Whether this subscription can be canceled.
is_cancelable: bool,
@@ -188,9 +191,21 @@ async fn list_billing_subscriptions(
id: subscription.id,
name: match subscription.kind {
Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
+ Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
+ Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
None => "Zed LLM Usage".to_string(),
},
status: subscription.stripe_subscription_status,
+ trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
+ maybe!({
+ let end_at = subscription.stripe_current_period_end?;
+ let end_at = DateTime::from_timestamp(end_at, 0)?;
+
+ Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
+ })
+ } else {
+ None
+ },
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at
.and_utc()
@@ -207,6 +222,7 @@ async fn list_billing_subscriptions(
#[serde(rename_all = "snake_case")]
enum ProductCode {
ZedPro,
+ ZedProTrial,
}
#[derive(Debug, Deserialize)]
@@ -286,24 +302,36 @@ async fn create_billing_subscription(
customer.id
};
+ let success_url = format!(
+ "{}/account?checkout_complete=1",
+ app.config.zed_dot_dev_url()
+ );
+
let checkout_session_url = match body.product {
Some(ProductCode::ZedPro) => {
- let success_url = format!(
- "{}/account?checkout_complete=1",
- app.config.zed_dot_dev_url()
- );
stripe_billing
- .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+ .checkout_with_price(
+ app.config.zed_pro_price_id()?,
+ customer_id,
+ &user.github_login,
+ &success_url,
+ )
+ .await?
+ }
+ Some(ProductCode::ZedProTrial) => {
+ stripe_billing
+ .checkout_with_price(
+ app.config.zed_pro_trial_price_id()?,
+ customer_id,
+ &user.github_login,
+ &success_url,
+ )
.await?
}
None => {
let default_model =
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
- let success_url = format!(
- "{}/account?checkout_complete=1",
- app.config.zed_dot_dev_url()
- );
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
@@ -322,6 +350,8 @@ enum ManageSubscriptionIntent {
///
/// This will open the Stripe billing portal without putting the user in a specific flow.
ManageSubscription,
+ /// The user intends to upgrade to Zed Pro.
+ UpgradeToPro,
/// The user intends to cancel their subscription.
Cancel,
/// The user intends to stop the cancellation of their subscription.
@@ -373,11 +403,10 @@ async fn manage_billing_subscription(
.get_billing_subscription_by_id(body.subscription_id)
.await?
.ok_or_else(|| anyhow!("subscription not found"))?;
+ let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
+ .context("failed to parse subscription ID")?;
if body.intent == ManageSubscriptionIntent::StopCancellation {
- let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
- .context("failed to parse subscription ID")?;
-
let updated_stripe_subscription = Subscription::update(
&stripe_client,
&subscription_id,
@@ -410,6 +439,47 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
+ ManageSubscriptionIntent::UpgradeToPro => {
+ let zed_pro_price_id = app.config.zed_pro_price_id()?;
+ let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
+ let zed_free_price_id = app.config.zed_free_price_id()?;
+
+ let stripe_subscription =
+ Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
+
+ let subscription_item_to_update = stripe_subscription
+ .items
+ .data
+ .iter()
+ .find_map(|item| {
+ let price = item.price.as_ref()?;
+
+ if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
+ Some(item.id.clone())
+ } else {
+ None
+ }
+ })
+ .ok_or_else(|| anyhow!("No subscription item to update"))?;
+
+ Some(CreateBillingPortalSessionFlowData {
+ type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
+ subscription_update_confirm: Some(
+ CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
+ subscription: subscription.stripe_subscription_id,
+ items: vec![
+ CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
+ id: subscription_item_to_update.to_string(),
+ price: Some(zed_pro_price_id.to_string()),
+ quantity: Some(1),
+ },
+ ],
+ discounts: None,
+ },
+ ),
+ ..Default::default()
+ })
+ }
ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
@@ -696,22 +766,25 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
- let subscription_kind =
- if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
- let has_zed_pro_price = subscription.items.data.iter().any(|item| {
- item.price
- .as_ref()
- .map_or(false, |price| price.id.as_str() == zed_pro_price_id)
- });
+ let subscription_kind = maybe!({
+ let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
+ let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
+ let zed_free_price_id = app.config.zed_free_price_id().ok()?;
+
+ subscription.items.data.iter().find_map(|item| {
+ let price = item.price.as_ref()?;
- if has_zed_pro_price {
+ if price.id == zed_pro_price_id {
Some(SubscriptionKind::ZedPro)
+ } else if price.id == zed_pro_trial_price_id {
+ Some(SubscriptionKind::ZedProTrial)
+ } else if price.id == zed_free_price_id {
+ Some(SubscriptionKind::ZedFree)
} else {
None
}
- } else {
- None
- };
+ })
+ });
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
@@ -183,6 +183,8 @@ pub struct Config {
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>,
+ pub stripe_zed_pro_trial_price_id: Option<String>,
+ pub stripe_zed_free_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -201,6 +203,29 @@ impl Config {
}
}
+ pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+ Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
+ }
+
+ pub fn zed_pro_trial_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+ Self::parse_stripe_price_id(
+ "Zed Pro Trial",
+ self.stripe_zed_pro_trial_price_id.as_deref(),
+ )
+ }
+
+ pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
+ Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
+ }
+
+ fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
+ use std::str::FromStr as _;
+
+ let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
+
+ Ok(stripe::PriceId::from_str(price_id)?)
+ }
+
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -239,6 +264,8 @@ impl Config {
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
+ stripe_zed_pro_trial_price_id: None,
+ stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
@@ -324,12 +351,9 @@ impl AppState {
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
- stripe_billing: stripe_client.clone().map(|stripe_client| {
- Arc::new(StripeBilling::new(
- stripe_client,
- config.stripe_zed_pro_price_id.clone(),
- ))
- }),
+ stripe_billing: stripe_client
+ .clone()
+ .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
@@ -1,16 +1,16 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
-use anyhow::{Context as _, anyhow};
+use anyhow::Context as _;
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
+use stripe::PriceId;
use tokio::sync::RwLock;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
- zed_pro_price_id: Option<String>,
}
#[derive(Default)]
@@ -32,11 +32,10 @@ struct StripeBillingPrice {
}
impl StripeBilling {
- pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
+ pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client,
state: RwLock::default(),
- zed_pro_price_id,
}
}
@@ -385,23 +384,19 @@ impl StripeBilling {
Ok(session.url.context("no checkout session URL")?)
}
- pub async fn checkout_with_zed_pro(
+ pub async fn checkout_with_price(
&self,
+ price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
- let zed_pro_price_id = self
- .zed_pro_price_id
- .as_ref()
- .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
-
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
- price: Some(zed_pro_price_id.clone()),
+ price: Some(price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);