1use anyhow::{Context as _, bail};
2use chrono::{DateTime, Utc};
3use sea_orm::ActiveValue;
4use std::{sync::Arc, time::Duration};
5use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
6use util::ResultExt;
7
8use crate::AppState;
9use crate::db::billing_subscription::{
10 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
11};
12use crate::db::{
13 CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
14 UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer,
15};
16use crate::rpc::{ResultExt as _, Server};
17use crate::stripe_client::{
18 StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
19 StripeSubscriptionId,
20};
21
22/// The amount of time we wait in between each poll of Stripe events.
23///
24/// This value should strike a balance between:
25/// 1. Being short enough that we update quickly when something in Stripe changes
26/// 2. Being long enough that we don't eat into our rate limits.
27///
28/// As a point of reference, the Sequin folks say they have this at **500ms**:
29///
30/// > We poll the Stripe /events endpoint every 500ms per account
31/// >
32/// > — https://blog.sequinstream.com/events-not-webhooks/
33const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
34
35/// The maximum number of events to return per page.
36///
37/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
38///
39/// > Limit can range between 1 and 100, and the default is 10.
40const EVENTS_LIMIT_PER_PAGE: u64 = 100;
41
42/// The number of pages consisting entirely of already-processed events that we
43/// will see before we stop retrieving events.
44///
45/// This is used to prevent over-fetching the Stripe events API for events we've
46/// already seen and processed.
47const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
48
49/// Polls the Stripe events API periodically to reconcile the records in our
50/// database with the data in Stripe.
51pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
52 let Some(real_stripe_client) = app.real_stripe_client.clone() else {
53 log::warn!("failed to retrieve Stripe client");
54 return;
55 };
56 let Some(stripe_client) = app.stripe_client.clone() else {
57 log::warn!("failed to retrieve Stripe client");
58 return;
59 };
60
61 let executor = app.executor.clone();
62 executor.spawn_detached({
63 let executor = executor.clone();
64 async move {
65 loop {
66 poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
67 .await
68 .log_err();
69
70 executor.sleep(POLL_EVENTS_INTERVAL).await;
71 }
72 }
73 });
74}
75
76async fn poll_stripe_events(
77 app: &Arc<AppState>,
78 rpc_server: &Arc<Server>,
79 stripe_client: &Arc<dyn StripeClient>,
80 real_stripe_client: &stripe::Client,
81) -> anyhow::Result<()> {
82 let feature_flags = app.db.list_feature_flags().await?;
83 let sync_events_using_cloud = feature_flags
84 .iter()
85 .any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
86 if sync_events_using_cloud {
87 return Ok(());
88 }
89
90 fn event_type_to_string(event_type: EventType) -> String {
91 // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
92 // so we need to unquote it.
93 event_type.to_string().trim_matches('"').to_string()
94 }
95
96 let event_types = [
97 EventType::CustomerCreated,
98 EventType::CustomerUpdated,
99 EventType::CustomerSubscriptionCreated,
100 EventType::CustomerSubscriptionUpdated,
101 EventType::CustomerSubscriptionPaused,
102 EventType::CustomerSubscriptionResumed,
103 EventType::CustomerSubscriptionDeleted,
104 ]
105 .into_iter()
106 .map(event_type_to_string)
107 .collect::<Vec<_>>();
108
109 let mut pages_of_already_processed_events = 0;
110 let mut unprocessed_events = Vec::new();
111
112 log::info!(
113 "Stripe events: starting retrieval for {}",
114 event_types.join(", ")
115 );
116 let mut params = ListEvents::new();
117 params.types = Some(event_types.clone());
118 params.limit = Some(EVENTS_LIMIT_PER_PAGE);
119
120 let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
121 .await?
122 .paginate(params);
123
124 loop {
125 let processed_event_ids = {
126 let event_ids = event_pages
127 .page
128 .data
129 .iter()
130 .map(|event| event.id.as_str())
131 .collect::<Vec<_>>();
132 app.db
133 .get_processed_stripe_events_by_event_ids(&event_ids)
134 .await?
135 .into_iter()
136 .map(|event| event.stripe_event_id)
137 .collect::<Vec<_>>()
138 };
139
140 let mut processed_events_in_page = 0;
141 let events_in_page = event_pages.page.data.len();
142 for event in &event_pages.page.data {
143 if processed_event_ids.contains(&event.id.to_string()) {
144 processed_events_in_page += 1;
145 log::debug!("Stripe events: already processed '{}', skipping", event.id);
146 } else {
147 unprocessed_events.push(event.clone());
148 }
149 }
150
151 if processed_events_in_page == events_in_page {
152 pages_of_already_processed_events += 1;
153 }
154
155 if event_pages.page.has_more {
156 if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
157 {
158 log::info!(
159 "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
160 );
161 break;
162 } else {
163 log::info!("Stripe events: retrieving next page");
164 event_pages = event_pages.next(&real_stripe_client).await?;
165 }
166 } else {
167 break;
168 }
169 }
170
171 log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
172
173 // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
174 unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
175
176 for event in unprocessed_events {
177 let event_id = event.id.clone();
178 let processed_event_params = CreateProcessedStripeEventParams {
179 stripe_event_id: event.id.to_string(),
180 stripe_event_type: event_type_to_string(event.type_),
181 stripe_event_created_timestamp: event.created,
182 };
183
184 // If the event has happened too far in the past, we don't want to
185 // process it and risk overwriting other more-recent updates.
186 //
187 // 1 day was chosen arbitrarily. This could be made longer or shorter.
188 let one_day = Duration::from_secs(24 * 60 * 60);
189 let a_day_ago = Utc::now() - one_day;
190 if a_day_ago.timestamp() > event.created {
191 log::info!(
192 "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
193 event_id
194 );
195 app.db
196 .create_processed_stripe_event(&processed_event_params)
197 .await?;
198
199 continue;
200 }
201
202 let process_result = match event.type_ {
203 EventType::CustomerCreated | EventType::CustomerUpdated => {
204 handle_customer_event(app, real_stripe_client, event).await
205 }
206 EventType::CustomerSubscriptionCreated
207 | EventType::CustomerSubscriptionUpdated
208 | EventType::CustomerSubscriptionPaused
209 | EventType::CustomerSubscriptionResumed
210 | EventType::CustomerSubscriptionDeleted => {
211 handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
212 }
213 _ => Ok(()),
214 };
215
216 if let Some(()) = process_result
217 .with_context(|| format!("failed to process event {event_id} successfully"))
218 .log_err()
219 {
220 app.db
221 .create_processed_stripe_event(&processed_event_params)
222 .await?;
223 }
224 }
225
226 Ok(())
227}
228
229async fn handle_customer_event(
230 app: &Arc<AppState>,
231 _stripe_client: &stripe::Client,
232 event: stripe::Event,
233) -> anyhow::Result<()> {
234 let EventObject::Customer(customer) = event.data.object else {
235 bail!("unexpected event payload for {}", event.id);
236 };
237
238 log::info!("handling Stripe {} event: {}", event.type_, event.id);
239
240 let Some(email) = customer.email else {
241 log::info!("Stripe customer has no email: skipping");
242 return Ok(());
243 };
244
245 let Some(user) = app.db.get_user_by_email(&email).await? else {
246 log::info!("no user found for email: skipping");
247 return Ok(());
248 };
249
250 if let Some(existing_customer) = app
251 .db
252 .get_billing_customer_by_stripe_customer_id(&customer.id)
253 .await?
254 {
255 app.db
256 .update_billing_customer(
257 existing_customer.id,
258 &UpdateBillingCustomerParams {
259 // For now we just leave the information as-is, as it is not
260 // likely to change.
261 ..Default::default()
262 },
263 )
264 .await?;
265 } else {
266 app.db
267 .create_billing_customer(&CreateBillingCustomerParams {
268 user_id: user.id,
269 stripe_customer_id: customer.id.to_string(),
270 })
271 .await?;
272 }
273
274 Ok(())
275}
276
277async fn sync_subscription(
278 app: &Arc<AppState>,
279 stripe_client: &Arc<dyn StripeClient>,
280 subscription: StripeSubscription,
281) -> anyhow::Result<billing_customer::Model> {
282 let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
283 stripe_billing
284 .determine_subscription_kind(&subscription)
285 .await
286 } else {
287 None
288 };
289
290 let billing_customer =
291 find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
292 .await?
293 .context("billing customer not found")?;
294
295 if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
296 if subscription.status == SubscriptionStatus::Trialing {
297 let current_period_start =
298 DateTime::from_timestamp(subscription.current_period_start, 0)
299 .context("No trial subscription period start")?;
300
301 app.db
302 .update_billing_customer(
303 billing_customer.id,
304 &UpdateBillingCustomerParams {
305 trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
306 ..Default::default()
307 },
308 )
309 .await?;
310 }
311 }
312
313 let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
314 && subscription
315 .cancellation_details
316 .as_ref()
317 .and_then(|details| details.reason)
318 .map_or(false, |reason| {
319 reason == StripeCancellationDetailsReason::PaymentFailed
320 });
321
322 if was_canceled_due_to_payment_failure {
323 app.db
324 .update_billing_customer(
325 billing_customer.id,
326 &UpdateBillingCustomerParams {
327 has_overdue_invoices: ActiveValue::set(true),
328 ..Default::default()
329 },
330 )
331 .await?;
332 }
333
334 if let Some(existing_subscription) = app
335 .db
336 .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
337 .await?
338 {
339 app.db
340 .update_billing_subscription(
341 existing_subscription.id,
342 &UpdateBillingSubscriptionParams {
343 billing_customer_id: ActiveValue::set(billing_customer.id),
344 kind: ActiveValue::set(subscription_kind),
345 stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
346 stripe_subscription_status: ActiveValue::set(subscription.status.into()),
347 stripe_cancel_at: ActiveValue::set(
348 subscription
349 .cancel_at
350 .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
351 .map(|time| time.naive_utc()),
352 ),
353 stripe_cancellation_reason: ActiveValue::set(
354 subscription
355 .cancellation_details
356 .and_then(|details| details.reason)
357 .map(|reason| reason.into()),
358 ),
359 stripe_current_period_start: ActiveValue::set(Some(
360 subscription.current_period_start,
361 )),
362 stripe_current_period_end: ActiveValue::set(Some(
363 subscription.current_period_end,
364 )),
365 },
366 )
367 .await?;
368 } else {
369 if let Some(existing_subscription) = app
370 .db
371 .get_active_billing_subscription(billing_customer.user_id)
372 .await?
373 {
374 if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
375 && subscription_kind == Some(SubscriptionKind::ZedProTrial)
376 {
377 let stripe_subscription_id = StripeSubscriptionId(
378 existing_subscription.stripe_subscription_id.clone().into(),
379 );
380
381 stripe_client
382 .cancel_subscription(&stripe_subscription_id)
383 .await?;
384 } else {
385 // If the user already has an active billing subscription, ignore the
386 // event and return an `Ok` to signal that it was processed
387 // successfully.
388 //
389 // There is the possibility that this could cause us to not create a
390 // subscription in the following scenario:
391 //
392 // 1. User has an active subscription A
393 // 2. User cancels subscription A
394 // 3. User creates a new subscription B
395 // 4. We process the new subscription B before the cancellation of subscription A
396 // 5. User ends up with no subscriptions
397 //
398 // In theory this situation shouldn't arise as we try to process the events in the order they occur.
399
400 log::info!(
401 "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
402 user_id = billing_customer.user_id,
403 subscription_id = subscription.id
404 );
405 return Ok(billing_customer);
406 }
407 }
408
409 app.db
410 .create_billing_subscription(&CreateBillingSubscriptionParams {
411 billing_customer_id: billing_customer.id,
412 kind: subscription_kind,
413 stripe_subscription_id: subscription.id.to_string(),
414 stripe_subscription_status: subscription.status.into(),
415 stripe_cancellation_reason: subscription
416 .cancellation_details
417 .and_then(|details| details.reason)
418 .map(|reason| reason.into()),
419 stripe_current_period_start: Some(subscription.current_period_start),
420 stripe_current_period_end: Some(subscription.current_period_end),
421 })
422 .await?;
423 }
424
425 if let Some(stripe_billing) = app.stripe_billing.as_ref() {
426 if subscription.status == SubscriptionStatus::Canceled
427 || subscription.status == SubscriptionStatus::Paused
428 {
429 let already_has_active_billing_subscription = app
430 .db
431 .has_active_billing_subscription(billing_customer.user_id)
432 .await?;
433 if !already_has_active_billing_subscription {
434 let stripe_customer_id =
435 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
436
437 stripe_billing
438 .subscribe_to_zed_free(stripe_customer_id)
439 .await?;
440 }
441 }
442 }
443
444 Ok(billing_customer)
445}
446
447async fn handle_customer_subscription_event(
448 app: &Arc<AppState>,
449 rpc_server: &Arc<Server>,
450 stripe_client: &Arc<dyn StripeClient>,
451 event: stripe::Event,
452) -> anyhow::Result<()> {
453 let EventObject::Subscription(subscription) = event.data.object else {
454 bail!("unexpected event payload for {}", event.id);
455 };
456
457 log::info!("handling Stripe {} event: {}", event.type_, event.id);
458
459 let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
460
461 // When the user's subscription changes, push down any changes to their plan.
462 rpc_server
463 .update_plan_for_user_legacy(billing_customer.user_id)
464 .await
465 .trace_err();
466
467 // When the user's subscription changes, we want to refresh their LLM tokens
468 // to either grant/revoke access.
469 rpc_server
470 .refresh_llm_tokens_for_user(billing_customer.user_id)
471 .await;
472
473 Ok(())
474}
475
476impl From<SubscriptionStatus> for StripeSubscriptionStatus {
477 fn from(value: SubscriptionStatus) -> Self {
478 match value {
479 SubscriptionStatus::Incomplete => Self::Incomplete,
480 SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
481 SubscriptionStatus::Trialing => Self::Trialing,
482 SubscriptionStatus::Active => Self::Active,
483 SubscriptionStatus::PastDue => Self::PastDue,
484 SubscriptionStatus::Canceled => Self::Canceled,
485 SubscriptionStatus::Unpaid => Self::Unpaid,
486 SubscriptionStatus::Paused => Self::Paused,
487 }
488 }
489}
490
491impl From<CancellationDetailsReason> for StripeCancellationReason {
492 fn from(value: CancellationDetailsReason) -> Self {
493 match value {
494 CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
495 CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
496 CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
497 }
498 }
499}
500
501/// Finds or creates a billing customer using the provided customer.
502pub async fn find_or_create_billing_customer(
503 app: &Arc<AppState>,
504 stripe_client: &dyn StripeClient,
505 customer_id: &StripeCustomerId,
506) -> anyhow::Result<Option<billing_customer::Model>> {
507 // If we already have a billing customer record associated with the Stripe customer,
508 // there's nothing more we need to do.
509 if let Some(billing_customer) = app
510 .db
511 .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
512 .await?
513 {
514 return Ok(Some(billing_customer));
515 }
516
517 let customer = stripe_client.get_customer(customer_id).await?;
518
519 let Some(email) = customer.email else {
520 return Ok(None);
521 };
522
523 let Some(user) = app.db.get_user_by_email(&email).await? else {
524 return Ok(None);
525 };
526
527 let billing_customer = app
528 .db
529 .create_billing_customer(&CreateBillingCustomerParams {
530 user_id: user.id,
531 stripe_customer_id: customer.id.to_string(),
532 })
533 .await?;
534
535 Ok(Some(billing_customer))
536}