1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{
7 extract::{self, Query},
8 routing::{get, post},
9 Extension, Json, Router,
10};
11use reqwest::StatusCode;
12use sea_orm::ActiveValue;
13use serde::{Deserialize, Serialize};
14use stripe::{
15 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
19 CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
20 SubscriptionStatus,
21};
22use util::ResultExt;
23
24use crate::db::billing_subscription::StripeSubscriptionStatus;
25use crate::db::{
26 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
27 CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
28 UpdateBillingSubscriptionParams,
29};
30use crate::{AppState, Error, Result};
31
32pub fn router() -> Router {
33 Router::new()
34 .route(
35 "/billing/subscriptions",
36 get(list_billing_subscriptions).post(create_billing_subscription),
37 )
38 .route(
39 "/billing/subscriptions/manage",
40 post(manage_billing_subscription),
41 )
42}
43
44#[derive(Debug, Deserialize)]
45struct ListBillingSubscriptionsParams {
46 github_user_id: i32,
47}
48
49#[derive(Debug, Serialize)]
50struct BillingSubscriptionJson {
51 id: BillingSubscriptionId,
52 name: String,
53 status: StripeSubscriptionStatus,
54 /// Whether this subscription can be canceled.
55 is_cancelable: bool,
56}
57
58#[derive(Debug, Serialize)]
59struct ListBillingSubscriptionsResponse {
60 subscriptions: Vec<BillingSubscriptionJson>,
61}
62
63async fn list_billing_subscriptions(
64 Extension(app): Extension<Arc<AppState>>,
65 Query(params): Query<ListBillingSubscriptionsParams>,
66) -> Result<Json<ListBillingSubscriptionsResponse>> {
67 let user = app
68 .db
69 .get_user_by_github_user_id(params.github_user_id)
70 .await?
71 .ok_or_else(|| anyhow!("user not found"))?;
72
73 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
74
75 Ok(Json(ListBillingSubscriptionsResponse {
76 subscriptions: subscriptions
77 .into_iter()
78 .map(|subscription| BillingSubscriptionJson {
79 id: subscription.id,
80 name: "Zed Pro".to_string(),
81 status: subscription.stripe_subscription_status,
82 is_cancelable: subscription.stripe_subscription_status.is_cancelable(),
83 })
84 .collect(),
85 }))
86}
87
88#[derive(Debug, Deserialize)]
89struct CreateBillingSubscriptionBody {
90 github_user_id: i32,
91}
92
93#[derive(Debug, Serialize)]
94struct CreateBillingSubscriptionResponse {
95 checkout_session_url: String,
96}
97
98/// Initiates a Stripe Checkout session for creating a billing subscription.
99async fn create_billing_subscription(
100 Extension(app): Extension<Arc<AppState>>,
101 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
102) -> Result<Json<CreateBillingSubscriptionResponse>> {
103 let user = app
104 .db
105 .get_user_by_github_user_id(body.github_user_id)
106 .await?
107 .ok_or_else(|| anyhow!("user not found"))?;
108
109 let Some((stripe_client, stripe_price_id)) = app
110 .stripe_client
111 .clone()
112 .zip(app.config.stripe_price_id.clone())
113 else {
114 log::error!("failed to retrieve Stripe client or price ID");
115 Err(Error::Http(
116 StatusCode::NOT_IMPLEMENTED,
117 "not supported".into(),
118 ))?
119 };
120
121 let customer_id =
122 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
123 CustomerId::from_str(&existing_customer.stripe_customer_id)
124 .context("failed to parse customer ID")?
125 } else {
126 let customer = Customer::create(
127 &stripe_client,
128 CreateCustomer {
129 email: user.email_address.as_deref(),
130 ..Default::default()
131 },
132 )
133 .await?;
134
135 customer.id
136 };
137
138 let checkout_session = {
139 let mut params = CreateCheckoutSession::new();
140 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
141 params.customer = Some(customer_id);
142 params.client_reference_id = Some(user.github_login.as_str());
143 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
144 price: Some(stripe_price_id.to_string()),
145 quantity: Some(1),
146 ..Default::default()
147 }]);
148 params.success_url = Some("https://zed.dev/billing/success");
149
150 CheckoutSession::create(&stripe_client, params).await?
151 };
152
153 Ok(Json(CreateBillingSubscriptionResponse {
154 checkout_session_url: checkout_session
155 .url
156 .ok_or_else(|| anyhow!("no checkout session URL"))?,
157 }))
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "snake_case")]
162enum ManageSubscriptionIntent {
163 /// The user intends to cancel their subscription.
164 Cancel,
165}
166
167#[derive(Debug, Deserialize)]
168struct ManageBillingSubscriptionBody {
169 github_user_id: i32,
170 intent: ManageSubscriptionIntent,
171 /// The ID of the subscription to manage.
172 ///
173 /// If not provided, we will try to use the active subscription (if there is only one).
174 subscription_id: Option<BillingSubscriptionId>,
175}
176
177#[derive(Debug, Serialize)]
178struct ManageBillingSubscriptionResponse {
179 billing_portal_session_url: String,
180}
181
182/// Initiates a Stripe customer portal session for managing a billing subscription.
183async fn manage_billing_subscription(
184 Extension(app): Extension<Arc<AppState>>,
185 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
186) -> Result<Json<ManageBillingSubscriptionResponse>> {
187 let user = app
188 .db
189 .get_user_by_github_user_id(body.github_user_id)
190 .await?
191 .ok_or_else(|| anyhow!("user not found"))?;
192
193 let Some(stripe_client) = app.stripe_client.clone() else {
194 log::error!("failed to retrieve Stripe client");
195 Err(Error::Http(
196 StatusCode::NOT_IMPLEMENTED,
197 "not supported".into(),
198 ))?
199 };
200
201 let customer = app
202 .db
203 .get_billing_customer_by_user_id(user.id)
204 .await?
205 .ok_or_else(|| anyhow!("billing customer not found"))?;
206 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
207 .context("failed to parse customer ID")?;
208
209 let subscription = if let Some(subscription_id) = body.subscription_id {
210 app.db
211 .get_billing_subscription_by_id(subscription_id)
212 .await?
213 .ok_or_else(|| anyhow!("subscription not found"))?
214 } else {
215 // If no subscription ID was provided, try to find the only active subscription ID.
216 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
217 if subscriptions.len() > 1 {
218 Err(anyhow!("user has multiple active subscriptions"))?;
219 }
220
221 subscriptions
222 .into_iter()
223 .next()
224 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
225 };
226
227 let flow = match body.intent {
228 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
229 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
230 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
231 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
232 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
233 return_url: "https://zed.dev/settings".into(),
234 }),
235 ..Default::default()
236 }),
237 subscription_cancel: Some(
238 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
239 subscription: subscription.stripe_subscription_id,
240 retention: None,
241 },
242 ),
243 ..Default::default()
244 },
245 };
246
247 let mut params = CreateBillingPortalSession::new(customer_id);
248 params.flow_data = Some(flow);
249 params.return_url = Some("https://zed.dev/settings");
250
251 let session = BillingPortalSession::create(&stripe_client, params).await?;
252
253 Ok(Json(ManageBillingSubscriptionResponse {
254 billing_portal_session_url: session.url,
255 }))
256}
257
258const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
259
260/// Polls the Stripe events API periodically to reconcile the records in our
261/// database with the data in Stripe.
262pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
263 let Some(stripe_client) = app.stripe_client.clone() else {
264 log::warn!("failed to retrieve Stripe client");
265 return;
266 };
267
268 let executor = app.executor.clone();
269 executor.spawn_detached({
270 let executor = executor.clone();
271 async move {
272 loop {
273 poll_stripe_events(&app, &stripe_client).await.log_err();
274
275 executor.sleep(POLL_EVENTS_INTERVAL).await;
276 }
277 }
278 });
279}
280
281async fn poll_stripe_events(
282 app: &Arc<AppState>,
283 stripe_client: &stripe::Client,
284) -> anyhow::Result<()> {
285 fn event_type_to_string(event_type: EventType) -> String {
286 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
287 // so we need to unquote it.
288 event_type.to_string().trim_matches('"').to_string()
289 }
290
291 let event_types = [
292 EventType::CustomerCreated,
293 EventType::CustomerUpdated,
294 EventType::CustomerSubscriptionCreated,
295 EventType::CustomerSubscriptionUpdated,
296 EventType::CustomerSubscriptionPaused,
297 EventType::CustomerSubscriptionResumed,
298 EventType::CustomerSubscriptionDeleted,
299 ]
300 .into_iter()
301 .map(event_type_to_string)
302 .collect::<Vec<_>>();
303
304 let mut unprocessed_events = Vec::new();
305
306 loop {
307 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
308
309 let mut params = ListEvents::new();
310 params.types = Some(event_types.clone());
311 params.limit = Some(100);
312
313 let events = stripe::Event::list(stripe_client, ¶ms).await?;
314
315 let processed_event_ids = {
316 let event_ids = &events
317 .data
318 .iter()
319 .map(|event| event.id.as_str())
320 .collect::<Vec<_>>();
321
322 app.db
323 .get_processed_stripe_events_by_event_ids(event_ids)
324 .await?
325 .into_iter()
326 .map(|event| event.stripe_event_id)
327 .collect::<Vec<_>>()
328 };
329
330 for event in events.data {
331 if processed_event_ids.contains(&event.id.to_string()) {
332 log::info!("Stripe event {} already processed: skipping", event.id);
333 } else {
334 unprocessed_events.push(event);
335 }
336 }
337
338 if !events.has_more {
339 break;
340 }
341 }
342
343 log::info!(
344 "unprocessed events from Stripe: {}",
345 unprocessed_events.len()
346 );
347
348 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
349 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
350
351 for event in unprocessed_events {
352 let processed_event_params = CreateProcessedStripeEventParams {
353 stripe_event_id: event.id.to_string(),
354 stripe_event_type: event_type_to_string(event.type_),
355 stripe_event_created_timestamp: event.created,
356 };
357
358 match event.type_ {
359 EventType::CustomerCreated | EventType::CustomerUpdated => {
360 handle_customer_event(app, stripe_client, event)
361 .await
362 .log_err();
363 }
364 EventType::CustomerSubscriptionCreated
365 | EventType::CustomerSubscriptionUpdated
366 | EventType::CustomerSubscriptionPaused
367 | EventType::CustomerSubscriptionResumed
368 | EventType::CustomerSubscriptionDeleted => {
369 handle_customer_subscription_event(app, stripe_client, event)
370 .await
371 .log_err();
372 }
373 _ => {}
374 }
375
376 app.db
377 .create_processed_stripe_event(&processed_event_params)
378 .await?;
379 }
380
381 Ok(())
382}
383
384async fn handle_customer_event(
385 app: &Arc<AppState>,
386 _stripe_client: &stripe::Client,
387 event: stripe::Event,
388) -> anyhow::Result<()> {
389 let EventObject::Customer(customer) = event.data.object else {
390 bail!("unexpected event payload for {}", event.id);
391 };
392
393 log::info!("handling Stripe {} event: {}", event.type_, event.id);
394
395 let Some(email) = customer.email else {
396 log::info!("Stripe customer has no email: skipping");
397 return Ok(());
398 };
399
400 let Some(user) = app.db.get_user_by_email(&email).await? else {
401 log::info!("no user found for email: skipping");
402 return Ok(());
403 };
404
405 if let Some(existing_customer) = app
406 .db
407 .get_billing_customer_by_stripe_customer_id(&customer.id)
408 .await?
409 {
410 app.db
411 .update_billing_customer(
412 existing_customer.id,
413 &UpdateBillingCustomerParams {
414 // For now we just leave the information as-is, as it is not
415 // likely to change.
416 ..Default::default()
417 },
418 )
419 .await?;
420 } else {
421 app.db
422 .create_billing_customer(&CreateBillingCustomerParams {
423 user_id: user.id,
424 stripe_customer_id: customer.id.to_string(),
425 })
426 .await?;
427 }
428
429 Ok(())
430}
431
432async fn handle_customer_subscription_event(
433 app: &Arc<AppState>,
434 stripe_client: &stripe::Client,
435 event: stripe::Event,
436) -> anyhow::Result<()> {
437 let EventObject::Subscription(subscription) = event.data.object else {
438 bail!("unexpected event payload for {}", event.id);
439 };
440
441 log::info!("handling Stripe {} event: {}", event.type_, event.id);
442
443 let billing_customer =
444 find_or_create_billing_customer(app, stripe_client, subscription.customer)
445 .await?
446 .ok_or_else(|| anyhow!("billing customer not found"))?;
447
448 if let Some(existing_subscription) = app
449 .db
450 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
451 .await?
452 {
453 app.db
454 .update_billing_subscription(
455 existing_subscription.id,
456 &UpdateBillingSubscriptionParams {
457 billing_customer_id: ActiveValue::set(billing_customer.id),
458 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
459 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
460 },
461 )
462 .await?;
463 } else {
464 app.db
465 .create_billing_subscription(&CreateBillingSubscriptionParams {
466 billing_customer_id: billing_customer.id,
467 stripe_subscription_id: subscription.id.to_string(),
468 stripe_subscription_status: subscription.status.into(),
469 })
470 .await?;
471 }
472
473 Ok(())
474}
475
476impl From<SubscriptionStatus> for StripeSubscriptionStatus {
477 fn from(value: SubscriptionStatus) -> Self {
478 match value {
479 SubscriptionStatus::Incomplete => Self::Incomplete,
480 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
481 SubscriptionStatus::Trialing => Self::Trialing,
482 SubscriptionStatus::Active => Self::Active,
483 SubscriptionStatus::PastDue => Self::PastDue,
484 SubscriptionStatus::Canceled => Self::Canceled,
485 SubscriptionStatus::Unpaid => Self::Unpaid,
486 SubscriptionStatus::Paused => Self::Paused,
487 }
488 }
489}
490
491/// Finds or creates a billing customer using the provided customer.
492async fn find_or_create_billing_customer(
493 app: &Arc<AppState>,
494 stripe_client: &stripe::Client,
495 customer_or_id: Expandable<Customer>,
496) -> anyhow::Result<Option<billing_customer::Model>> {
497 let customer_id = match &customer_or_id {
498 Expandable::Id(id) => id,
499 Expandable::Object(customer) => customer.id.as_ref(),
500 };
501
502 // If we already have a billing customer record associated with the Stripe customer,
503 // there's nothing more we need to do.
504 if let Some(billing_customer) = app
505 .db
506 .get_billing_customer_by_stripe_customer_id(&customer_id)
507 .await?
508 {
509 return Ok(Some(billing_customer));
510 }
511
512 // If all we have is a customer ID, resolve it to a full customer record by
513 // hitting the Stripe API.
514 let customer = match customer_or_id {
515 Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
516 Expandable::Object(customer) => *customer,
517 };
518
519 let Some(email) = customer.email else {
520 return Ok(None);
521 };
522
523 let Some(user) = app.db.get_user_by_email(&email).await? else {
524 return Ok(None);
525 };
526
527 let billing_customer = app
528 .db
529 .create_billing_customer(&CreateBillingCustomerParams {
530 user_id: user.id,
531 stripe_customer_id: customer.id.to_string(),
532 })
533 .await?;
534
535 Ok(Some(billing_customer))
536}