billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{extract, routing::post, Extension, Json, Router};
  7use reqwest::StatusCode;
  8use sea_orm::ActiveValue;
  9use serde::{Deserialize, Serialize};
 10use stripe::{
 11    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 12    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 13    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 14    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 15    CreateCustomer, Customer, CustomerId, EventId, EventObject, EventType, Expandable, ListEvents,
 16    SubscriptionStatus,
 17};
 18use util::ResultExt;
 19
 20use crate::db::billing_subscription::StripeSubscriptionStatus;
 21use crate::db::{
 22    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 23    CreateBillingSubscriptionParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams,
 24};
 25use crate::{AppState, Error, Result};
 26
 27pub fn router() -> Router {
 28    Router::new()
 29        .route("/billing/subscriptions", post(create_billing_subscription))
 30        .route(
 31            "/billing/subscriptions/manage",
 32            post(manage_billing_subscription),
 33        )
 34}
 35
 36#[derive(Debug, Deserialize)]
 37struct CreateBillingSubscriptionBody {
 38    github_user_id: i32,
 39}
 40
 41#[derive(Debug, Serialize)]
 42struct CreateBillingSubscriptionResponse {
 43    checkout_session_url: String,
 44}
 45
 46/// Initiates a Stripe Checkout session for creating a billing subscription.
 47async fn create_billing_subscription(
 48    Extension(app): Extension<Arc<AppState>>,
 49    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 50) -> Result<Json<CreateBillingSubscriptionResponse>> {
 51    let user = app
 52        .db
 53        .get_user_by_github_user_id(body.github_user_id)
 54        .await?
 55        .ok_or_else(|| anyhow!("user not found"))?;
 56
 57    let Some((stripe_client, stripe_price_id)) = app
 58        .stripe_client
 59        .clone()
 60        .zip(app.config.stripe_price_id.clone())
 61    else {
 62        log::error!("failed to retrieve Stripe client or price ID");
 63        Err(Error::Http(
 64            StatusCode::NOT_IMPLEMENTED,
 65            "not supported".into(),
 66        ))?
 67    };
 68
 69    let customer_id =
 70        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
 71            CustomerId::from_str(&existing_customer.stripe_customer_id)
 72                .context("failed to parse customer ID")?
 73        } else {
 74            let customer = Customer::create(
 75                &stripe_client,
 76                CreateCustomer {
 77                    email: user.email_address.as_deref(),
 78                    ..Default::default()
 79                },
 80            )
 81            .await?;
 82
 83            customer.id
 84        };
 85
 86    let checkout_session = {
 87        let mut params = CreateCheckoutSession::new();
 88        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
 89        params.customer = Some(customer_id);
 90        params.client_reference_id = Some(user.github_login.as_str());
 91        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
 92            price: Some(stripe_price_id.to_string()),
 93            quantity: Some(1),
 94            ..Default::default()
 95        }]);
 96        params.success_url = Some("https://zed.dev/billing/success");
 97
 98        CheckoutSession::create(&stripe_client, params).await?
 99    };
100
101    Ok(Json(CreateBillingSubscriptionResponse {
102        checkout_session_url: checkout_session
103            .url
104            .ok_or_else(|| anyhow!("no checkout session URL"))?,
105    }))
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(rename_all = "snake_case")]
110enum ManageSubscriptionIntent {
111    /// The user intends to cancel their subscription.
112    Cancel,
113}
114
115#[derive(Debug, Deserialize)]
116struct ManageBillingSubscriptionBody {
117    github_user_id: i32,
118    intent: ManageSubscriptionIntent,
119    /// The ID of the subscription to manage.
120    ///
121    /// If not provided, we will try to use the active subscription (if there is only one).
122    subscription_id: Option<BillingSubscriptionId>,
123}
124
125#[derive(Debug, Serialize)]
126struct ManageBillingSubscriptionResponse {
127    billing_portal_session_url: String,
128}
129
130/// Initiates a Stripe customer portal session for managing a billing subscription.
131async fn manage_billing_subscription(
132    Extension(app): Extension<Arc<AppState>>,
133    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
134) -> Result<Json<ManageBillingSubscriptionResponse>> {
135    let user = app
136        .db
137        .get_user_by_github_user_id(body.github_user_id)
138        .await?
139        .ok_or_else(|| anyhow!("user not found"))?;
140
141    let Some(stripe_client) = app.stripe_client.clone() else {
142        log::error!("failed to retrieve Stripe client");
143        Err(Error::Http(
144            StatusCode::NOT_IMPLEMENTED,
145            "not supported".into(),
146        ))?
147    };
148
149    let customer = app
150        .db
151        .get_billing_customer_by_user_id(user.id)
152        .await?
153        .ok_or_else(|| anyhow!("billing customer not found"))?;
154    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
155        .context("failed to parse customer ID")?;
156
157    let subscription = if let Some(subscription_id) = body.subscription_id {
158        app.db
159            .get_billing_subscription_by_id(subscription_id)
160            .await?
161            .ok_or_else(|| anyhow!("subscription not found"))?
162    } else {
163        // If no subscription ID was provided, try to find the only active subscription ID.
164        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
165        if subscriptions.len() > 1 {
166            Err(anyhow!("user has multiple active subscriptions"))?;
167        }
168
169        subscriptions
170            .into_iter()
171            .next()
172            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
173    };
174
175    let flow = match body.intent {
176        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
177            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
178            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
179                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
180                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
181                    return_url: "https://zed.dev/billing".into(),
182                }),
183                ..Default::default()
184            }),
185            subscription_cancel: Some(
186                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
187                    subscription: subscription.stripe_subscription_id,
188                    retention: None,
189                },
190            ),
191            ..Default::default()
192        },
193    };
194
195    let mut params = CreateBillingPortalSession::new(customer_id);
196    params.flow_data = Some(flow);
197    params.return_url = Some("https://zed.dev/billing");
198
199    let session = BillingPortalSession::create(&stripe_client, params).await?;
200
201    Ok(Json(ManageBillingSubscriptionResponse {
202        billing_portal_session_url: session.url,
203    }))
204}
205
206const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
207
208/// Polls the Stripe events API periodically to reconcile the records in our
209/// database with the data in Stripe.
210pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
211    let Some(stripe_client) = app.stripe_client.clone() else {
212        log::warn!("failed to retrieve Stripe client");
213        return;
214    };
215
216    let executor = app.executor.clone();
217    executor.spawn_detached({
218        let executor = executor.clone();
219        async move {
220            loop {
221                poll_stripe_events(&app, &stripe_client).await.log_err();
222
223                executor.sleep(POLL_EVENTS_INTERVAL).await;
224            }
225        }
226    });
227}
228
229async fn poll_stripe_events(
230    app: &Arc<AppState>,
231    stripe_client: &stripe::Client,
232) -> anyhow::Result<()> {
233    let event_types = [
234        EventType::CustomerCreated.to_string(),
235        EventType::CustomerUpdated.to_string(),
236        EventType::CustomerSubscriptionCreated.to_string(),
237        EventType::CustomerSubscriptionUpdated.to_string(),
238        EventType::CustomerSubscriptionPaused.to_string(),
239        EventType::CustomerSubscriptionResumed.to_string(),
240        EventType::CustomerSubscriptionDeleted.to_string(),
241    ]
242    .into_iter()
243    .map(|event_type| {
244        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
245        // so we need to unquote it.
246        event_type.trim_matches('"').to_string()
247    })
248    .collect::<Vec<_>>();
249
250    loop {
251        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
252
253        let mut params = ListEvents::new();
254        params.types = Some(event_types.clone());
255        params.limit = Some(100);
256
257        let events = stripe::Event::list(stripe_client, &params).await?;
258        for event in events.data {
259            match event.type_ {
260                EventType::CustomerCreated | EventType::CustomerUpdated => {
261                    handle_customer_event(app, stripe_client, event)
262                        .await
263                        .log_err();
264                }
265                EventType::CustomerSubscriptionCreated
266                | EventType::CustomerSubscriptionUpdated
267                | EventType::CustomerSubscriptionPaused
268                | EventType::CustomerSubscriptionResumed
269                | EventType::CustomerSubscriptionDeleted => {
270                    handle_customer_subscription_event(app, stripe_client, event)
271                        .await
272                        .log_err();
273                }
274                _ => {}
275            }
276        }
277
278        if !events.has_more {
279            break;
280        }
281    }
282
283    Ok(())
284}
285
286async fn handle_customer_event(
287    app: &Arc<AppState>,
288    _stripe_client: &stripe::Client,
289    event: stripe::Event,
290) -> anyhow::Result<()> {
291    let EventObject::Customer(customer) = event.data.object else {
292        bail!("unexpected event payload for {}", event.id);
293    };
294
295    log::info!("handling Stripe {} event: {}", event.type_, event.id);
296
297    let Some(email) = customer.email else {
298        log::info!("Stripe customer has no email: skipping");
299        return Ok(());
300    };
301
302    let Some(user) = app.db.get_user_by_email(&email).await? else {
303        log::info!("no user found for email: skipping");
304        return Ok(());
305    };
306
307    if let Some(existing_customer) = app
308        .db
309        .get_billing_customer_by_stripe_customer_id(&customer.id)
310        .await?
311    {
312        if should_ignore_event(&event.id, existing_customer.last_stripe_event_id.as_deref()) {
313            log::info!(
314                "ignoring Stripe event {} based on last seen event ID",
315                event.id
316            );
317            return Ok(());
318        }
319
320        app.db
321            .update_billing_customer(
322                existing_customer.id,
323                &UpdateBillingCustomerParams {
324                    // For now we just update the last event ID for the customer
325                    // and leave the rest of the information as-is, as it is not
326                    // likely to change.
327                    last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
328                    ..Default::default()
329                },
330            )
331            .await?;
332    } else {
333        app.db
334            .create_billing_customer(&CreateBillingCustomerParams {
335                user_id: user.id,
336                stripe_customer_id: customer.id.to_string(),
337                last_stripe_event_id: Some(event.id.to_string()),
338            })
339            .await?;
340    }
341
342    Ok(())
343}
344
345async fn handle_customer_subscription_event(
346    app: &Arc<AppState>,
347    stripe_client: &stripe::Client,
348    event: stripe::Event,
349) -> anyhow::Result<()> {
350    let EventObject::Subscription(subscription) = event.data.object else {
351        bail!("unexpected event payload for {}", event.id);
352    };
353
354    log::info!("handling Stripe {} event: {}", event.type_, event.id);
355
356    let billing_customer = find_or_create_billing_customer(
357        app,
358        stripe_client,
359        // Even though we're handling a subscription event, we can still set
360        // the ID as the last seen event ID on the customer in the event that
361        // we have to create it.
362        //
363        // This is done to avoid any potential rollback in the customer's values
364        // if we then see an older event that pertains to the customer.
365        &event.id,
366        subscription.customer,
367    )
368    .await?
369    .ok_or_else(|| anyhow!("billing customer not found"))?;
370
371    if let Some(existing_subscription) = app
372        .db
373        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
374        .await?
375    {
376        if should_ignore_event(
377            &event.id,
378            existing_subscription.last_stripe_event_id.as_deref(),
379        ) {
380            log::info!(
381                "ignoring Stripe event {} based on last seen event ID",
382                event.id
383            );
384            return Ok(());
385        }
386
387        app.db
388            .update_billing_subscription(
389                existing_subscription.id,
390                &UpdateBillingSubscriptionParams {
391                    billing_customer_id: ActiveValue::set(billing_customer.id),
392                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
393                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
394                    last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
395                },
396            )
397            .await?;
398    } else {
399        app.db
400            .create_billing_subscription(&CreateBillingSubscriptionParams {
401                billing_customer_id: billing_customer.id,
402                stripe_subscription_id: subscription.id.to_string(),
403                stripe_subscription_status: subscription.status.into(),
404                last_stripe_event_id: Some(event.id.to_string()),
405            })
406            .await?;
407    }
408
409    Ok(())
410}
411
412impl From<SubscriptionStatus> for StripeSubscriptionStatus {
413    fn from(value: SubscriptionStatus) -> Self {
414        match value {
415            SubscriptionStatus::Incomplete => Self::Incomplete,
416            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
417            SubscriptionStatus::Trialing => Self::Trialing,
418            SubscriptionStatus::Active => Self::Active,
419            SubscriptionStatus::PastDue => Self::PastDue,
420            SubscriptionStatus::Canceled => Self::Canceled,
421            SubscriptionStatus::Unpaid => Self::Unpaid,
422            SubscriptionStatus::Paused => Self::Paused,
423        }
424    }
425}
426
427/// Finds or creates a billing customer using the provided customer.
428async fn find_or_create_billing_customer(
429    app: &Arc<AppState>,
430    stripe_client: &stripe::Client,
431    event_id: &EventId,
432    customer_or_id: Expandable<Customer>,
433) -> anyhow::Result<Option<billing_customer::Model>> {
434    let customer_id = match &customer_or_id {
435        Expandable::Id(id) => id,
436        Expandable::Object(customer) => customer.id.as_ref(),
437    };
438
439    // If we already have a billing customer record associated with the Stripe customer,
440    // there's nothing more we need to do.
441    if let Some(billing_customer) = app
442        .db
443        .get_billing_customer_by_stripe_customer_id(&customer_id)
444        .await?
445    {
446        return Ok(Some(billing_customer));
447    }
448
449    // If all we have is a customer ID, resolve it to a full customer record by
450    // hitting the Stripe API.
451    let customer = match customer_or_id {
452        Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
453        Expandable::Object(customer) => *customer,
454    };
455
456    let Some(email) = customer.email else {
457        return Ok(None);
458    };
459
460    let Some(user) = app.db.get_user_by_email(&email).await? else {
461        return Ok(None);
462    };
463
464    let billing_customer = app
465        .db
466        .create_billing_customer(&CreateBillingCustomerParams {
467            user_id: user.id,
468            stripe_customer_id: customer.id.to_string(),
469            last_stripe_event_id: Some(event_id.to_string()),
470        })
471        .await?;
472
473    Ok(Some(billing_customer))
474}
475
476/// Returns whether an [`Event`] should be ignored, based on its ID and the last
477/// seen event ID for this object.
478#[inline]
479fn should_ignore_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
480    !should_apply_event(event_id, last_event_id)
481}
482
483/// Returns whether an [`Event`] should be applied, based on its ID and the last
484/// seen event ID for this object.
485fn should_apply_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
486    let Some(last_event_id) = last_event_id else {
487        return true;
488    };
489
490    event_id.as_str() < last_event_id
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_should_apply_event() {
499        let subscription_created_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNafuZSGsmh").unwrap();
500        let subscription_updated_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNa5UZLSsto").unwrap();
501
502        assert_eq!(
503            should_apply_event(
504                &subscription_created_event,
505                Some(subscription_created_event.as_str())
506            ),
507            false,
508            "Events should not be applied when the IDs are the same."
509        );
510
511        assert_eq!(
512            should_apply_event(
513                &subscription_created_event,
514                Some(subscription_updated_event.as_str())
515            ),
516            false,
517            "Events should not be applied when the last event ID is newer than the event ID."
518        );
519
520        assert_eq!(
521            should_apply_event(&subscription_created_event, None),
522            true,
523            "Events should be applied when we don't have a last event ID."
524        );
525
526        assert_eq!(
527            should_apply_event(
528                &subscription_updated_event,
529                Some(subscription_created_event.as_str())
530            ),
531            true,
532            "Events should be applied when the event ID is newer than the last event ID."
533        );
534    }
535}