@@ -25,7 +25,7 @@ use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
- StripeSubscriptionId, UpdateCustomerParams,
+ StripeSubscriptionId,
};
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
@@ -40,7 +40,6 @@ use crate::{
pub fn router() -> Router {
Router::new()
- .route("/billing/subscriptions", post(create_billing_subscription))
.route(
"/billing/subscriptions/manage",
post(manage_billing_subscription),
@@ -51,122 +50,6 @@ pub fn router() -> Router {
)
}
-#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
-#[serde(rename_all = "snake_case")]
-enum ProductCode {
- ZedPro,
- ZedProTrial,
-}
-
-#[derive(Debug, Deserialize)]
-struct CreateBillingSubscriptionBody {
- github_user_id: i32,
- product: ProductCode,
-}
-
-#[derive(Debug, Serialize)]
-struct CreateBillingSubscriptionResponse {
- checkout_session_url: String,
-}
-
-/// Initiates a Stripe Checkout session for creating a billing subscription.
-async fn create_billing_subscription(
- Extension(app): Extension<Arc<AppState>>,
- extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
-) -> Result<Json<CreateBillingSubscriptionResponse>> {
- let user = app
- .db
- .get_user_by_github_user_id(body.github_user_id)
- .await?
- .context("user not found")?;
-
- let Some(stripe_billing) = app.stripe_billing.clone() else {
- log::error!("failed to retrieve Stripe billing object");
- Err(Error::http(
- StatusCode::NOT_IMPLEMENTED,
- "not supported".into(),
- ))?
- };
-
- if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
- let is_checkout_allowed = body.product == ProductCode::ZedProTrial
- && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
-
- if !is_checkout_allowed {
- return Err(Error::http(
- StatusCode::CONFLICT,
- "user already has an active subscription".into(),
- ));
- }
- }
-
- let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
- if let Some(existing_billing_customer) = &existing_billing_customer {
- if existing_billing_customer.has_overdue_invoices {
- return Err(Error::http(
- StatusCode::PAYMENT_REQUIRED,
- "user has overdue invoices".into(),
- ));
- }
- }
-
- let customer_id = if let Some(existing_customer) = &existing_billing_customer {
- let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
- if let Some(email) = user.email_address.as_deref() {
- stripe_billing
- .client()
- .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
- .await
- // Update of email address is best-effort - continue checkout even if it fails
- .context("error updating stripe customer email address")
- .log_err();
- }
- customer_id
- } else {
- stripe_billing
- .find_or_create_customer_by_email(user.email_address.as_deref())
- .await?
- };
-
- let success_url = format!(
- "{}/account?checkout_complete=1",
- app.config.zed_dot_dev_url()
- );
-
- let checkout_session_url = match body.product {
- ProductCode::ZedPro => {
- stripe_billing
- .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
- .await?
- }
- ProductCode::ZedProTrial => {
- if let Some(existing_billing_customer) = &existing_billing_customer {
- if existing_billing_customer.trial_started_at.is_some() {
- return Err(Error::http(
- StatusCode::FORBIDDEN,
- "user already used free trial".into(),
- ));
- }
- }
-
- let feature_flags = app.db.get_user_flags(user.id).await?;
-
- stripe_billing
- .checkout_with_zed_pro_trial(
- &customer_id,
- &user.github_login,
- feature_flags,
- &success_url,
- )
- .await?
- }
- };
-
- Ok(Json(CreateBillingSubscriptionResponse {
- checkout_session_url,
- }))
-}
-
#[derive(Debug, PartialEq, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ManageSubscriptionIntent {
@@ -1,6 +1,6 @@
use std::sync::Arc;
-use anyhow::{Context as _, anyhow};
+use anyhow::anyhow;
use chrono::Utc;
use collections::HashMap;
use stripe::SubscriptionStatus;
@@ -9,18 +9,13 @@ use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
-use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
- RealStripeClient, StripeAutomaticTax, StripeBillingAddressCollection,
- StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
- StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
- StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
+ RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
- StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName,
- StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
+ StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
- StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
- UpdateSubscriptionItems, UpdateSubscriptionParams,
+ StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
+ UpdateSubscriptionParams,
};
pub struct StripeBilling {
@@ -214,95 +209,6 @@ impl StripeBilling {
Ok(())
}
- pub async fn checkout_with_zed_pro(
- &self,
- customer_id: &StripeCustomerId,
- github_login: &str,
- success_url: &str,
- ) -> Result<String> {
- let zed_pro_price_id = self.zed_pro_price_id().await?;
-
- let mut params = StripeCreateCheckoutSessionParams::default();
- params.mode = Some(StripeCheckoutSessionMode::Subscription);
- params.customer = Some(customer_id);
- params.client_reference_id = Some(github_login);
- params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
- price: Some(zed_pro_price_id.to_string()),
- quantity: Some(1),
- }]);
- params.success_url = Some(success_url);
- params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
- params.customer_update = Some(StripeCustomerUpdate {
- address: Some(StripeCustomerUpdateAddress::Auto),
- name: Some(StripeCustomerUpdateName::Auto),
- shipping: None,
- });
- params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
-
- let session = self.client.create_checkout_session(params).await?;
- Ok(session.url.context("no checkout session URL")?)
- }
-
- pub async fn checkout_with_zed_pro_trial(
- &self,
- customer_id: &StripeCustomerId,
- github_login: &str,
- feature_flags: Vec<String>,
- success_url: &str,
- ) -> Result<String> {
- let zed_pro_price_id = self.zed_pro_price_id().await?;
-
- let eligible_for_extended_trial = feature_flags
- .iter()
- .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
-
- let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
-
- let mut subscription_metadata = std::collections::HashMap::new();
- if eligible_for_extended_trial {
- subscription_metadata.insert(
- "promo_feature_flag".to_string(),
- AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
- );
- }
-
- let mut params = StripeCreateCheckoutSessionParams::default();
- params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
- trial_period_days: Some(trial_period_days),
- trial_settings: Some(StripeSubscriptionTrialSettings {
- end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
- missing_payment_method:
- StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
- },
- }),
- metadata: if !subscription_metadata.is_empty() {
- Some(subscription_metadata)
- } else {
- None
- },
- });
- params.mode = Some(StripeCheckoutSessionMode::Subscription);
- params.payment_method_collection =
- Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
- params.customer = Some(customer_id);
- params.client_reference_id = Some(github_login);
- params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
- price: Some(zed_pro_price_id.to_string()),
- quantity: Some(1),
- }]);
- params.success_url = Some(success_url);
- params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
- params.customer_update = Some(StripeCustomerUpdate {
- address: Some(StripeCustomerUpdateAddress::Auto),
- name: Some(StripeCustomerUpdateName::Auto),
- shipping: None,
- });
- params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
-
- let session = self.client.create_checkout_session(params).await?;
- Ok(session.url.context("no checkout session URL")?)
- }
-
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
@@ -3,17 +3,11 @@ use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
-use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
- FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
- StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
- StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate,
- StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice,
- StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
- StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
- StripeSubscriptionTrialSettingsEndBehavior,
- StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
+ FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
+ StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
+ StripeSubscriptionItemId, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -364,240 +358,3 @@ async fn test_bill_model_request_usage() {
);
assert_eq!(create_meter_event_calls[0].value, 73);
}
-
-#[gpui::test]
-async fn test_checkout_with_zed_pro() {
- let (stripe_billing, stripe_client) = make_stripe_billing();
-
- let customer_id = StripeCustomerId("cus_test".into());
- let github_login = "zeduser1";
- let success_url = "https://example.com/success";
-
- // It returns an error when the Zed Pro price doesn't exist.
- {
- let result = stripe_billing
- .checkout_with_zed_pro(&customer_id, github_login, success_url)
- .await;
-
- assert!(result.is_err());
- assert_eq!(
- result.err().unwrap().to_string(),
- r#"no price ID found for "zed-pro""#
- );
- }
-
- // Successful checkout.
- {
- let price = StripePrice {
- id: StripePriceId("price_1".into()),
- unit_amount: Some(2000),
- lookup_key: Some("zed-pro".to_string()),
- recurring: None,
- };
- stripe_client
- .prices
- .lock()
- .insert(price.id.clone(), price.clone());
-
- stripe_billing.initialize().await.unwrap();
-
- let checkout_url = stripe_billing
- .checkout_with_zed_pro(&customer_id, github_login, success_url)
- .await
- .unwrap();
-
- assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
-
- let create_checkout_session_calls = stripe_client
- .create_checkout_session_calls
- .lock()
- .drain(..)
- .collect::<Vec<_>>();
- assert_eq!(create_checkout_session_calls.len(), 1);
- let call = create_checkout_session_calls.into_iter().next().unwrap();
- assert_eq!(call.customer, Some(customer_id));
- assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
- assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
- assert_eq!(
- call.line_items,
- Some(vec![StripeCreateCheckoutSessionLineItems {
- price: Some(price.id.to_string()),
- quantity: Some(1)
- }])
- );
- assert_eq!(call.payment_method_collection, None);
- assert_eq!(call.subscription_data, None);
- assert_eq!(call.success_url.as_deref(), Some(success_url));
- assert_eq!(
- call.billing_address_collection,
- Some(StripeBillingAddressCollection::Required)
- );
- assert_eq!(
- call.customer_update,
- Some(StripeCustomerUpdate {
- address: Some(StripeCustomerUpdateAddress::Auto),
- name: Some(StripeCustomerUpdateName::Auto),
- shipping: None,
- })
- );
- }
-}
-
-#[gpui::test]
-async fn test_checkout_with_zed_pro_trial() {
- let (stripe_billing, stripe_client) = make_stripe_billing();
-
- let customer_id = StripeCustomerId("cus_test".into());
- let github_login = "zeduser1";
- let success_url = "https://example.com/success";
-
- // It returns an error when the Zed Pro price doesn't exist.
- {
- let result = stripe_billing
- .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
- .await;
-
- assert!(result.is_err());
- assert_eq!(
- result.err().unwrap().to_string(),
- r#"no price ID found for "zed-pro""#
- );
- }
-
- let price = StripePrice {
- id: StripePriceId("price_1".into()),
- unit_amount: Some(2000),
- lookup_key: Some("zed-pro".to_string()),
- recurring: None,
- };
- stripe_client
- .prices
- .lock()
- .insert(price.id.clone(), price.clone());
-
- stripe_billing.initialize().await.unwrap();
-
- // Successful checkout.
- {
- let checkout_url = stripe_billing
- .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
- .await
- .unwrap();
-
- assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
-
- let create_checkout_session_calls = stripe_client
- .create_checkout_session_calls
- .lock()
- .drain(..)
- .collect::<Vec<_>>();
- assert_eq!(create_checkout_session_calls.len(), 1);
- let call = create_checkout_session_calls.into_iter().next().unwrap();
- assert_eq!(call.customer.as_ref(), Some(&customer_id));
- assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
- assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
- assert_eq!(
- call.line_items,
- Some(vec![StripeCreateCheckoutSessionLineItems {
- price: Some(price.id.to_string()),
- quantity: Some(1)
- }])
- );
- assert_eq!(
- call.payment_method_collection,
- Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
- );
- assert_eq!(
- call.subscription_data,
- Some(StripeCreateCheckoutSessionSubscriptionData {
- trial_period_days: Some(14),
- trial_settings: Some(StripeSubscriptionTrialSettings {
- end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
- missing_payment_method:
- StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
- },
- }),
- metadata: None,
- })
- );
- assert_eq!(call.success_url.as_deref(), Some(success_url));
- assert_eq!(
- call.billing_address_collection,
- Some(StripeBillingAddressCollection::Required)
- );
- assert_eq!(
- call.customer_update,
- Some(StripeCustomerUpdate {
- address: Some(StripeCustomerUpdateAddress::Auto),
- name: Some(StripeCustomerUpdateName::Auto),
- shipping: None,
- })
- );
- }
-
- // Successful checkout with extended trial.
- {
- let checkout_url = stripe_billing
- .checkout_with_zed_pro_trial(
- &customer_id,
- github_login,
- vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
- success_url,
- )
- .await
- .unwrap();
-
- assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
-
- let create_checkout_session_calls = stripe_client
- .create_checkout_session_calls
- .lock()
- .drain(..)
- .collect::<Vec<_>>();
- assert_eq!(create_checkout_session_calls.len(), 1);
- let call = create_checkout_session_calls.into_iter().next().unwrap();
- assert_eq!(call.customer, Some(customer_id));
- assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
- assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
- assert_eq!(
- call.line_items,
- Some(vec![StripeCreateCheckoutSessionLineItems {
- price: Some(price.id.to_string()),
- quantity: Some(1)
- }])
- );
- assert_eq!(
- call.payment_method_collection,
- Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
- );
- assert_eq!(
- call.subscription_data,
- Some(StripeCreateCheckoutSessionSubscriptionData {
- trial_period_days: Some(60),
- trial_settings: Some(StripeSubscriptionTrialSettings {
- end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
- missing_payment_method:
- StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
- },
- }),
- metadata: Some(std::collections::HashMap::from_iter([(
- "promo_feature_flag".into(),
- AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
- )])),
- })
- );
- assert_eq!(call.success_url.as_deref(), Some(success_url));
- assert_eq!(
- call.billing_address_collection,
- Some(StripeBillingAddressCollection::Required)
- );
- assert_eq!(
- call.customer_update,
- Some(StripeCustomerUpdate {
- address: Some(StripeCustomerUpdateAddress::Auto),
- name: Some(StripeCustomerUpdateName::Auto),
- shipping: None,
- })
- );
- }
-}