1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{extract, routing::post, Extension, Json, Router};
7use reqwest::StatusCode;
8use sea_orm::ActiveValue;
9use serde::{Deserialize, Serialize};
10use stripe::{
11 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
12 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
13 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
14 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
15 CreateCustomer, Customer, CustomerId, EventId, EventObject, EventType, Expandable, ListEvents,
16 SubscriptionStatus,
17};
18use util::ResultExt;
19
20use crate::db::billing_subscription::StripeSubscriptionStatus;
21use crate::db::{
22 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
23 CreateBillingSubscriptionParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams,
24};
25use crate::{AppState, Error, Result};
26
27pub fn router() -> Router {
28 Router::new()
29 .route("/billing/subscriptions", post(create_billing_subscription))
30 .route(
31 "/billing/subscriptions/manage",
32 post(manage_billing_subscription),
33 )
34}
35
36#[derive(Debug, Deserialize)]
37struct CreateBillingSubscriptionBody {
38 github_user_id: i32,
39}
40
41#[derive(Debug, Serialize)]
42struct CreateBillingSubscriptionResponse {
43 checkout_session_url: String,
44}
45
46/// Initiates a Stripe Checkout session for creating a billing subscription.
47async fn create_billing_subscription(
48 Extension(app): Extension<Arc<AppState>>,
49 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
50) -> Result<Json<CreateBillingSubscriptionResponse>> {
51 let user = app
52 .db
53 .get_user_by_github_user_id(body.github_user_id)
54 .await?
55 .ok_or_else(|| anyhow!("user not found"))?;
56
57 let Some((stripe_client, stripe_price_id)) = app
58 .stripe_client
59 .clone()
60 .zip(app.config.stripe_price_id.clone())
61 else {
62 log::error!("failed to retrieve Stripe client or price ID");
63 Err(Error::Http(
64 StatusCode::NOT_IMPLEMENTED,
65 "not supported".into(),
66 ))?
67 };
68
69 let customer_id =
70 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
71 CustomerId::from_str(&existing_customer.stripe_customer_id)
72 .context("failed to parse customer ID")?
73 } else {
74 let customer = Customer::create(
75 &stripe_client,
76 CreateCustomer {
77 email: user.email_address.as_deref(),
78 ..Default::default()
79 },
80 )
81 .await?;
82
83 customer.id
84 };
85
86 let checkout_session = {
87 let mut params = CreateCheckoutSession::new();
88 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
89 params.customer = Some(customer_id);
90 params.client_reference_id = Some(user.github_login.as_str());
91 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
92 price: Some(stripe_price_id.to_string()),
93 quantity: Some(1),
94 ..Default::default()
95 }]);
96 params.success_url = Some("https://zed.dev/billing/success");
97
98 CheckoutSession::create(&stripe_client, params).await?
99 };
100
101 Ok(Json(CreateBillingSubscriptionResponse {
102 checkout_session_url: checkout_session
103 .url
104 .ok_or_else(|| anyhow!("no checkout session URL"))?,
105 }))
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(rename_all = "snake_case")]
110enum ManageSubscriptionIntent {
111 /// The user intends to cancel their subscription.
112 Cancel,
113}
114
115#[derive(Debug, Deserialize)]
116struct ManageBillingSubscriptionBody {
117 github_user_id: i32,
118 intent: ManageSubscriptionIntent,
119 /// The ID of the subscription to manage.
120 ///
121 /// If not provided, we will try to use the active subscription (if there is only one).
122 subscription_id: Option<BillingSubscriptionId>,
123}
124
125#[derive(Debug, Serialize)]
126struct ManageBillingSubscriptionResponse {
127 billing_portal_session_url: String,
128}
129
130/// Initiates a Stripe customer portal session for managing a billing subscription.
131async fn manage_billing_subscription(
132 Extension(app): Extension<Arc<AppState>>,
133 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
134) -> Result<Json<ManageBillingSubscriptionResponse>> {
135 let user = app
136 .db
137 .get_user_by_github_user_id(body.github_user_id)
138 .await?
139 .ok_or_else(|| anyhow!("user not found"))?;
140
141 let Some(stripe_client) = app.stripe_client.clone() else {
142 log::error!("failed to retrieve Stripe client");
143 Err(Error::Http(
144 StatusCode::NOT_IMPLEMENTED,
145 "not supported".into(),
146 ))?
147 };
148
149 let customer = app
150 .db
151 .get_billing_customer_by_user_id(user.id)
152 .await?
153 .ok_or_else(|| anyhow!("billing customer not found"))?;
154 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
155 .context("failed to parse customer ID")?;
156
157 let subscription = if let Some(subscription_id) = body.subscription_id {
158 app.db
159 .get_billing_subscription_by_id(subscription_id)
160 .await?
161 .ok_or_else(|| anyhow!("subscription not found"))?
162 } else {
163 // If no subscription ID was provided, try to find the only active subscription ID.
164 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
165 if subscriptions.len() > 1 {
166 Err(anyhow!("user has multiple active subscriptions"))?;
167 }
168
169 subscriptions
170 .into_iter()
171 .next()
172 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
173 };
174
175 let flow = match body.intent {
176 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
177 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
178 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
179 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
180 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
181 return_url: "https://zed.dev/billing".into(),
182 }),
183 ..Default::default()
184 }),
185 subscription_cancel: Some(
186 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
187 subscription: subscription.stripe_subscription_id,
188 retention: None,
189 },
190 ),
191 ..Default::default()
192 },
193 };
194
195 let mut params = CreateBillingPortalSession::new(customer_id);
196 params.flow_data = Some(flow);
197 params.return_url = Some("https://zed.dev/billing");
198
199 let session = BillingPortalSession::create(&stripe_client, params).await?;
200
201 Ok(Json(ManageBillingSubscriptionResponse {
202 billing_portal_session_url: session.url,
203 }))
204}
205
206const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
207
208/// Polls the Stripe events API periodically to reconcile the records in our
209/// database with the data in Stripe.
210pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
211 let Some(stripe_client) = app.stripe_client.clone() else {
212 log::warn!("failed to retrieve Stripe client");
213 return;
214 };
215
216 let executor = app.executor.clone();
217 executor.spawn_detached({
218 let executor = executor.clone();
219 async move {
220 loop {
221 poll_stripe_events(&app, &stripe_client).await.log_err();
222
223 executor.sleep(POLL_EVENTS_INTERVAL).await;
224 }
225 }
226 });
227}
228
229async fn poll_stripe_events(
230 app: &Arc<AppState>,
231 stripe_client: &stripe::Client,
232) -> anyhow::Result<()> {
233 let event_types = [
234 EventType::CustomerCreated.to_string(),
235 EventType::CustomerUpdated.to_string(),
236 EventType::CustomerSubscriptionCreated.to_string(),
237 EventType::CustomerSubscriptionUpdated.to_string(),
238 EventType::CustomerSubscriptionPaused.to_string(),
239 EventType::CustomerSubscriptionResumed.to_string(),
240 EventType::CustomerSubscriptionDeleted.to_string(),
241 ]
242 .into_iter()
243 .map(|event_type| {
244 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
245 // so we need to unquote it.
246 event_type.trim_matches('"').to_string()
247 })
248 .collect::<Vec<_>>();
249
250 loop {
251 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
252
253 let mut params = ListEvents::new();
254 params.types = Some(event_types.clone());
255 params.limit = Some(100);
256
257 let events = stripe::Event::list(stripe_client, ¶ms).await?;
258 for event in events.data {
259 match event.type_ {
260 EventType::CustomerCreated | EventType::CustomerUpdated => {
261 handle_customer_event(app, stripe_client, event)
262 .await
263 .log_err();
264 }
265 EventType::CustomerSubscriptionCreated
266 | EventType::CustomerSubscriptionUpdated
267 | EventType::CustomerSubscriptionPaused
268 | EventType::CustomerSubscriptionResumed
269 | EventType::CustomerSubscriptionDeleted => {
270 handle_customer_subscription_event(app, stripe_client, event)
271 .await
272 .log_err();
273 }
274 _ => {}
275 }
276 }
277
278 if !events.has_more {
279 break;
280 }
281 }
282
283 Ok(())
284}
285
286async fn handle_customer_event(
287 app: &Arc<AppState>,
288 _stripe_client: &stripe::Client,
289 event: stripe::Event,
290) -> anyhow::Result<()> {
291 let EventObject::Customer(customer) = event.data.object else {
292 bail!("unexpected event payload for {}", event.id);
293 };
294
295 log::info!("handling Stripe {} event: {}", event.type_, event.id);
296
297 let Some(email) = customer.email else {
298 log::info!("Stripe customer has no email: skipping");
299 return Ok(());
300 };
301
302 let Some(user) = app.db.get_user_by_email(&email).await? else {
303 log::info!("no user found for email: skipping");
304 return Ok(());
305 };
306
307 if let Some(existing_customer) = app
308 .db
309 .get_billing_customer_by_stripe_customer_id(&customer.id)
310 .await?
311 {
312 if should_ignore_event(&event.id, existing_customer.last_stripe_event_id.as_deref()) {
313 log::info!(
314 "ignoring Stripe event {} based on last seen event ID",
315 event.id
316 );
317 return Ok(());
318 }
319
320 app.db
321 .update_billing_customer(
322 existing_customer.id,
323 &UpdateBillingCustomerParams {
324 // For now we just update the last event ID for the customer
325 // and leave the rest of the information as-is, as it is not
326 // likely to change.
327 last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
328 ..Default::default()
329 },
330 )
331 .await?;
332 } else {
333 app.db
334 .create_billing_customer(&CreateBillingCustomerParams {
335 user_id: user.id,
336 stripe_customer_id: customer.id.to_string(),
337 last_stripe_event_id: Some(event.id.to_string()),
338 })
339 .await?;
340 }
341
342 Ok(())
343}
344
345async fn handle_customer_subscription_event(
346 app: &Arc<AppState>,
347 stripe_client: &stripe::Client,
348 event: stripe::Event,
349) -> anyhow::Result<()> {
350 let EventObject::Subscription(subscription) = event.data.object else {
351 bail!("unexpected event payload for {}", event.id);
352 };
353
354 log::info!("handling Stripe {} event: {}", event.type_, event.id);
355
356 let billing_customer = find_or_create_billing_customer(
357 app,
358 stripe_client,
359 // Even though we're handling a subscription event, we can still set
360 // the ID as the last seen event ID on the customer in the event that
361 // we have to create it.
362 //
363 // This is done to avoid any potential rollback in the customer's values
364 // if we then see an older event that pertains to the customer.
365 &event.id,
366 subscription.customer,
367 )
368 .await?
369 .ok_or_else(|| anyhow!("billing customer not found"))?;
370
371 if let Some(existing_subscription) = app
372 .db
373 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
374 .await?
375 {
376 if should_ignore_event(
377 &event.id,
378 existing_subscription.last_stripe_event_id.as_deref(),
379 ) {
380 log::info!(
381 "ignoring Stripe event {} based on last seen event ID",
382 event.id
383 );
384 return Ok(());
385 }
386
387 app.db
388 .update_billing_subscription(
389 existing_subscription.id,
390 &UpdateBillingSubscriptionParams {
391 billing_customer_id: ActiveValue::set(billing_customer.id),
392 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
393 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
394 last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
395 },
396 )
397 .await?;
398 } else {
399 app.db
400 .create_billing_subscription(&CreateBillingSubscriptionParams {
401 billing_customer_id: billing_customer.id,
402 stripe_subscription_id: subscription.id.to_string(),
403 stripe_subscription_status: subscription.status.into(),
404 last_stripe_event_id: Some(event.id.to_string()),
405 })
406 .await?;
407 }
408
409 Ok(())
410}
411
412impl From<SubscriptionStatus> for StripeSubscriptionStatus {
413 fn from(value: SubscriptionStatus) -> Self {
414 match value {
415 SubscriptionStatus::Incomplete => Self::Incomplete,
416 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
417 SubscriptionStatus::Trialing => Self::Trialing,
418 SubscriptionStatus::Active => Self::Active,
419 SubscriptionStatus::PastDue => Self::PastDue,
420 SubscriptionStatus::Canceled => Self::Canceled,
421 SubscriptionStatus::Unpaid => Self::Unpaid,
422 SubscriptionStatus::Paused => Self::Paused,
423 }
424 }
425}
426
427/// Finds or creates a billing customer using the provided customer.
428async fn find_or_create_billing_customer(
429 app: &Arc<AppState>,
430 stripe_client: &stripe::Client,
431 event_id: &EventId,
432 customer_or_id: Expandable<Customer>,
433) -> anyhow::Result<Option<billing_customer::Model>> {
434 let customer_id = match &customer_or_id {
435 Expandable::Id(id) => id,
436 Expandable::Object(customer) => customer.id.as_ref(),
437 };
438
439 // If we already have a billing customer record associated with the Stripe customer,
440 // there's nothing more we need to do.
441 if let Some(billing_customer) = app
442 .db
443 .get_billing_customer_by_stripe_customer_id(&customer_id)
444 .await?
445 {
446 return Ok(Some(billing_customer));
447 }
448
449 // If all we have is a customer ID, resolve it to a full customer record by
450 // hitting the Stripe API.
451 let customer = match customer_or_id {
452 Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
453 Expandable::Object(customer) => *customer,
454 };
455
456 let Some(email) = customer.email else {
457 return Ok(None);
458 };
459
460 let Some(user) = app.db.get_user_by_email(&email).await? else {
461 return Ok(None);
462 };
463
464 let billing_customer = app
465 .db
466 .create_billing_customer(&CreateBillingCustomerParams {
467 user_id: user.id,
468 stripe_customer_id: customer.id.to_string(),
469 last_stripe_event_id: Some(event_id.to_string()),
470 })
471 .await?;
472
473 Ok(Some(billing_customer))
474}
475
476/// Returns whether an [`Event`] should be ignored, based on its ID and the last
477/// seen event ID for this object.
478#[inline]
479fn should_ignore_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
480 !should_apply_event(event_id, last_event_id)
481}
482
483/// Returns whether an [`Event`] should be applied, based on its ID and the last
484/// seen event ID for this object.
485fn should_apply_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
486 let Some(last_event_id) = last_event_id else {
487 return true;
488 };
489
490 event_id.as_str() < last_event_id
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_should_apply_event() {
499 let subscription_created_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNafuZSGsmh").unwrap();
500 let subscription_updated_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNa5UZLSsto").unwrap();
501
502 assert_eq!(
503 should_apply_event(
504 &subscription_created_event,
505 Some(subscription_created_event.as_str())
506 ),
507 false,
508 "Events should not be applied when the IDs are the same."
509 );
510
511 assert_eq!(
512 should_apply_event(
513 &subscription_created_event,
514 Some(subscription_updated_event.as_str())
515 ),
516 false,
517 "Events should not be applied when the last event ID is newer than the event ID."
518 );
519
520 assert_eq!(
521 should_apply_event(&subscription_created_event, None),
522 true,
523 "Events should be applied when we don't have a last event ID."
524 );
525
526 assert_eq!(
527 should_apply_event(
528 &subscription_updated_event,
529 Some(subscription_created_event.as_str())
530 ),
531 true,
532 "Events should be applied when the event ID is newer than the last event ID."
533 );
534 }
535}