billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{
  7    extract::{self, Query},
  8    routing::{get, post},
  9    Extension, Json, Router,
 10};
 11use reqwest::StatusCode;
 12use sea_orm::ActiveValue;
 13use serde::{Deserialize, Serialize};
 14use stripe::{
 15    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 18    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 19    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
 20    SubscriptionStatus,
 21};
 22use util::ResultExt;
 23
 24use crate::db::billing_subscription::StripeSubscriptionStatus;
 25use crate::db::{
 26    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 27    CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
 28    UpdateBillingSubscriptionParams,
 29};
 30use crate::{AppState, Error, Result};
 31
 32pub fn router() -> Router {
 33    Router::new()
 34        .route(
 35            "/billing/subscriptions",
 36            get(list_billing_subscriptions).post(create_billing_subscription),
 37        )
 38        .route(
 39            "/billing/subscriptions/manage",
 40            post(manage_billing_subscription),
 41        )
 42}
 43
 44#[derive(Debug, Deserialize)]
 45struct ListBillingSubscriptionsParams {
 46    github_user_id: i32,
 47}
 48
 49#[derive(Debug, Serialize)]
 50struct BillingSubscriptionJson {
 51    id: BillingSubscriptionId,
 52    name: String,
 53    status: StripeSubscriptionStatus,
 54    /// Whether this subscription can be canceled.
 55    is_cancelable: bool,
 56}
 57
 58#[derive(Debug, Serialize)]
 59struct ListBillingSubscriptionsResponse {
 60    subscriptions: Vec<BillingSubscriptionJson>,
 61}
 62
 63async fn list_billing_subscriptions(
 64    Extension(app): Extension<Arc<AppState>>,
 65    Query(params): Query<ListBillingSubscriptionsParams>,
 66) -> Result<Json<ListBillingSubscriptionsResponse>> {
 67    let user = app
 68        .db
 69        .get_user_by_github_user_id(params.github_user_id)
 70        .await?
 71        .ok_or_else(|| anyhow!("user not found"))?;
 72
 73    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 74
 75    Ok(Json(ListBillingSubscriptionsResponse {
 76        subscriptions: subscriptions
 77            .into_iter()
 78            .map(|subscription| BillingSubscriptionJson {
 79                id: subscription.id,
 80                name: "Zed Pro".to_string(),
 81                status: subscription.stripe_subscription_status,
 82                is_cancelable: subscription.stripe_subscription_status.is_cancelable(),
 83            })
 84            .collect(),
 85    }))
 86}
 87
 88#[derive(Debug, Deserialize)]
 89struct CreateBillingSubscriptionBody {
 90    github_user_id: i32,
 91}
 92
 93#[derive(Debug, Serialize)]
 94struct CreateBillingSubscriptionResponse {
 95    checkout_session_url: String,
 96}
 97
 98/// Initiates a Stripe Checkout session for creating a billing subscription.
 99async fn create_billing_subscription(
100    Extension(app): Extension<Arc<AppState>>,
101    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
102) -> Result<Json<CreateBillingSubscriptionResponse>> {
103    let user = app
104        .db
105        .get_user_by_github_user_id(body.github_user_id)
106        .await?
107        .ok_or_else(|| anyhow!("user not found"))?;
108
109    let Some((stripe_client, stripe_price_id)) = app
110        .stripe_client
111        .clone()
112        .zip(app.config.stripe_price_id.clone())
113    else {
114        log::error!("failed to retrieve Stripe client or price ID");
115        Err(Error::Http(
116            StatusCode::NOT_IMPLEMENTED,
117            "not supported".into(),
118        ))?
119    };
120
121    let customer_id =
122        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
123            CustomerId::from_str(&existing_customer.stripe_customer_id)
124                .context("failed to parse customer ID")?
125        } else {
126            let customer = Customer::create(
127                &stripe_client,
128                CreateCustomer {
129                    email: user.email_address.as_deref(),
130                    ..Default::default()
131                },
132            )
133            .await?;
134
135            customer.id
136        };
137
138    let checkout_session = {
139        let mut params = CreateCheckoutSession::new();
140        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
141        params.customer = Some(customer_id);
142        params.client_reference_id = Some(user.github_login.as_str());
143        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
144            price: Some(stripe_price_id.to_string()),
145            quantity: Some(1),
146            ..Default::default()
147        }]);
148        params.success_url = Some("https://zed.dev/billing/success");
149
150        CheckoutSession::create(&stripe_client, params).await?
151    };
152
153    Ok(Json(CreateBillingSubscriptionResponse {
154        checkout_session_url: checkout_session
155            .url
156            .ok_or_else(|| anyhow!("no checkout session URL"))?,
157    }))
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "snake_case")]
162enum ManageSubscriptionIntent {
163    /// The user intends to cancel their subscription.
164    Cancel,
165}
166
167#[derive(Debug, Deserialize)]
168struct ManageBillingSubscriptionBody {
169    github_user_id: i32,
170    intent: ManageSubscriptionIntent,
171    /// The ID of the subscription to manage.
172    subscription_id: BillingSubscriptionId,
173}
174
175#[derive(Debug, Serialize)]
176struct ManageBillingSubscriptionResponse {
177    billing_portal_session_url: String,
178}
179
180/// Initiates a Stripe customer portal session for managing a billing subscription.
181async fn manage_billing_subscription(
182    Extension(app): Extension<Arc<AppState>>,
183    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
184) -> Result<Json<ManageBillingSubscriptionResponse>> {
185    let user = app
186        .db
187        .get_user_by_github_user_id(body.github_user_id)
188        .await?
189        .ok_or_else(|| anyhow!("user not found"))?;
190
191    let Some(stripe_client) = app.stripe_client.clone() else {
192        log::error!("failed to retrieve Stripe client");
193        Err(Error::Http(
194            StatusCode::NOT_IMPLEMENTED,
195            "not supported".into(),
196        ))?
197    };
198
199    let customer = app
200        .db
201        .get_billing_customer_by_user_id(user.id)
202        .await?
203        .ok_or_else(|| anyhow!("billing customer not found"))?;
204    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
205        .context("failed to parse customer ID")?;
206
207    let subscription = app
208        .db
209        .get_billing_subscription_by_id(body.subscription_id)
210        .await?
211        .ok_or_else(|| anyhow!("subscription not found"))?;
212
213    let flow = match body.intent {
214        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
215            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
216            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
217                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
218                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
219                    return_url: "https://zed.dev/settings".into(),
220                }),
221                ..Default::default()
222            }),
223            subscription_cancel: Some(
224                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
225                    subscription: subscription.stripe_subscription_id,
226                    retention: None,
227                },
228            ),
229            ..Default::default()
230        },
231    };
232
233    let mut params = CreateBillingPortalSession::new(customer_id);
234    params.flow_data = Some(flow);
235    params.return_url = Some("https://zed.dev/settings");
236
237    let session = BillingPortalSession::create(&stripe_client, params).await?;
238
239    Ok(Json(ManageBillingSubscriptionResponse {
240        billing_portal_session_url: session.url,
241    }))
242}
243
244const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
245
246/// Polls the Stripe events API periodically to reconcile the records in our
247/// database with the data in Stripe.
248pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
249    let Some(stripe_client) = app.stripe_client.clone() else {
250        log::warn!("failed to retrieve Stripe client");
251        return;
252    };
253
254    let executor = app.executor.clone();
255    executor.spawn_detached({
256        let executor = executor.clone();
257        async move {
258            loop {
259                poll_stripe_events(&app, &stripe_client).await.log_err();
260
261                executor.sleep(POLL_EVENTS_INTERVAL).await;
262            }
263        }
264    });
265}
266
267async fn poll_stripe_events(
268    app: &Arc<AppState>,
269    stripe_client: &stripe::Client,
270) -> anyhow::Result<()> {
271    fn event_type_to_string(event_type: EventType) -> String {
272        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
273        // so we need to unquote it.
274        event_type.to_string().trim_matches('"').to_string()
275    }
276
277    let event_types = [
278        EventType::CustomerCreated,
279        EventType::CustomerUpdated,
280        EventType::CustomerSubscriptionCreated,
281        EventType::CustomerSubscriptionUpdated,
282        EventType::CustomerSubscriptionPaused,
283        EventType::CustomerSubscriptionResumed,
284        EventType::CustomerSubscriptionDeleted,
285    ]
286    .into_iter()
287    .map(event_type_to_string)
288    .collect::<Vec<_>>();
289
290    let mut unprocessed_events = Vec::new();
291
292    loop {
293        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
294
295        let mut params = ListEvents::new();
296        params.types = Some(event_types.clone());
297        params.limit = Some(100);
298
299        let events = stripe::Event::list(stripe_client, &params).await?;
300
301        let processed_event_ids = {
302            let event_ids = &events
303                .data
304                .iter()
305                .map(|event| event.id.as_str())
306                .collect::<Vec<_>>();
307
308            app.db
309                .get_processed_stripe_events_by_event_ids(event_ids)
310                .await?
311                .into_iter()
312                .map(|event| event.stripe_event_id)
313                .collect::<Vec<_>>()
314        };
315
316        for event in events.data {
317            if processed_event_ids.contains(&event.id.to_string()) {
318                log::info!("Stripe event {} already processed: skipping", event.id);
319            } else {
320                unprocessed_events.push(event);
321            }
322        }
323
324        if !events.has_more {
325            break;
326        }
327    }
328
329    log::info!(
330        "unprocessed events from Stripe: {}",
331        unprocessed_events.len()
332    );
333
334    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
335    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
336
337    for event in unprocessed_events {
338        let processed_event_params = CreateProcessedStripeEventParams {
339            stripe_event_id: event.id.to_string(),
340            stripe_event_type: event_type_to_string(event.type_),
341            stripe_event_created_timestamp: event.created,
342        };
343
344        match event.type_ {
345            EventType::CustomerCreated | EventType::CustomerUpdated => {
346                handle_customer_event(app, stripe_client, event)
347                    .await
348                    .log_err();
349            }
350            EventType::CustomerSubscriptionCreated
351            | EventType::CustomerSubscriptionUpdated
352            | EventType::CustomerSubscriptionPaused
353            | EventType::CustomerSubscriptionResumed
354            | EventType::CustomerSubscriptionDeleted => {
355                handle_customer_subscription_event(app, stripe_client, event)
356                    .await
357                    .log_err();
358            }
359            _ => {}
360        }
361
362        app.db
363            .create_processed_stripe_event(&processed_event_params)
364            .await?;
365    }
366
367    Ok(())
368}
369
370async fn handle_customer_event(
371    app: &Arc<AppState>,
372    _stripe_client: &stripe::Client,
373    event: stripe::Event,
374) -> anyhow::Result<()> {
375    let EventObject::Customer(customer) = event.data.object else {
376        bail!("unexpected event payload for {}", event.id);
377    };
378
379    log::info!("handling Stripe {} event: {}", event.type_, event.id);
380
381    let Some(email) = customer.email else {
382        log::info!("Stripe customer has no email: skipping");
383        return Ok(());
384    };
385
386    let Some(user) = app.db.get_user_by_email(&email).await? else {
387        log::info!("no user found for email: skipping");
388        return Ok(());
389    };
390
391    if let Some(existing_customer) = app
392        .db
393        .get_billing_customer_by_stripe_customer_id(&customer.id)
394        .await?
395    {
396        app.db
397            .update_billing_customer(
398                existing_customer.id,
399                &UpdateBillingCustomerParams {
400                    // For now we just leave the information as-is, as it is not
401                    // likely to change.
402                    ..Default::default()
403                },
404            )
405            .await?;
406    } else {
407        app.db
408            .create_billing_customer(&CreateBillingCustomerParams {
409                user_id: user.id,
410                stripe_customer_id: customer.id.to_string(),
411            })
412            .await?;
413    }
414
415    Ok(())
416}
417
418async fn handle_customer_subscription_event(
419    app: &Arc<AppState>,
420    stripe_client: &stripe::Client,
421    event: stripe::Event,
422) -> anyhow::Result<()> {
423    let EventObject::Subscription(subscription) = event.data.object else {
424        bail!("unexpected event payload for {}", event.id);
425    };
426
427    log::info!("handling Stripe {} event: {}", event.type_, event.id);
428
429    let billing_customer =
430        find_or_create_billing_customer(app, stripe_client, subscription.customer)
431            .await?
432            .ok_or_else(|| anyhow!("billing customer not found"))?;
433
434    if let Some(existing_subscription) = app
435        .db
436        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
437        .await?
438    {
439        app.db
440            .update_billing_subscription(
441                existing_subscription.id,
442                &UpdateBillingSubscriptionParams {
443                    billing_customer_id: ActiveValue::set(billing_customer.id),
444                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
445                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
446                },
447            )
448            .await?;
449    } else {
450        app.db
451            .create_billing_subscription(&CreateBillingSubscriptionParams {
452                billing_customer_id: billing_customer.id,
453                stripe_subscription_id: subscription.id.to_string(),
454                stripe_subscription_status: subscription.status.into(),
455            })
456            .await?;
457    }
458
459    Ok(())
460}
461
462impl From<SubscriptionStatus> for StripeSubscriptionStatus {
463    fn from(value: SubscriptionStatus) -> Self {
464        match value {
465            SubscriptionStatus::Incomplete => Self::Incomplete,
466            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
467            SubscriptionStatus::Trialing => Self::Trialing,
468            SubscriptionStatus::Active => Self::Active,
469            SubscriptionStatus::PastDue => Self::PastDue,
470            SubscriptionStatus::Canceled => Self::Canceled,
471            SubscriptionStatus::Unpaid => Self::Unpaid,
472            SubscriptionStatus::Paused => Self::Paused,
473        }
474    }
475}
476
477/// Finds or creates a billing customer using the provided customer.
478async fn find_or_create_billing_customer(
479    app: &Arc<AppState>,
480    stripe_client: &stripe::Client,
481    customer_or_id: Expandable<Customer>,
482) -> anyhow::Result<Option<billing_customer::Model>> {
483    let customer_id = match &customer_or_id {
484        Expandable::Id(id) => id,
485        Expandable::Object(customer) => customer.id.as_ref(),
486    };
487
488    // If we already have a billing customer record associated with the Stripe customer,
489    // there's nothing more we need to do.
490    if let Some(billing_customer) = app
491        .db
492        .get_billing_customer_by_stripe_customer_id(&customer_id)
493        .await?
494    {
495        return Ok(Some(billing_customer));
496    }
497
498    // If all we have is a customer ID, resolve it to a full customer record by
499    // hitting the Stripe API.
500    let customer = match customer_or_id {
501        Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
502        Expandable::Object(customer) => *customer,
503    };
504
505    let Some(email) = customer.email else {
506        return Ok(None);
507    };
508
509    let Some(user) = app.db.get_user_by_email(&email).await? else {
510        return Ok(None);
511    };
512
513    let billing_customer = app
514        .db
515        .create_billing_customer(&CreateBillingCustomerParams {
516            user_id: user.id,
517            stripe_customer_id: customer.id.to_string(),
518        })
519        .await?;
520
521    Ok(Some(billing_customer))
522}