1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{
7 extract::{self, Query},
8 routing::{get, post},
9 Extension, Json, Router,
10};
11use reqwest::StatusCode;
12use sea_orm::ActiveValue;
13use serde::{Deserialize, Serialize};
14use stripe::{
15 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
16 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
17 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
18 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
19 CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
20 SubscriptionStatus,
21};
22use util::ResultExt;
23
24use crate::db::billing_subscription::StripeSubscriptionStatus;
25use crate::db::{
26 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
27 CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
28 UpdateBillingSubscriptionParams,
29};
30use crate::{AppState, Error, Result};
31
32pub fn router() -> Router {
33 Router::new()
34 .route(
35 "/billing/subscriptions",
36 get(list_billing_subscriptions).post(create_billing_subscription),
37 )
38 .route(
39 "/billing/subscriptions/manage",
40 post(manage_billing_subscription),
41 )
42}
43
44#[derive(Debug, Deserialize)]
45struct ListBillingSubscriptionsParams {
46 github_user_id: i32,
47}
48
49#[derive(Debug, Serialize)]
50struct BillingSubscriptionJson {
51 id: BillingSubscriptionId,
52 name: String,
53 status: StripeSubscriptionStatus,
54 /// Whether this subscription can be canceled.
55 is_cancelable: bool,
56}
57
58#[derive(Debug, Serialize)]
59struct ListBillingSubscriptionsResponse {
60 subscriptions: Vec<BillingSubscriptionJson>,
61}
62
63async fn list_billing_subscriptions(
64 Extension(app): Extension<Arc<AppState>>,
65 Query(params): Query<ListBillingSubscriptionsParams>,
66) -> Result<Json<ListBillingSubscriptionsResponse>> {
67 let user = app
68 .db
69 .get_user_by_github_user_id(params.github_user_id)
70 .await?
71 .ok_or_else(|| anyhow!("user not found"))?;
72
73 let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
74
75 Ok(Json(ListBillingSubscriptionsResponse {
76 subscriptions: subscriptions
77 .into_iter()
78 .map(|subscription| BillingSubscriptionJson {
79 id: subscription.id,
80 name: "Zed Pro".to_string(),
81 status: subscription.stripe_subscription_status,
82 is_cancelable: subscription.stripe_subscription_status.is_cancelable(),
83 })
84 .collect(),
85 }))
86}
87
88#[derive(Debug, Deserialize)]
89struct CreateBillingSubscriptionBody {
90 github_user_id: i32,
91}
92
93#[derive(Debug, Serialize)]
94struct CreateBillingSubscriptionResponse {
95 checkout_session_url: String,
96}
97
98/// Initiates a Stripe Checkout session for creating a billing subscription.
99async fn create_billing_subscription(
100 Extension(app): Extension<Arc<AppState>>,
101 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
102) -> Result<Json<CreateBillingSubscriptionResponse>> {
103 let user = app
104 .db
105 .get_user_by_github_user_id(body.github_user_id)
106 .await?
107 .ok_or_else(|| anyhow!("user not found"))?;
108
109 let Some((stripe_client, stripe_price_id)) = app
110 .stripe_client
111 .clone()
112 .zip(app.config.stripe_price_id.clone())
113 else {
114 log::error!("failed to retrieve Stripe client or price ID");
115 Err(Error::Http(
116 StatusCode::NOT_IMPLEMENTED,
117 "not supported".into(),
118 ))?
119 };
120
121 let customer_id =
122 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
123 CustomerId::from_str(&existing_customer.stripe_customer_id)
124 .context("failed to parse customer ID")?
125 } else {
126 let customer = Customer::create(
127 &stripe_client,
128 CreateCustomer {
129 email: user.email_address.as_deref(),
130 ..Default::default()
131 },
132 )
133 .await?;
134
135 customer.id
136 };
137
138 let checkout_session = {
139 let mut params = CreateCheckoutSession::new();
140 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
141 params.customer = Some(customer_id);
142 params.client_reference_id = Some(user.github_login.as_str());
143 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
144 price: Some(stripe_price_id.to_string()),
145 quantity: Some(1),
146 ..Default::default()
147 }]);
148 params.success_url = Some("https://zed.dev/billing/success");
149
150 CheckoutSession::create(&stripe_client, params).await?
151 };
152
153 Ok(Json(CreateBillingSubscriptionResponse {
154 checkout_session_url: checkout_session
155 .url
156 .ok_or_else(|| anyhow!("no checkout session URL"))?,
157 }))
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "snake_case")]
162enum ManageSubscriptionIntent {
163 /// The user intends to cancel their subscription.
164 Cancel,
165}
166
167#[derive(Debug, Deserialize)]
168struct ManageBillingSubscriptionBody {
169 github_user_id: i32,
170 intent: ManageSubscriptionIntent,
171 /// The ID of the subscription to manage.
172 subscription_id: BillingSubscriptionId,
173}
174
175#[derive(Debug, Serialize)]
176struct ManageBillingSubscriptionResponse {
177 billing_portal_session_url: String,
178}
179
180/// Initiates a Stripe customer portal session for managing a billing subscription.
181async fn manage_billing_subscription(
182 Extension(app): Extension<Arc<AppState>>,
183 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
184) -> Result<Json<ManageBillingSubscriptionResponse>> {
185 let user = app
186 .db
187 .get_user_by_github_user_id(body.github_user_id)
188 .await?
189 .ok_or_else(|| anyhow!("user not found"))?;
190
191 let Some(stripe_client) = app.stripe_client.clone() else {
192 log::error!("failed to retrieve Stripe client");
193 Err(Error::Http(
194 StatusCode::NOT_IMPLEMENTED,
195 "not supported".into(),
196 ))?
197 };
198
199 let customer = app
200 .db
201 .get_billing_customer_by_user_id(user.id)
202 .await?
203 .ok_or_else(|| anyhow!("billing customer not found"))?;
204 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
205 .context("failed to parse customer ID")?;
206
207 let subscription = app
208 .db
209 .get_billing_subscription_by_id(body.subscription_id)
210 .await?
211 .ok_or_else(|| anyhow!("subscription not found"))?;
212
213 let flow = match body.intent {
214 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
215 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
216 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
217 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
218 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
219 return_url: "https://zed.dev/settings".into(),
220 }),
221 ..Default::default()
222 }),
223 subscription_cancel: Some(
224 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
225 subscription: subscription.stripe_subscription_id,
226 retention: None,
227 },
228 ),
229 ..Default::default()
230 },
231 };
232
233 let mut params = CreateBillingPortalSession::new(customer_id);
234 params.flow_data = Some(flow);
235 params.return_url = Some("https://zed.dev/settings");
236
237 let session = BillingPortalSession::create(&stripe_client, params).await?;
238
239 Ok(Json(ManageBillingSubscriptionResponse {
240 billing_portal_session_url: session.url,
241 }))
242}
243
244const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
245
246/// Polls the Stripe events API periodically to reconcile the records in our
247/// database with the data in Stripe.
248pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
249 let Some(stripe_client) = app.stripe_client.clone() else {
250 log::warn!("failed to retrieve Stripe client");
251 return;
252 };
253
254 let executor = app.executor.clone();
255 executor.spawn_detached({
256 let executor = executor.clone();
257 async move {
258 loop {
259 poll_stripe_events(&app, &stripe_client).await.log_err();
260
261 executor.sleep(POLL_EVENTS_INTERVAL).await;
262 }
263 }
264 });
265}
266
267async fn poll_stripe_events(
268 app: &Arc<AppState>,
269 stripe_client: &stripe::Client,
270) -> anyhow::Result<()> {
271 fn event_type_to_string(event_type: EventType) -> String {
272 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
273 // so we need to unquote it.
274 event_type.to_string().trim_matches('"').to_string()
275 }
276
277 let event_types = [
278 EventType::CustomerCreated,
279 EventType::CustomerUpdated,
280 EventType::CustomerSubscriptionCreated,
281 EventType::CustomerSubscriptionUpdated,
282 EventType::CustomerSubscriptionPaused,
283 EventType::CustomerSubscriptionResumed,
284 EventType::CustomerSubscriptionDeleted,
285 ]
286 .into_iter()
287 .map(event_type_to_string)
288 .collect::<Vec<_>>();
289
290 let mut unprocessed_events = Vec::new();
291
292 loop {
293 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
294
295 let mut params = ListEvents::new();
296 params.types = Some(event_types.clone());
297 params.limit = Some(100);
298
299 let events = stripe::Event::list(stripe_client, ¶ms).await?;
300
301 let processed_event_ids = {
302 let event_ids = &events
303 .data
304 .iter()
305 .map(|event| event.id.as_str())
306 .collect::<Vec<_>>();
307
308 app.db
309 .get_processed_stripe_events_by_event_ids(event_ids)
310 .await?
311 .into_iter()
312 .map(|event| event.stripe_event_id)
313 .collect::<Vec<_>>()
314 };
315
316 for event in events.data {
317 if processed_event_ids.contains(&event.id.to_string()) {
318 log::info!("Stripe event {} already processed: skipping", event.id);
319 } else {
320 unprocessed_events.push(event);
321 }
322 }
323
324 if !events.has_more {
325 break;
326 }
327 }
328
329 log::info!(
330 "unprocessed events from Stripe: {}",
331 unprocessed_events.len()
332 );
333
334 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
335 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
336
337 for event in unprocessed_events {
338 let processed_event_params = CreateProcessedStripeEventParams {
339 stripe_event_id: event.id.to_string(),
340 stripe_event_type: event_type_to_string(event.type_),
341 stripe_event_created_timestamp: event.created,
342 };
343
344 match event.type_ {
345 EventType::CustomerCreated | EventType::CustomerUpdated => {
346 handle_customer_event(app, stripe_client, event)
347 .await
348 .log_err();
349 }
350 EventType::CustomerSubscriptionCreated
351 | EventType::CustomerSubscriptionUpdated
352 | EventType::CustomerSubscriptionPaused
353 | EventType::CustomerSubscriptionResumed
354 | EventType::CustomerSubscriptionDeleted => {
355 handle_customer_subscription_event(app, stripe_client, event)
356 .await
357 .log_err();
358 }
359 _ => {}
360 }
361
362 app.db
363 .create_processed_stripe_event(&processed_event_params)
364 .await?;
365 }
366
367 Ok(())
368}
369
370async fn handle_customer_event(
371 app: &Arc<AppState>,
372 _stripe_client: &stripe::Client,
373 event: stripe::Event,
374) -> anyhow::Result<()> {
375 let EventObject::Customer(customer) = event.data.object else {
376 bail!("unexpected event payload for {}", event.id);
377 };
378
379 log::info!("handling Stripe {} event: {}", event.type_, event.id);
380
381 let Some(email) = customer.email else {
382 log::info!("Stripe customer has no email: skipping");
383 return Ok(());
384 };
385
386 let Some(user) = app.db.get_user_by_email(&email).await? else {
387 log::info!("no user found for email: skipping");
388 return Ok(());
389 };
390
391 if let Some(existing_customer) = app
392 .db
393 .get_billing_customer_by_stripe_customer_id(&customer.id)
394 .await?
395 {
396 app.db
397 .update_billing_customer(
398 existing_customer.id,
399 &UpdateBillingCustomerParams {
400 // For now we just leave the information as-is, as it is not
401 // likely to change.
402 ..Default::default()
403 },
404 )
405 .await?;
406 } else {
407 app.db
408 .create_billing_customer(&CreateBillingCustomerParams {
409 user_id: user.id,
410 stripe_customer_id: customer.id.to_string(),
411 })
412 .await?;
413 }
414
415 Ok(())
416}
417
418async fn handle_customer_subscription_event(
419 app: &Arc<AppState>,
420 stripe_client: &stripe::Client,
421 event: stripe::Event,
422) -> anyhow::Result<()> {
423 let EventObject::Subscription(subscription) = event.data.object else {
424 bail!("unexpected event payload for {}", event.id);
425 };
426
427 log::info!("handling Stripe {} event: {}", event.type_, event.id);
428
429 let billing_customer =
430 find_or_create_billing_customer(app, stripe_client, subscription.customer)
431 .await?
432 .ok_or_else(|| anyhow!("billing customer not found"))?;
433
434 if let Some(existing_subscription) = app
435 .db
436 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
437 .await?
438 {
439 app.db
440 .update_billing_subscription(
441 existing_subscription.id,
442 &UpdateBillingSubscriptionParams {
443 billing_customer_id: ActiveValue::set(billing_customer.id),
444 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
445 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
446 },
447 )
448 .await?;
449 } else {
450 app.db
451 .create_billing_subscription(&CreateBillingSubscriptionParams {
452 billing_customer_id: billing_customer.id,
453 stripe_subscription_id: subscription.id.to_string(),
454 stripe_subscription_status: subscription.status.into(),
455 })
456 .await?;
457 }
458
459 Ok(())
460}
461
462impl From<SubscriptionStatus> for StripeSubscriptionStatus {
463 fn from(value: SubscriptionStatus) -> Self {
464 match value {
465 SubscriptionStatus::Incomplete => Self::Incomplete,
466 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
467 SubscriptionStatus::Trialing => Self::Trialing,
468 SubscriptionStatus::Active => Self::Active,
469 SubscriptionStatus::PastDue => Self::PastDue,
470 SubscriptionStatus::Canceled => Self::Canceled,
471 SubscriptionStatus::Unpaid => Self::Unpaid,
472 SubscriptionStatus::Paused => Self::Paused,
473 }
474 }
475}
476
477/// Finds or creates a billing customer using the provided customer.
478async fn find_or_create_billing_customer(
479 app: &Arc<AppState>,
480 stripe_client: &stripe::Client,
481 customer_or_id: Expandable<Customer>,
482) -> anyhow::Result<Option<billing_customer::Model>> {
483 let customer_id = match &customer_or_id {
484 Expandable::Id(id) => id,
485 Expandable::Object(customer) => customer.id.as_ref(),
486 };
487
488 // If we already have a billing customer record associated with the Stripe customer,
489 // there's nothing more we need to do.
490 if let Some(billing_customer) = app
491 .db
492 .get_billing_customer_by_stripe_customer_id(&customer_id)
493 .await?
494 {
495 return Ok(Some(billing_customer));
496 }
497
498 // If all we have is a customer ID, resolve it to a full customer record by
499 // hitting the Stripe API.
500 let customer = match customer_or_id {
501 Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
502 Expandable::Object(customer) => *customer,
503 };
504
505 let Some(email) = customer.email else {
506 return Ok(None);
507 };
508
509 let Some(user) = app.db.get_user_by_email(&email).await? else {
510 return Ok(None);
511 };
512
513 let billing_customer = app
514 .db
515 .create_billing_customer(&CreateBillingCustomerParams {
516 user_id: user.id,
517 stripe_customer_id: customer.id.to_string(),
518 })
519 .await?;
520
521 Ok(Some(billing_customer))
522}