billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{extract, routing::post, Extension, Json, Router};
  7use reqwest::StatusCode;
  8use sea_orm::ActiveValue;
  9use serde::{Deserialize, Serialize};
 10use stripe::{
 11    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 12    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 13    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 14    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 15    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
 16    SubscriptionStatus,
 17};
 18use util::ResultExt;
 19
 20use crate::db::billing_subscription::StripeSubscriptionStatus;
 21use crate::db::{
 22    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 23    CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
 24    UpdateBillingSubscriptionParams,
 25};
 26use crate::{AppState, Error, Result};
 27
 28pub fn router() -> Router {
 29    Router::new()
 30        .route("/billing/subscriptions", post(create_billing_subscription))
 31        .route(
 32            "/billing/subscriptions/manage",
 33            post(manage_billing_subscription),
 34        )
 35}
 36
 37#[derive(Debug, Deserialize)]
 38struct CreateBillingSubscriptionBody {
 39    github_user_id: i32,
 40}
 41
 42#[derive(Debug, Serialize)]
 43struct CreateBillingSubscriptionResponse {
 44    checkout_session_url: String,
 45}
 46
 47/// Initiates a Stripe Checkout session for creating a billing subscription.
 48async fn create_billing_subscription(
 49    Extension(app): Extension<Arc<AppState>>,
 50    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 51) -> Result<Json<CreateBillingSubscriptionResponse>> {
 52    let user = app
 53        .db
 54        .get_user_by_github_user_id(body.github_user_id)
 55        .await?
 56        .ok_or_else(|| anyhow!("user not found"))?;
 57
 58    let Some((stripe_client, stripe_price_id)) = app
 59        .stripe_client
 60        .clone()
 61        .zip(app.config.stripe_price_id.clone())
 62    else {
 63        log::error!("failed to retrieve Stripe client or price ID");
 64        Err(Error::Http(
 65            StatusCode::NOT_IMPLEMENTED,
 66            "not supported".into(),
 67        ))?
 68    };
 69
 70    let customer_id =
 71        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
 72            CustomerId::from_str(&existing_customer.stripe_customer_id)
 73                .context("failed to parse customer ID")?
 74        } else {
 75            let customer = Customer::create(
 76                &stripe_client,
 77                CreateCustomer {
 78                    email: user.email_address.as_deref(),
 79                    ..Default::default()
 80                },
 81            )
 82            .await?;
 83
 84            customer.id
 85        };
 86
 87    let checkout_session = {
 88        let mut params = CreateCheckoutSession::new();
 89        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
 90        params.customer = Some(customer_id);
 91        params.client_reference_id = Some(user.github_login.as_str());
 92        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
 93            price: Some(stripe_price_id.to_string()),
 94            quantity: Some(1),
 95            ..Default::default()
 96        }]);
 97        params.success_url = Some("https://zed.dev/billing/success");
 98
 99        CheckoutSession::create(&stripe_client, params).await?
100    };
101
102    Ok(Json(CreateBillingSubscriptionResponse {
103        checkout_session_url: checkout_session
104            .url
105            .ok_or_else(|| anyhow!("no checkout session URL"))?,
106    }))
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(rename_all = "snake_case")]
111enum ManageSubscriptionIntent {
112    /// The user intends to cancel their subscription.
113    Cancel,
114}
115
116#[derive(Debug, Deserialize)]
117struct ManageBillingSubscriptionBody {
118    github_user_id: i32,
119    intent: ManageSubscriptionIntent,
120    /// The ID of the subscription to manage.
121    ///
122    /// If not provided, we will try to use the active subscription (if there is only one).
123    subscription_id: Option<BillingSubscriptionId>,
124}
125
126#[derive(Debug, Serialize)]
127struct ManageBillingSubscriptionResponse {
128    billing_portal_session_url: String,
129}
130
131/// Initiates a Stripe customer portal session for managing a billing subscription.
132async fn manage_billing_subscription(
133    Extension(app): Extension<Arc<AppState>>,
134    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
135) -> Result<Json<ManageBillingSubscriptionResponse>> {
136    let user = app
137        .db
138        .get_user_by_github_user_id(body.github_user_id)
139        .await?
140        .ok_or_else(|| anyhow!("user not found"))?;
141
142    let Some(stripe_client) = app.stripe_client.clone() else {
143        log::error!("failed to retrieve Stripe client");
144        Err(Error::Http(
145            StatusCode::NOT_IMPLEMENTED,
146            "not supported".into(),
147        ))?
148    };
149
150    let customer = app
151        .db
152        .get_billing_customer_by_user_id(user.id)
153        .await?
154        .ok_or_else(|| anyhow!("billing customer not found"))?;
155    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
156        .context("failed to parse customer ID")?;
157
158    let subscription = if let Some(subscription_id) = body.subscription_id {
159        app.db
160            .get_billing_subscription_by_id(subscription_id)
161            .await?
162            .ok_or_else(|| anyhow!("subscription not found"))?
163    } else {
164        // If no subscription ID was provided, try to find the only active subscription ID.
165        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
166        if subscriptions.len() > 1 {
167            Err(anyhow!("user has multiple active subscriptions"))?;
168        }
169
170        subscriptions
171            .into_iter()
172            .next()
173            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
174    };
175
176    let flow = match body.intent {
177        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
178            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
179            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
180                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
181                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
182                    return_url: "https://zed.dev/billing".into(),
183                }),
184                ..Default::default()
185            }),
186            subscription_cancel: Some(
187                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
188                    subscription: subscription.stripe_subscription_id,
189                    retention: None,
190                },
191            ),
192            ..Default::default()
193        },
194    };
195
196    let mut params = CreateBillingPortalSession::new(customer_id);
197    params.flow_data = Some(flow);
198    params.return_url = Some("https://zed.dev/billing");
199
200    let session = BillingPortalSession::create(&stripe_client, params).await?;
201
202    Ok(Json(ManageBillingSubscriptionResponse {
203        billing_portal_session_url: session.url,
204    }))
205}
206
207const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
208
209/// Polls the Stripe events API periodically to reconcile the records in our
210/// database with the data in Stripe.
211pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
212    let Some(stripe_client) = app.stripe_client.clone() else {
213        log::warn!("failed to retrieve Stripe client");
214        return;
215    };
216
217    let executor = app.executor.clone();
218    executor.spawn_detached({
219        let executor = executor.clone();
220        async move {
221            loop {
222                poll_stripe_events(&app, &stripe_client).await.log_err();
223
224                executor.sleep(POLL_EVENTS_INTERVAL).await;
225            }
226        }
227    });
228}
229
230async fn poll_stripe_events(
231    app: &Arc<AppState>,
232    stripe_client: &stripe::Client,
233) -> anyhow::Result<()> {
234    fn event_type_to_string(event_type: EventType) -> String {
235        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
236        // so we need to unquote it.
237        event_type.to_string().trim_matches('"').to_string()
238    }
239
240    let event_types = [
241        EventType::CustomerCreated,
242        EventType::CustomerUpdated,
243        EventType::CustomerSubscriptionCreated,
244        EventType::CustomerSubscriptionUpdated,
245        EventType::CustomerSubscriptionPaused,
246        EventType::CustomerSubscriptionResumed,
247        EventType::CustomerSubscriptionDeleted,
248    ]
249    .into_iter()
250    .map(event_type_to_string)
251    .collect::<Vec<_>>();
252
253    let mut unprocessed_events = Vec::new();
254
255    loop {
256        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
257
258        let mut params = ListEvents::new();
259        params.types = Some(event_types.clone());
260        params.limit = Some(100);
261
262        let events = stripe::Event::list(stripe_client, &params).await?;
263
264        let processed_event_ids = {
265            let event_ids = &events
266                .data
267                .iter()
268                .map(|event| event.id.as_str())
269                .collect::<Vec<_>>();
270
271            app.db
272                .get_processed_stripe_events_by_event_ids(event_ids)
273                .await?
274                .into_iter()
275                .map(|event| event.stripe_event_id)
276                .collect::<Vec<_>>()
277        };
278
279        for event in events.data {
280            if processed_event_ids.contains(&event.id.to_string()) {
281                log::info!("Stripe event {} already processed: skipping", event.id);
282            } else {
283                unprocessed_events.push(event);
284            }
285        }
286
287        if !events.has_more {
288            break;
289        }
290    }
291
292    log::info!(
293        "unprocessed events from Stripe: {}",
294        unprocessed_events.len()
295    );
296
297    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
298    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
299
300    for event in unprocessed_events {
301        let processed_event_params = CreateProcessedStripeEventParams {
302            stripe_event_id: event.id.to_string(),
303            stripe_event_type: event_type_to_string(event.type_),
304            stripe_event_created_timestamp: event.created,
305        };
306
307        match event.type_ {
308            EventType::CustomerCreated | EventType::CustomerUpdated => {
309                handle_customer_event(app, stripe_client, event)
310                    .await
311                    .log_err();
312            }
313            EventType::CustomerSubscriptionCreated
314            | EventType::CustomerSubscriptionUpdated
315            | EventType::CustomerSubscriptionPaused
316            | EventType::CustomerSubscriptionResumed
317            | EventType::CustomerSubscriptionDeleted => {
318                handle_customer_subscription_event(app, stripe_client, event)
319                    .await
320                    .log_err();
321            }
322            _ => {}
323        }
324
325        app.db
326            .create_processed_stripe_event(&processed_event_params)
327            .await?;
328    }
329
330    Ok(())
331}
332
333async fn handle_customer_event(
334    app: &Arc<AppState>,
335    _stripe_client: &stripe::Client,
336    event: stripe::Event,
337) -> anyhow::Result<()> {
338    let EventObject::Customer(customer) = event.data.object else {
339        bail!("unexpected event payload for {}", event.id);
340    };
341
342    log::info!("handling Stripe {} event: {}", event.type_, event.id);
343
344    let Some(email) = customer.email else {
345        log::info!("Stripe customer has no email: skipping");
346        return Ok(());
347    };
348
349    let Some(user) = app.db.get_user_by_email(&email).await? else {
350        log::info!("no user found for email: skipping");
351        return Ok(());
352    };
353
354    if let Some(existing_customer) = app
355        .db
356        .get_billing_customer_by_stripe_customer_id(&customer.id)
357        .await?
358    {
359        app.db
360            .update_billing_customer(
361                existing_customer.id,
362                &UpdateBillingCustomerParams {
363                    // For now we just leave the information as-is, as it is not
364                    // likely to change.
365                    ..Default::default()
366                },
367            )
368            .await?;
369    } else {
370        app.db
371            .create_billing_customer(&CreateBillingCustomerParams {
372                user_id: user.id,
373                stripe_customer_id: customer.id.to_string(),
374            })
375            .await?;
376    }
377
378    Ok(())
379}
380
381async fn handle_customer_subscription_event(
382    app: &Arc<AppState>,
383    stripe_client: &stripe::Client,
384    event: stripe::Event,
385) -> anyhow::Result<()> {
386    let EventObject::Subscription(subscription) = event.data.object else {
387        bail!("unexpected event payload for {}", event.id);
388    };
389
390    log::info!("handling Stripe {} event: {}", event.type_, event.id);
391
392    let billing_customer =
393        find_or_create_billing_customer(app, stripe_client, subscription.customer)
394            .await?
395            .ok_or_else(|| anyhow!("billing customer not found"))?;
396
397    if let Some(existing_subscription) = app
398        .db
399        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
400        .await?
401    {
402        app.db
403            .update_billing_subscription(
404                existing_subscription.id,
405                &UpdateBillingSubscriptionParams {
406                    billing_customer_id: ActiveValue::set(billing_customer.id),
407                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
408                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
409                },
410            )
411            .await?;
412    } else {
413        app.db
414            .create_billing_subscription(&CreateBillingSubscriptionParams {
415                billing_customer_id: billing_customer.id,
416                stripe_subscription_id: subscription.id.to_string(),
417                stripe_subscription_status: subscription.status.into(),
418            })
419            .await?;
420    }
421
422    Ok(())
423}
424
425impl From<SubscriptionStatus> for StripeSubscriptionStatus {
426    fn from(value: SubscriptionStatus) -> Self {
427        match value {
428            SubscriptionStatus::Incomplete => Self::Incomplete,
429            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
430            SubscriptionStatus::Trialing => Self::Trialing,
431            SubscriptionStatus::Active => Self::Active,
432            SubscriptionStatus::PastDue => Self::PastDue,
433            SubscriptionStatus::Canceled => Self::Canceled,
434            SubscriptionStatus::Unpaid => Self::Unpaid,
435            SubscriptionStatus::Paused => Self::Paused,
436        }
437    }
438}
439
440/// Finds or creates a billing customer using the provided customer.
441async fn find_or_create_billing_customer(
442    app: &Arc<AppState>,
443    stripe_client: &stripe::Client,
444    customer_or_id: Expandable<Customer>,
445) -> anyhow::Result<Option<billing_customer::Model>> {
446    let customer_id = match &customer_or_id {
447        Expandable::Id(id) => id,
448        Expandable::Object(customer) => customer.id.as_ref(),
449    };
450
451    // If we already have a billing customer record associated with the Stripe customer,
452    // there's nothing more we need to do.
453    if let Some(billing_customer) = app
454        .db
455        .get_billing_customer_by_stripe_customer_id(&customer_id)
456        .await?
457    {
458        return Ok(Some(billing_customer));
459    }
460
461    // If all we have is a customer ID, resolve it to a full customer record by
462    // hitting the Stripe API.
463    let customer = match customer_or_id {
464        Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
465        Expandable::Object(customer) => *customer,
466    };
467
468    let Some(email) = customer.email else {
469        return Ok(None);
470    };
471
472    let Some(user) = app.db.get_user_by_email(&email).await? else {
473        return Ok(None);
474    };
475
476    let billing_customer = app
477        .db
478        .create_billing_customer(&CreateBillingCustomerParams {
479            user_id: user.id,
480            stripe_customer_id: customer.id.to_string(),
481        })
482        .await?;
483
484    Ok(Some(billing_customer))
485}