1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::{anyhow, bail, Context};
6use axum::{extract, routing::post, Extension, Json, Router};
7use reqwest::StatusCode;
8use sea_orm::ActiveValue;
9use serde::{Deserialize, Serialize};
10use stripe::{
11 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
12 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
13 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
14 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
15 CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
16 SubscriptionStatus,
17};
18use util::ResultExt;
19
20use crate::db::billing_subscription::StripeSubscriptionStatus;
21use crate::db::{
22 billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
23 CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
24 UpdateBillingSubscriptionParams,
25};
26use crate::{AppState, Error, Result};
27
28pub fn router() -> Router {
29 Router::new()
30 .route("/billing/subscriptions", post(create_billing_subscription))
31 .route(
32 "/billing/subscriptions/manage",
33 post(manage_billing_subscription),
34 )
35}
36
37#[derive(Debug, Deserialize)]
38struct CreateBillingSubscriptionBody {
39 github_user_id: i32,
40}
41
42#[derive(Debug, Serialize)]
43struct CreateBillingSubscriptionResponse {
44 checkout_session_url: String,
45}
46
47/// Initiates a Stripe Checkout session for creating a billing subscription.
48async fn create_billing_subscription(
49 Extension(app): Extension<Arc<AppState>>,
50 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
51) -> Result<Json<CreateBillingSubscriptionResponse>> {
52 let user = app
53 .db
54 .get_user_by_github_user_id(body.github_user_id)
55 .await?
56 .ok_or_else(|| anyhow!("user not found"))?;
57
58 let Some((stripe_client, stripe_price_id)) = app
59 .stripe_client
60 .clone()
61 .zip(app.config.stripe_price_id.clone())
62 else {
63 log::error!("failed to retrieve Stripe client or price ID");
64 Err(Error::Http(
65 StatusCode::NOT_IMPLEMENTED,
66 "not supported".into(),
67 ))?
68 };
69
70 let customer_id =
71 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
72 CustomerId::from_str(&existing_customer.stripe_customer_id)
73 .context("failed to parse customer ID")?
74 } else {
75 let customer = Customer::create(
76 &stripe_client,
77 CreateCustomer {
78 email: user.email_address.as_deref(),
79 ..Default::default()
80 },
81 )
82 .await?;
83
84 customer.id
85 };
86
87 let checkout_session = {
88 let mut params = CreateCheckoutSession::new();
89 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
90 params.customer = Some(customer_id);
91 params.client_reference_id = Some(user.github_login.as_str());
92 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
93 price: Some(stripe_price_id.to_string()),
94 quantity: Some(1),
95 ..Default::default()
96 }]);
97 params.success_url = Some("https://zed.dev/billing/success");
98
99 CheckoutSession::create(&stripe_client, params).await?
100 };
101
102 Ok(Json(CreateBillingSubscriptionResponse {
103 checkout_session_url: checkout_session
104 .url
105 .ok_or_else(|| anyhow!("no checkout session URL"))?,
106 }))
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(rename_all = "snake_case")]
111enum ManageSubscriptionIntent {
112 /// The user intends to cancel their subscription.
113 Cancel,
114}
115
116#[derive(Debug, Deserialize)]
117struct ManageBillingSubscriptionBody {
118 github_user_id: i32,
119 intent: ManageSubscriptionIntent,
120 /// The ID of the subscription to manage.
121 ///
122 /// If not provided, we will try to use the active subscription (if there is only one).
123 subscription_id: Option<BillingSubscriptionId>,
124}
125
126#[derive(Debug, Serialize)]
127struct ManageBillingSubscriptionResponse {
128 billing_portal_session_url: String,
129}
130
131/// Initiates a Stripe customer portal session for managing a billing subscription.
132async fn manage_billing_subscription(
133 Extension(app): Extension<Arc<AppState>>,
134 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
135) -> Result<Json<ManageBillingSubscriptionResponse>> {
136 let user = app
137 .db
138 .get_user_by_github_user_id(body.github_user_id)
139 .await?
140 .ok_or_else(|| anyhow!("user not found"))?;
141
142 let Some(stripe_client) = app.stripe_client.clone() else {
143 log::error!("failed to retrieve Stripe client");
144 Err(Error::Http(
145 StatusCode::NOT_IMPLEMENTED,
146 "not supported".into(),
147 ))?
148 };
149
150 let customer = app
151 .db
152 .get_billing_customer_by_user_id(user.id)
153 .await?
154 .ok_or_else(|| anyhow!("billing customer not found"))?;
155 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
156 .context("failed to parse customer ID")?;
157
158 let subscription = if let Some(subscription_id) = body.subscription_id {
159 app.db
160 .get_billing_subscription_by_id(subscription_id)
161 .await?
162 .ok_or_else(|| anyhow!("subscription not found"))?
163 } else {
164 // If no subscription ID was provided, try to find the only active subscription ID.
165 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
166 if subscriptions.len() > 1 {
167 Err(anyhow!("user has multiple active subscriptions"))?;
168 }
169
170 subscriptions
171 .into_iter()
172 .next()
173 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
174 };
175
176 let flow = match body.intent {
177 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
178 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
179 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
180 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
181 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
182 return_url: "https://zed.dev/billing".into(),
183 }),
184 ..Default::default()
185 }),
186 subscription_cancel: Some(
187 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
188 subscription: subscription.stripe_subscription_id,
189 retention: None,
190 },
191 ),
192 ..Default::default()
193 },
194 };
195
196 let mut params = CreateBillingPortalSession::new(customer_id);
197 params.flow_data = Some(flow);
198 params.return_url = Some("https://zed.dev/billing");
199
200 let session = BillingPortalSession::create(&stripe_client, params).await?;
201
202 Ok(Json(ManageBillingSubscriptionResponse {
203 billing_portal_session_url: session.url,
204 }))
205}
206
207const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
208
209/// Polls the Stripe events API periodically to reconcile the records in our
210/// database with the data in Stripe.
211pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
212 let Some(stripe_client) = app.stripe_client.clone() else {
213 log::warn!("failed to retrieve Stripe client");
214 return;
215 };
216
217 let executor = app.executor.clone();
218 executor.spawn_detached({
219 let executor = executor.clone();
220 async move {
221 loop {
222 poll_stripe_events(&app, &stripe_client).await.log_err();
223
224 executor.sleep(POLL_EVENTS_INTERVAL).await;
225 }
226 }
227 });
228}
229
230async fn poll_stripe_events(
231 app: &Arc<AppState>,
232 stripe_client: &stripe::Client,
233) -> anyhow::Result<()> {
234 fn event_type_to_string(event_type: EventType) -> String {
235 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
236 // so we need to unquote it.
237 event_type.to_string().trim_matches('"').to_string()
238 }
239
240 let event_types = [
241 EventType::CustomerCreated,
242 EventType::CustomerUpdated,
243 EventType::CustomerSubscriptionCreated,
244 EventType::CustomerSubscriptionUpdated,
245 EventType::CustomerSubscriptionPaused,
246 EventType::CustomerSubscriptionResumed,
247 EventType::CustomerSubscriptionDeleted,
248 ]
249 .into_iter()
250 .map(event_type_to_string)
251 .collect::<Vec<_>>();
252
253 let mut unprocessed_events = Vec::new();
254
255 loop {
256 log::info!("retrieving events from Stripe: {}", event_types.join(", "));
257
258 let mut params = ListEvents::new();
259 params.types = Some(event_types.clone());
260 params.limit = Some(100);
261
262 let events = stripe::Event::list(stripe_client, ¶ms).await?;
263
264 let processed_event_ids = {
265 let event_ids = &events
266 .data
267 .iter()
268 .map(|event| event.id.as_str())
269 .collect::<Vec<_>>();
270
271 app.db
272 .get_processed_stripe_events_by_event_ids(event_ids)
273 .await?
274 .into_iter()
275 .map(|event| event.stripe_event_id)
276 .collect::<Vec<_>>()
277 };
278
279 for event in events.data {
280 if processed_event_ids.contains(&event.id.to_string()) {
281 log::info!("Stripe event {} already processed: skipping", event.id);
282 } else {
283 unprocessed_events.push(event);
284 }
285 }
286
287 if !events.has_more {
288 break;
289 }
290 }
291
292 log::info!(
293 "unprocessed events from Stripe: {}",
294 unprocessed_events.len()
295 );
296
297 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
298 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
299
300 for event in unprocessed_events {
301 let processed_event_params = CreateProcessedStripeEventParams {
302 stripe_event_id: event.id.to_string(),
303 stripe_event_type: event_type_to_string(event.type_),
304 stripe_event_created_timestamp: event.created,
305 };
306
307 match event.type_ {
308 EventType::CustomerCreated | EventType::CustomerUpdated => {
309 handle_customer_event(app, stripe_client, event)
310 .await
311 .log_err();
312 }
313 EventType::CustomerSubscriptionCreated
314 | EventType::CustomerSubscriptionUpdated
315 | EventType::CustomerSubscriptionPaused
316 | EventType::CustomerSubscriptionResumed
317 | EventType::CustomerSubscriptionDeleted => {
318 handle_customer_subscription_event(app, stripe_client, event)
319 .await
320 .log_err();
321 }
322 _ => {}
323 }
324
325 app.db
326 .create_processed_stripe_event(&processed_event_params)
327 .await?;
328 }
329
330 Ok(())
331}
332
333async fn handle_customer_event(
334 app: &Arc<AppState>,
335 _stripe_client: &stripe::Client,
336 event: stripe::Event,
337) -> anyhow::Result<()> {
338 let EventObject::Customer(customer) = event.data.object else {
339 bail!("unexpected event payload for {}", event.id);
340 };
341
342 log::info!("handling Stripe {} event: {}", event.type_, event.id);
343
344 let Some(email) = customer.email else {
345 log::info!("Stripe customer has no email: skipping");
346 return Ok(());
347 };
348
349 let Some(user) = app.db.get_user_by_email(&email).await? else {
350 log::info!("no user found for email: skipping");
351 return Ok(());
352 };
353
354 if let Some(existing_customer) = app
355 .db
356 .get_billing_customer_by_stripe_customer_id(&customer.id)
357 .await?
358 {
359 app.db
360 .update_billing_customer(
361 existing_customer.id,
362 &UpdateBillingCustomerParams {
363 // For now we just leave the information as-is, as it is not
364 // likely to change.
365 ..Default::default()
366 },
367 )
368 .await?;
369 } else {
370 app.db
371 .create_billing_customer(&CreateBillingCustomerParams {
372 user_id: user.id,
373 stripe_customer_id: customer.id.to_string(),
374 })
375 .await?;
376 }
377
378 Ok(())
379}
380
381async fn handle_customer_subscription_event(
382 app: &Arc<AppState>,
383 stripe_client: &stripe::Client,
384 event: stripe::Event,
385) -> anyhow::Result<()> {
386 let EventObject::Subscription(subscription) = event.data.object else {
387 bail!("unexpected event payload for {}", event.id);
388 };
389
390 log::info!("handling Stripe {} event: {}", event.type_, event.id);
391
392 let billing_customer =
393 find_or_create_billing_customer(app, stripe_client, subscription.customer)
394 .await?
395 .ok_or_else(|| anyhow!("billing customer not found"))?;
396
397 if let Some(existing_subscription) = app
398 .db
399 .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
400 .await?
401 {
402 app.db
403 .update_billing_subscription(
404 existing_subscription.id,
405 &UpdateBillingSubscriptionParams {
406 billing_customer_id: ActiveValue::set(billing_customer.id),
407 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
408 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
409 },
410 )
411 .await?;
412 } else {
413 app.db
414 .create_billing_subscription(&CreateBillingSubscriptionParams {
415 billing_customer_id: billing_customer.id,
416 stripe_subscription_id: subscription.id.to_string(),
417 stripe_subscription_status: subscription.status.into(),
418 })
419 .await?;
420 }
421
422 Ok(())
423}
424
425impl From<SubscriptionStatus> for StripeSubscriptionStatus {
426 fn from(value: SubscriptionStatus) -> Self {
427 match value {
428 SubscriptionStatus::Incomplete => Self::Incomplete,
429 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
430 SubscriptionStatus::Trialing => Self::Trialing,
431 SubscriptionStatus::Active => Self::Active,
432 SubscriptionStatus::PastDue => Self::PastDue,
433 SubscriptionStatus::Canceled => Self::Canceled,
434 SubscriptionStatus::Unpaid => Self::Unpaid,
435 SubscriptionStatus::Paused => Self::Paused,
436 }
437 }
438}
439
440/// Finds or creates a billing customer using the provided customer.
441async fn find_or_create_billing_customer(
442 app: &Arc<AppState>,
443 stripe_client: &stripe::Client,
444 customer_or_id: Expandable<Customer>,
445) -> anyhow::Result<Option<billing_customer::Model>> {
446 let customer_id = match &customer_or_id {
447 Expandable::Id(id) => id,
448 Expandable::Object(customer) => customer.id.as_ref(),
449 };
450
451 // If we already have a billing customer record associated with the Stripe customer,
452 // there's nothing more we need to do.
453 if let Some(billing_customer) = app
454 .db
455 .get_billing_customer_by_stripe_customer_id(&customer_id)
456 .await?
457 {
458 return Ok(Some(billing_customer));
459 }
460
461 // If all we have is a customer ID, resolve it to a full customer record by
462 // hitting the Stripe API.
463 let customer = match customer_or_id {
464 Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
465 Expandable::Object(customer) => *customer,
466 };
467
468 let Some(email) = customer.email else {
469 return Ok(None);
470 };
471
472 let Some(user) = app.db.get_user_by_email(&email).await? else {
473 return Ok(None);
474 };
475
476 let billing_customer = app
477 .db
478 .create_billing_customer(&CreateBillingCustomerParams {
479 user_id: user.id,
480 stripe_customer_id: customer.id.to_string(),
481 })
482 .await?;
483
484 Ok(Some(billing_customer))
485}