billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{
  7    extract::{self, Query},
  8    routing::{get, post},
  9    Extension, Json, Router,
 10};
 11use reqwest::StatusCode;
 12use sea_orm::ActiveValue;
 13use serde::{Deserialize, Serialize};
 14use stripe::{
 15    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 18    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 19    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
 20    SubscriptionStatus,
 21};
 22use util::ResultExt;
 23
 24use crate::db::billing_subscription::StripeSubscriptionStatus;
 25use crate::db::{
 26    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 27    CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
 28    UpdateBillingSubscriptionParams,
 29};
 30use crate::{AppState, Error, Result};
 31
 32pub fn router() -> Router {
 33    Router::new()
 34        .route(
 35            "/billing/subscriptions",
 36            get(list_billing_subscriptions).post(create_billing_subscription),
 37        )
 38        .route(
 39            "/billing/subscriptions/manage",
 40            post(manage_billing_subscription),
 41        )
 42}
 43
 44#[derive(Debug, Deserialize)]
 45struct ListBillingSubscriptionsParams {
 46    github_user_id: i32,
 47}
 48
 49#[derive(Debug, Serialize)]
 50struct BillingSubscriptionJson {
 51    id: BillingSubscriptionId,
 52    name: String,
 53    status: StripeSubscriptionStatus,
 54    /// Whether this subscription can be canceled.
 55    is_cancelable: bool,
 56}
 57
 58#[derive(Debug, Serialize)]
 59struct ListBillingSubscriptionsResponse {
 60    subscriptions: Vec<BillingSubscriptionJson>,
 61}
 62
 63async fn list_billing_subscriptions(
 64    Extension(app): Extension<Arc<AppState>>,
 65    Query(params): Query<ListBillingSubscriptionsParams>,
 66) -> Result<Json<ListBillingSubscriptionsResponse>> {
 67    let user = app
 68        .db
 69        .get_user_by_github_user_id(params.github_user_id)
 70        .await?
 71        .ok_or_else(|| anyhow!("user not found"))?;
 72
 73    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 74
 75    Ok(Json(ListBillingSubscriptionsResponse {
 76        subscriptions: subscriptions
 77            .into_iter()
 78            .map(|subscription| BillingSubscriptionJson {
 79                id: subscription.id,
 80                name: "Zed Pro".to_string(),
 81                status: subscription.stripe_subscription_status,
 82                is_cancelable: subscription.stripe_subscription_status.is_cancelable(),
 83            })
 84            .collect(),
 85    }))
 86}
 87
 88#[derive(Debug, Deserialize)]
 89struct CreateBillingSubscriptionBody {
 90    github_user_id: i32,
 91}
 92
 93#[derive(Debug, Serialize)]
 94struct CreateBillingSubscriptionResponse {
 95    checkout_session_url: String,
 96}
 97
 98/// Initiates a Stripe Checkout session for creating a billing subscription.
 99async fn create_billing_subscription(
100    Extension(app): Extension<Arc<AppState>>,
101    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
102) -> Result<Json<CreateBillingSubscriptionResponse>> {
103    let user = app
104        .db
105        .get_user_by_github_user_id(body.github_user_id)
106        .await?
107        .ok_or_else(|| anyhow!("user not found"))?;
108
109    let Some((stripe_client, stripe_price_id)) = app
110        .stripe_client
111        .clone()
112        .zip(app.config.stripe_price_id.clone())
113    else {
114        log::error!("failed to retrieve Stripe client or price ID");
115        Err(Error::Http(
116            StatusCode::NOT_IMPLEMENTED,
117            "not supported".into(),
118        ))?
119    };
120
121    let customer_id =
122        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
123            CustomerId::from_str(&existing_customer.stripe_customer_id)
124                .context("failed to parse customer ID")?
125        } else {
126            let customer = Customer::create(
127                &stripe_client,
128                CreateCustomer {
129                    email: user.email_address.as_deref(),
130                    ..Default::default()
131                },
132            )
133            .await?;
134
135            customer.id
136        };
137
138    let checkout_session = {
139        let mut params = CreateCheckoutSession::new();
140        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
141        params.customer = Some(customer_id);
142        params.client_reference_id = Some(user.github_login.as_str());
143        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
144            price: Some(stripe_price_id.to_string()),
145            quantity: Some(1),
146            ..Default::default()
147        }]);
148        params.success_url = Some("https://zed.dev/billing/success");
149
150        CheckoutSession::create(&stripe_client, params).await?
151    };
152
153    Ok(Json(CreateBillingSubscriptionResponse {
154        checkout_session_url: checkout_session
155            .url
156            .ok_or_else(|| anyhow!("no checkout session URL"))?,
157    }))
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "snake_case")]
162enum ManageSubscriptionIntent {
163    /// The user intends to cancel their subscription.
164    Cancel,
165}
166
167#[derive(Debug, Deserialize)]
168struct ManageBillingSubscriptionBody {
169    github_user_id: i32,
170    intent: ManageSubscriptionIntent,
171    /// The ID of the subscription to manage.
172    ///
173    /// If not provided, we will try to use the active subscription (if there is only one).
174    subscription_id: Option<BillingSubscriptionId>,
175}
176
177#[derive(Debug, Serialize)]
178struct ManageBillingSubscriptionResponse {
179    billing_portal_session_url: String,
180}
181
182/// Initiates a Stripe customer portal session for managing a billing subscription.
183async fn manage_billing_subscription(
184    Extension(app): Extension<Arc<AppState>>,
185    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
186) -> Result<Json<ManageBillingSubscriptionResponse>> {
187    let user = app
188        .db
189        .get_user_by_github_user_id(body.github_user_id)
190        .await?
191        .ok_or_else(|| anyhow!("user not found"))?;
192
193    let Some(stripe_client) = app.stripe_client.clone() else {
194        log::error!("failed to retrieve Stripe client");
195        Err(Error::Http(
196            StatusCode::NOT_IMPLEMENTED,
197            "not supported".into(),
198        ))?
199    };
200
201    let customer = app
202        .db
203        .get_billing_customer_by_user_id(user.id)
204        .await?
205        .ok_or_else(|| anyhow!("billing customer not found"))?;
206    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
207        .context("failed to parse customer ID")?;
208
209    let subscription = if let Some(subscription_id) = body.subscription_id {
210        app.db
211            .get_billing_subscription_by_id(subscription_id)
212            .await?
213            .ok_or_else(|| anyhow!("subscription not found"))?
214    } else {
215        // If no subscription ID was provided, try to find the only active subscription ID.
216        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
217        if subscriptions.len() > 1 {
218            Err(anyhow!("user has multiple active subscriptions"))?;
219        }
220
221        subscriptions
222            .into_iter()
223            .next()
224            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
225    };
226
227    let flow = match body.intent {
228        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
229            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
230            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
231                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
232                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
233                    return_url: "https://zed.dev/settings".into(),
234                }),
235                ..Default::default()
236            }),
237            subscription_cancel: Some(
238                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
239                    subscription: subscription.stripe_subscription_id,
240                    retention: None,
241                },
242            ),
243            ..Default::default()
244        },
245    };
246
247    let mut params = CreateBillingPortalSession::new(customer_id);
248    params.flow_data = Some(flow);
249    params.return_url = Some("https://zed.dev/settings");
250
251    let session = BillingPortalSession::create(&stripe_client, params).await?;
252
253    Ok(Json(ManageBillingSubscriptionResponse {
254        billing_portal_session_url: session.url,
255    }))
256}
257
258const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
259
260/// Polls the Stripe events API periodically to reconcile the records in our
261/// database with the data in Stripe.
262pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
263    let Some(stripe_client) = app.stripe_client.clone() else {
264        log::warn!("failed to retrieve Stripe client");
265        return;
266    };
267
268    let executor = app.executor.clone();
269    executor.spawn_detached({
270        let executor = executor.clone();
271        async move {
272            loop {
273                poll_stripe_events(&app, &stripe_client).await.log_err();
274
275                executor.sleep(POLL_EVENTS_INTERVAL).await;
276            }
277        }
278    });
279}
280
281async fn poll_stripe_events(
282    app: &Arc<AppState>,
283    stripe_client: &stripe::Client,
284) -> anyhow::Result<()> {
285    fn event_type_to_string(event_type: EventType) -> String {
286        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
287        // so we need to unquote it.
288        event_type.to_string().trim_matches('"').to_string()
289    }
290
291    let event_types = [
292        EventType::CustomerCreated,
293        EventType::CustomerUpdated,
294        EventType::CustomerSubscriptionCreated,
295        EventType::CustomerSubscriptionUpdated,
296        EventType::CustomerSubscriptionPaused,
297        EventType::CustomerSubscriptionResumed,
298        EventType::CustomerSubscriptionDeleted,
299    ]
300    .into_iter()
301    .map(event_type_to_string)
302    .collect::<Vec<_>>();
303
304    let mut unprocessed_events = Vec::new();
305
306    loop {
307        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
308
309        let mut params = ListEvents::new();
310        params.types = Some(event_types.clone());
311        params.limit = Some(100);
312
313        let events = stripe::Event::list(stripe_client, &params).await?;
314
315        let processed_event_ids = {
316            let event_ids = &events
317                .data
318                .iter()
319                .map(|event| event.id.as_str())
320                .collect::<Vec<_>>();
321
322            app.db
323                .get_processed_stripe_events_by_event_ids(event_ids)
324                .await?
325                .into_iter()
326                .map(|event| event.stripe_event_id)
327                .collect::<Vec<_>>()
328        };
329
330        for event in events.data {
331            if processed_event_ids.contains(&event.id.to_string()) {
332                log::info!("Stripe event {} already processed: skipping", event.id);
333            } else {
334                unprocessed_events.push(event);
335            }
336        }
337
338        if !events.has_more {
339            break;
340        }
341    }
342
343    log::info!(
344        "unprocessed events from Stripe: {}",
345        unprocessed_events.len()
346    );
347
348    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
349    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
350
351    for event in unprocessed_events {
352        let processed_event_params = CreateProcessedStripeEventParams {
353            stripe_event_id: event.id.to_string(),
354            stripe_event_type: event_type_to_string(event.type_),
355            stripe_event_created_timestamp: event.created,
356        };
357
358        match event.type_ {
359            EventType::CustomerCreated | EventType::CustomerUpdated => {
360                handle_customer_event(app, stripe_client, event)
361                    .await
362                    .log_err();
363            }
364            EventType::CustomerSubscriptionCreated
365            | EventType::CustomerSubscriptionUpdated
366            | EventType::CustomerSubscriptionPaused
367            | EventType::CustomerSubscriptionResumed
368            | EventType::CustomerSubscriptionDeleted => {
369                handle_customer_subscription_event(app, stripe_client, event)
370                    .await
371                    .log_err();
372            }
373            _ => {}
374        }
375
376        app.db
377            .create_processed_stripe_event(&processed_event_params)
378            .await?;
379    }
380
381    Ok(())
382}
383
384async fn handle_customer_event(
385    app: &Arc<AppState>,
386    _stripe_client: &stripe::Client,
387    event: stripe::Event,
388) -> anyhow::Result<()> {
389    let EventObject::Customer(customer) = event.data.object else {
390        bail!("unexpected event payload for {}", event.id);
391    };
392
393    log::info!("handling Stripe {} event: {}", event.type_, event.id);
394
395    let Some(email) = customer.email else {
396        log::info!("Stripe customer has no email: skipping");
397        return Ok(());
398    };
399
400    let Some(user) = app.db.get_user_by_email(&email).await? else {
401        log::info!("no user found for email: skipping");
402        return Ok(());
403    };
404
405    if let Some(existing_customer) = app
406        .db
407        .get_billing_customer_by_stripe_customer_id(&customer.id)
408        .await?
409    {
410        app.db
411            .update_billing_customer(
412                existing_customer.id,
413                &UpdateBillingCustomerParams {
414                    // For now we just leave the information as-is, as it is not
415                    // likely to change.
416                    ..Default::default()
417                },
418            )
419            .await?;
420    } else {
421        app.db
422            .create_billing_customer(&CreateBillingCustomerParams {
423                user_id: user.id,
424                stripe_customer_id: customer.id.to_string(),
425            })
426            .await?;
427    }
428
429    Ok(())
430}
431
432async fn handle_customer_subscription_event(
433    app: &Arc<AppState>,
434    stripe_client: &stripe::Client,
435    event: stripe::Event,
436) -> anyhow::Result<()> {
437    let EventObject::Subscription(subscription) = event.data.object else {
438        bail!("unexpected event payload for {}", event.id);
439    };
440
441    log::info!("handling Stripe {} event: {}", event.type_, event.id);
442
443    let billing_customer =
444        find_or_create_billing_customer(app, stripe_client, subscription.customer)
445            .await?
446            .ok_or_else(|| anyhow!("billing customer not found"))?;
447
448    if let Some(existing_subscription) = app
449        .db
450        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
451        .await?
452    {
453        app.db
454            .update_billing_subscription(
455                existing_subscription.id,
456                &UpdateBillingSubscriptionParams {
457                    billing_customer_id: ActiveValue::set(billing_customer.id),
458                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
459                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
460                },
461            )
462            .await?;
463    } else {
464        app.db
465            .create_billing_subscription(&CreateBillingSubscriptionParams {
466                billing_customer_id: billing_customer.id,
467                stripe_subscription_id: subscription.id.to_string(),
468                stripe_subscription_status: subscription.status.into(),
469            })
470            .await?;
471    }
472
473    Ok(())
474}
475
476impl From<SubscriptionStatus> for StripeSubscriptionStatus {
477    fn from(value: SubscriptionStatus) -> Self {
478        match value {
479            SubscriptionStatus::Incomplete => Self::Incomplete,
480            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
481            SubscriptionStatus::Trialing => Self::Trialing,
482            SubscriptionStatus::Active => Self::Active,
483            SubscriptionStatus::PastDue => Self::PastDue,
484            SubscriptionStatus::Canceled => Self::Canceled,
485            SubscriptionStatus::Unpaid => Self::Unpaid,
486            SubscriptionStatus::Paused => Self::Paused,
487        }
488    }
489}
490
491/// Finds or creates a billing customer using the provided customer.
492async fn find_or_create_billing_customer(
493    app: &Arc<AppState>,
494    stripe_client: &stripe::Client,
495    customer_or_id: Expandable<Customer>,
496) -> anyhow::Result<Option<billing_customer::Model>> {
497    let customer_id = match &customer_or_id {
498        Expandable::Id(id) => id,
499        Expandable::Object(customer) => customer.id.as_ref(),
500    };
501
502    // If we already have a billing customer record associated with the Stripe customer,
503    // there's nothing more we need to do.
504    if let Some(billing_customer) = app
505        .db
506        .get_billing_customer_by_stripe_customer_id(&customer_id)
507        .await?
508    {
509        return Ok(Some(billing_customer));
510    }
511
512    // If all we have is a customer ID, resolve it to a full customer record by
513    // hitting the Stripe API.
514    let customer = match customer_or_id {
515        Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
516        Expandable::Object(customer) => *customer,
517    };
518
519    let Some(email) = customer.email else {
520        return Ok(None);
521    };
522
523    let Some(user) = app.db.get_user_by_email(&email).await? else {
524        return Ok(None);
525    };
526
527    let billing_customer = app
528        .db
529        .create_billing_customer(&CreateBillingCustomerParams {
530            user_id: user.id,
531            stripe_customer_id: customer.id.to_string(),
532        })
533        .await?;
534
535    Ok(Some(billing_customer))
536}