Detailed changes
@@ -834,6 +834,26 @@ dependencies = [
"syn 2.0.59",
]
+[[package]]
+name = "async-stripe"
+version = "0.37.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5"
+dependencies = [
+ "futures-util",
+ "http-types",
+ "hyper",
+ "hyper-rustls",
+ "serde",
+ "serde_json",
+ "serde_path_to_error",
+ "serde_qs 0.10.1",
+ "smart-default",
+ "smol_str",
+ "thiserror",
+ "tokio",
+]
+
[[package]]
name = "async-tar"
version = "0.4.2"
@@ -1462,6 +1482,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
+[[package]]
+name = "base64"
+version = "0.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
+
[[package]]
name = "base64"
version = "0.21.7"
@@ -2425,6 +2451,7 @@ dependencies = [
"anthropic",
"anyhow",
"assistant",
+ "async-stripe",
"async-trait",
"async-tungstenite",
"audio",
@@ -5254,6 +5281,27 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
+[[package]]
+name = "http-types"
+version = "2.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
+dependencies = [
+ "anyhow",
+ "async-channel 1.9.0",
+ "base64 0.13.1",
+ "futures-lite 1.13.0",
+ "http 0.2.9",
+ "infer",
+ "pin-project-lite",
+ "rand 0.7.3",
+ "serde",
+ "serde_json",
+ "serde_qs 0.8.5",
+ "serde_urlencoded",
+ "url",
+]
+
[[package]]
name = "http_client"
version = "0.1.0"
@@ -5512,6 +5560,12 @@ version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
+[[package]]
+name = "infer"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
+
[[package]]
name = "inherent"
version = "1.0.10"
@@ -9564,6 +9618,28 @@ dependencies = [
"serde",
]
+[[package]]
+name = "serde_qs"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
+dependencies = [
+ "percent-encoding",
+ "serde",
+ "thiserror",
+]
+
+[[package]]
+name = "serde_qs"
+version = "0.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
+dependencies = [
+ "percent-encoding",
+ "serde",
+ "thiserror",
+]
+
[[package]]
name = "serde_repr"
version = "0.1.16"
@@ -9880,6 +9956,17 @@ dependencies = [
"serde",
]
+[[package]]
+name = "smart-default"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
[[package]]
name = "smol"
version = "1.3.0"
@@ -9897,6 +9984,15 @@ dependencies = [
"futures-lite 1.13.0",
]
+[[package]]
+name = "smol_str"
+version = "0.1.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
+dependencies = [
+ "serde",
+]
+
[[package]]
name = "snippet"
version = "0.1.0"
@@ -309,6 +309,7 @@ async-dispatcher = "0.1"
async-fs = "1.6"
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
async-recursion = "1.0.0"
+async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] }
async-tar = "0.4.2"
async-trait = "0.1"
async-tungstenite = "0.23"
@@ -20,6 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
+async-stripe.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
@@ -116,3 +117,6 @@ util.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
headless.workspace = true
+
+[package.metadata.cargo-machete]
+ignored = ["async-stripe"]
@@ -1,3 +1,4 @@
+pub mod billing;
pub mod contributors;
pub mod events;
pub mod extensions;
@@ -31,6 +32,7 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
.route("/user", get(get_authenticated_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
+ .merge(billing::router())
.merge(contributors::router())
.layer(
ServiceBuilder::new()
@@ -0,0 +1,88 @@
+use std::str::FromStr;
+use std::sync::Arc;
+
+use anyhow::anyhow;
+use axum::{extract, routing::post, Extension, Json, Router};
+use collections::HashSet;
+use reqwest::StatusCode;
+use serde::{Deserialize, Serialize};
+use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
+
+use crate::{AppState, Error, Result};
+
+pub fn router() -> Router {
+ Router::new().route("/billing/subscriptions", post(create_billing_subscription))
+}
+
+#[derive(Debug, Deserialize)]
+struct CreateBillingSubscriptionBody {
+ github_user_id: i32,
+}
+
+#[derive(Debug, Serialize)]
+struct CreateBillingSubscriptionResponse {
+ checkout_session_url: String,
+}
+
+/// Initiates a Stripe Checkout session for creating a billing subscription.
+async fn create_billing_subscription(
+ Extension(app): Extension<Arc<AppState>>,
+ extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
+) -> Result<Json<CreateBillingSubscriptionResponse>> {
+ let user = app
+ .db
+ .get_user_by_github_user_id(body.github_user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let Some((stripe_client, stripe_price_id)) = app
+ .stripe_client
+ .clone()
+ .zip(app.config.stripe_price_id.clone())
+ else {
+ log::error!("failed to retrieve Stripe client or price ID");
+ Err(Error::Http(
+ StatusCode::NOT_IMPLEMENTED,
+ "not supported".into(),
+ ))?
+ };
+
+ let existing_customer_id = {
+ let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
+ let distinct_customer_ids = existing_subscriptions
+ .iter()
+ .map(|subscription| subscription.stripe_customer_id.as_str())
+ .collect::<HashSet<_>>();
+ // Sanity: Make sure we can determine a single Stripe customer ID for the user.
+ if distinct_customer_ids.len() > 1 {
+ Err(anyhow!("user has multiple existing customer IDs"))?;
+ }
+
+ distinct_customer_ids
+ .into_iter()
+ .next()
+ .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
+ .transpose()
+ }?;
+
+ let checkout_session = {
+ let mut params = CreateCheckoutSession::new();
+ params.mode = Some(stripe::CheckoutSessionMode::Subscription);
+ params.customer = existing_customer_id;
+ params.client_reference_id = Some(user.github_login.as_str());
+ params.line_items = Some(vec![CreateCheckoutSessionLineItems {
+ price: Some(stripe_price_id.to_string()),
+ quantity: Some(1),
+ ..Default::default()
+ }]);
+ params.success_url = Some("https://zed.dev/billing/success");
+
+ CheckoutSession::create(&stripe_client, params).await?
+ };
+
+ Ok(Json(CreateBillingSubscriptionResponse {
+ checkout_session_url: checkout_session
+ .url
+ .ok_or_else(|| anyhow!("no checkout session URL"))?,
+ }))
+}
@@ -32,6 +32,26 @@ impl Database {
.await
}
+ /// Returns all of the billing subscriptions for the user with the specified ID.
+ ///
+ /// Note that this returns the subscriptions regardless of their status.
+ /// If you're wanting to check if a use has an active billing subscription,
+ /// use `get_active_billing_subscriptions` instead.
+ pub async fn get_billing_subscriptions(
+ &self,
+ user_id: UserId,
+ ) -> Result<Vec<billing_subscription::Model>> {
+ self.transaction(|tx| async move {
+ let subscriptions = billing_subscription::Entity::find()
+ .filter(billing_subscription::Column::UserId.eq(user_id))
+ .all(&*tx)
+ .await?;
+
+ Ok(subscriptions)
+ })
+ .await
+ }
+
/// Returns all of the active billing subscriptions for the user with the specified ID.
pub async fn get_active_billing_subscriptions(
&self,
@@ -61,6 +61,17 @@ impl Database {
.await
}
+ /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally.
+ pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result<Option<User>> {
+ self.transaction(|tx| async move {
+ Ok(user::Entity::find()
+ .filter(user::Column::GithubUserId.eq(github_user_id))
+ .one(&*tx)
+ .await?)
+ })
+ .await
+ }
+
/// Returns a user by GitHub login. There are no access checks here, so this should only be used internally.
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
self.transaction(|tx| async move {
@@ -26,6 +26,7 @@ pub enum Error {
Http(StatusCode, String),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
+ Stripe(stripe::StripeError),
}
impl From<anyhow::Error> for Error {
@@ -40,6 +41,12 @@ impl From<sea_orm::error::DbErr> for Error {
}
}
+impl From<stripe::StripeError> for Error {
+ fn from(error: stripe::StripeError) -> Self {
+ Self::Stripe(error)
+ }
+}
+
impl From<axum::Error> for Error {
fn from(error: axum::Error) -> Self {
Self::Internal(error.into())
@@ -81,6 +88,14 @@ impl IntoResponse for Error {
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
+ Error::Stripe(error) => {
+ log::error!(
+ "HTTP error {}: {:?}",
+ StatusCode::INTERNAL_SERVER_ERROR,
+ &error
+ );
+ (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
+ }
}
}
}
@@ -91,6 +106,7 @@ impl std::fmt::Debug for Error {
Error::Http(code, message) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
+ Error::Stripe(error) => error.fmt(f),
}
}
}
@@ -101,6 +117,7 @@ impl std::fmt::Display for Error {
Error::Http(code, message) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
+ Error::Stripe(error) => error.fmt(f),
}
}
}
@@ -137,6 +154,8 @@ pub struct Config {
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
+ pub stripe_api_key: Option<String>,
+ pub stripe_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
}
@@ -150,6 +169,7 @@ pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
+ pub stripe_client: Option<Arc<stripe::Client>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub clickhouse_client: Option<clickhouse::Client>,
@@ -183,6 +203,10 @@ impl AppState {
db: db.clone(),
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
+ stripe_client: build_stripe_client(&config)
+ .await
+ .map(|client| Arc::new(client))
+ .log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
@@ -195,6 +219,15 @@ impl AppState {
}
}
+async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
+ let api_key = config
+ .stripe_api_key
+ .as_ref()
+ .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
+
+ Ok(stripe::Client::new(api_key))
+}
+
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
let keys = aws_sdk_s3::config::Credentials::new(
config
@@ -637,6 +637,7 @@ impl TestServer {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
+ stripe_client: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
clickhouse_client: None,
@@ -669,6 +670,8 @@ impl TestServer {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
+ stripe_api_key: None,
+ stripe_price_id: None,
supermaven_admin_api_key: None,
},
})