Detailed changes
@@ -198,9 +198,16 @@ async fn list_billing_subscriptions(
}))
}
+#[derive(Debug, Clone, Copy, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ProductCode {
+ ZedPro,
+}
+
#[derive(Debug, Deserialize)]
struct CreateBillingSubscriptionBody {
github_user_id: i32,
+ product: Option<ProductCode>,
}
#[derive(Debug, Serialize)]
@@ -274,15 +281,30 @@ async fn create_billing_subscription(
customer.id
};
- let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
- let stripe_model = stripe_billing.register_model(default_model).await?;
- let success_url = format!(
- "{}/account?checkout_complete=1",
- app.config.zed_dot_dev_url()
- );
- let checkout_session_url = stripe_billing
- .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
- .await?;
+ let checkout_session_url = match body.product {
+ Some(ProductCode::ZedPro) => {
+ let success_url = format!(
+ "{}/account?checkout_complete=1",
+ app.config.zed_dot_dev_url()
+ );
+ stripe_billing
+ .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
+ .await?
+ }
+ None => {
+ let default_model =
+ llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
+ let stripe_model = stripe_billing.register_model(default_model).await?;
+ let success_url = format!(
+ "{}/account?checkout_complete=1",
+ app.config.zed_dot_dev_url()
+ );
+ stripe_billing
+ .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
+ .await?
+ }
+ };
+
Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url,
}))
@@ -182,6 +182,7 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
+ pub stripe_zed_pro_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -237,6 +238,7 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
+ stripe_zed_pro_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
@@ -322,9 +324,12 @@ impl AppState {
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
- stripe_billing: stripe_client
- .clone()
- .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
+ stripe_billing: stripe_client.clone().map(|stripe_client| {
+ Arc::new(StripeBilling::new(
+ stripe_client,
+ config.stripe_zed_pro_price_id.clone(),
+ ))
+ }),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
@@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
-use anyhow::Context as _;
+use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
@@ -10,6 +10,7 @@ use tokio::sync::RwLock;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
+ zed_pro_price_id: Option<String>,
}
#[derive(Default)]
@@ -31,10 +32,11 @@ struct StripeBillingPrice {
}
impl StripeBilling {
- pub fn new(client: Arc<stripe::Client>) -> Self {
+ pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
Self {
client,
state: RwLock::default(),
+ zed_pro_price_id,
}
}
@@ -382,6 +384,32 @@ impl StripeBilling {
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
+
+ pub async fn checkout_with_zed_pro(
+ &self,
+ customer_id: stripe::CustomerId,
+ github_login: &str,
+ success_url: &str,
+ ) -> Result<String> {
+ let zed_pro_price_id = self
+ .zed_pro_price_id
+ .as_ref()
+ .ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
+
+ let mut params = stripe::CreateCheckoutSession::new();
+ params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+ params.customer = Some(customer_id);
+ params.client_reference_id = Some(github_login);
+ params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
+ price: Some(zed_pro_price_id.clone()),
+ quantity: Some(1),
+ ..Default::default()
+ }]);
+ params.success_url = Some(success_url);
+
+ let session = stripe::CheckoutSession::create(&self.client, params).await?;
+ Ok(session.url.context("no checkout session URL")?)
+ }
}
#[derive(Serialize)]
@@ -557,6 +557,7 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
+ stripe_zed_pro_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,