1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{extract, routing::post, Extension, Json, Router};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use stripe::{
10 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
11 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
12 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
13 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
14 CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
15 SubscriptionStatus,
16};
17use util::ResultExt;
18
19use crate::db::billing_subscription::StripeSubscriptionStatus;
20use crate::db::{
21 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
22 CreateBillingSubscriptionParams,
23};
24use crate::{AppState, Error, Result};
25
26pub fn router() -> Router {
27 Router::new()
28 .route("/billing/subscriptions", post(create_billing_subscription))
29 .route(
30 "/billing/subscriptions/manage",
31 post(manage_billing_subscription),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct CreateBillingSubscriptionBody {
37 github_user_id: i32,
38}
39
40#[derive(Debug, Serialize)]
41struct CreateBillingSubscriptionResponse {
42 checkout_session_url: String,
43}
44
45/// Initiates a Stripe Checkout session for creating a billing subscription.
46async fn create_billing_subscription(
47 Extension(app): Extension<Arc<AppState>>,
48 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
49) -> Result<Json<CreateBillingSubscriptionResponse>> {
50 let user = app
51 .db
52 .get_user_by_github_user_id(body.github_user_id)
53 .await?
54 .ok_or_else(|| anyhow!("user not found"))?;
55
56 let Some((stripe_client, stripe_price_id)) = app
57 .stripe_client
58 .clone()
59 .zip(app.config.stripe_price_id.clone())
60 else {
61 log::error!("failed to retrieve Stripe client or price ID");
62 Err(Error::Http(
63 StatusCode::NOT_IMPLEMENTED,
64 "not supported".into(),
65 ))?
66 };
67
68 let customer_id =
69 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
70 CustomerId::from_str(&existing_customer.stripe_customer_id)
71 .context("failed to parse customer ID")?
72 } else {
73 let customer = Customer::create(
74 &stripe_client,
75 CreateCustomer {
76 email: user.email_address.as_deref(),
77 ..Default::default()
78 },
79 )
80 .await?;
81
82 customer.id
83 };
84
85 let checkout_session = {
86 let mut params = CreateCheckoutSession::new();
87 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
88 params.customer = Some(customer_id);
89 params.client_reference_id = Some(user.github_login.as_str());
90 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
91 price: Some(stripe_price_id.to_string()),
92 quantity: Some(1),
93 ..Default::default()
94 }]);
95 params.success_url = Some("https://zed.dev/billing/success");
96
97 CheckoutSession::create(&stripe_client, params).await?
98 };
99
100 Ok(Json(CreateBillingSubscriptionResponse {
101 checkout_session_url: checkout_session
102 .url
103 .ok_or_else(|| anyhow!("no checkout session URL"))?,
104 }))
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(rename_all = "snake_case")]
109enum ManageSubscriptionIntent {
110 /// The user intends to cancel their subscription.
111 Cancel,
112}
113
114#[derive(Debug, Deserialize)]
115struct ManageBillingSubscriptionBody {
116 github_user_id: i32,
117 intent: ManageSubscriptionIntent,
118 /// The ID of the subscription to manage.
119 ///
120 /// If not provided, we will try to use the active subscription (if there is only one).
121 subscription_id: Option<BillingSubscriptionId>,
122}
123
124#[derive(Debug, Serialize)]
125struct ManageBillingSubscriptionResponse {
126 billing_portal_session_url: String,
127}
128
129/// Initiates a Stripe customer portal session for managing a billing subscription.
130async fn manage_billing_subscription(
131 Extension(app): Extension<Arc<AppState>>,
132 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
133) -> Result<Json<ManageBillingSubscriptionResponse>> {
134 let user = app
135 .db
136 .get_user_by_github_user_id(body.github_user_id)
137 .await?
138 .ok_or_else(|| anyhow!("user not found"))?;
139
140 let Some(stripe_client) = app.stripe_client.clone() else {
141 log::error!("failed to retrieve Stripe client");
142 Err(Error::Http(
143 StatusCode::NOT_IMPLEMENTED,
144 "not supported".into(),
145 ))?
146 };
147
148 let customer = app
149 .db
150 .get_billing_customer_by_user_id(user.id)
151 .await?
152 .ok_or_else(|| anyhow!("billing customer not found"))?;
153 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
154 .context("failed to parse customer ID")?;
155
156 let subscription = if let Some(subscription_id) = body.subscription_id {
157 app.db
158 .get_billing_subscription_by_id(subscription_id)
159 .await?
160 .ok_or_else(|| anyhow!("subscription not found"))?
161 } else {
162 // If no subscription ID was provided, try to find the only active subscription ID.
163 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
164 if subscriptions.len() > 1 {
165 Err(anyhow!("user has multiple active subscriptions"))?;
166 }
167
168 subscriptions
169 .into_iter()
170 .next()
171 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
172 };
173
174 let flow = match body.intent {
175 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
176 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
177 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
178 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
179 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
180 return_url: "https://zed.dev/billing".into(),
181 }),
182 ..Default::default()
183 }),
184 subscription_cancel: Some(
185 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
186 subscription: subscription.stripe_subscription_id,
187 retention: None,
188 },
189 ),
190 ..Default::default()
191 },
192 };
193
194 let mut params = CreateBillingPortalSession::new(customer_id);
195 params.flow_data = Some(flow);
196 params.return_url = Some("https://zed.dev/billing");
197
198 let session = BillingPortalSession::create(&stripe_client, params).await?;
199
200 Ok(Json(ManageBillingSubscriptionResponse {
201 billing_portal_session_url: session.url,
202 }))
203}
204
205const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
206
207/// Polls the Stripe events API periodically to reconcile the records in our
208/// database with the data in Stripe.
209pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
210 let Some(stripe_client) = app.stripe_client.clone() else {
211 log::warn!("failed to retrieve Stripe client");
212 return;
213 };
214
215 let executor = app.executor.clone();
216 executor.spawn_detached({
217 let executor = executor.clone();
218 async move {
219 loop {
220 poll_stripe_events(&app, &stripe_client).await.log_err();
221
222 executor.sleep(POLL_EVENTS_INTERVAL).await;
223 }
224 }
225 });
226}
227
228async fn poll_stripe_events(
229 app: &Arc<AppState>,
230 stripe_client: &stripe::Client,
231) -> anyhow::Result<()> {
232 let event_types = [
233 EventType::CustomerCreated.to_string(),
234 EventType::CustomerSubscriptionCreated.to_string(),
235 EventType::CustomerSubscriptionUpdated.to_string(),
236 EventType::CustomerSubscriptionPaused.to_string(),
237 EventType::CustomerSubscriptionResumed.to_string(),
238 EventType::CustomerSubscriptionDeleted.to_string(),
239 ]
240 .into_iter()
241 .map(|event_type| {
242 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
243 // so we need to unquote it.
244 event_type.trim_matches('"').to_string()
245 })
246 .collect::<Vec<_>>();
247
248 loop {
249 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
250
251 let mut params = ListEvents::new();
252 params.types = Some(event_types.clone());
253 params.limit = Some(100);
254
255 let events = stripe::Event::list(stripe_client, ¶ms).await?;
256 for event in events.data {
257 match event.type_ {
258 EventType::CustomerCreated => {
259 handle_customer_event(app, stripe_client, event)
260 .await
261 .log_err();
262 }
263 EventType::CustomerSubscriptionCreated
264 | EventType::CustomerSubscriptionUpdated
265 | EventType::CustomerSubscriptionPaused
266 | EventType::CustomerSubscriptionResumed
267 | EventType::CustomerSubscriptionDeleted => {
268 handle_customer_subscription_event(app, stripe_client, event)
269 .await
270 .log_err();
271 }
272 _ => {}
273 }
274 }
275
276 if !events.has_more {
277 break;
278 }
279 }
280
281 Ok(())
282}
283
284async fn handle_customer_event(
285 app: &Arc<AppState>,
286 stripe_client: &stripe::Client,
287 event: stripe::Event,
288) -> anyhow::Result<()> {
289 let EventObject::Customer(customer) = event.data.object else {
290 bail!("unexpected event payload for {}", event.id);
291 };
292
293 find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer)))
294 .await?;
295
296 Ok(())
297}
298
299async fn handle_customer_subscription_event(
300 app: &Arc<AppState>,
301 stripe_client: &stripe::Client,
302 event: stripe::Event,
303) -> anyhow::Result<()> {
304 let EventObject::Subscription(subscription) = event.data.object else {
305 bail!("unexpected event payload for {}", event.id);
306 };
307
308 let billing_customer =
309 find_or_create_billing_customer(app, stripe_client, subscription.customer)
310 .await?
311 .ok_or_else(|| anyhow!("billing customer not found"))?;
312
313 app.db
314 .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams {
315 billing_customer_id: billing_customer.id,
316 stripe_subscription_id: subscription.id.to_string(),
317 stripe_subscription_status: subscription.status.into(),
318 })
319 .await?;
320
321 Ok(())
322}
323
324impl From<SubscriptionStatus> for StripeSubscriptionStatus {
325 fn from(value: SubscriptionStatus) -> Self {
326 match value {
327 SubscriptionStatus::Incomplete => Self::Incomplete,
328 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
329 SubscriptionStatus::Trialing => Self::Trialing,
330 SubscriptionStatus::Active => Self::Active,
331 SubscriptionStatus::PastDue => Self::PastDue,
332 SubscriptionStatus::Canceled => Self::Canceled,
333 SubscriptionStatus::Unpaid => Self::Unpaid,
334 SubscriptionStatus::Paused => Self::Paused,
335 }
336 }
337}
338
339/// Finds or creates a billing customer using the provided customer.
340async fn find_or_create_billing_customer(
341 app: &Arc<AppState>,
342 stripe_client: &stripe::Client,
343 customer_or_id: Expandable<Customer>,
344) -> anyhow::Result<Option<billing_customer::Model>> {
345 let customer_id = match &customer_or_id {
346 Expandable::Id(id) => id,
347 Expandable::Object(customer) => customer.id.as_ref(),
348 };
349
350 // If we already have a billing customer record associated with the Stripe customer,
351 // there's nothing more we need to do.
352 if let Some(billing_customer) = app
353 .db
354 .get_billing_customer_by_stripe_customer_id(&customer_id)
355 .await?
356 {
357 return Ok(Some(billing_customer));
358 }
359
360 // If all we have is a customer ID, resolve it to a full customer record by
361 // hitting the Stripe API.
362 let customer = match customer_or_id {
363 Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
364 Expandable::Object(customer) => *customer,
365 };
366
367 let Some(email) = customer.email else {
368 return Ok(None);
369 };
370
371 let Some(user) = app.db.get_user_by_email(&email).await? else {
372 return Ok(None);
373 };
374
375 let billing_customer = app
376 .db
377 .create_billing_customer(&CreateBillingCustomerParams {
378 user_id: user.id,
379 stripe_customer_id: customer.id.to_string(),
380 })
381 .await?;
382
383 Ok(Some(billing_customer))
384}