billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3use std::time::Duration;
  4
  5use anyhow::{anyhow, bail, Context};
  6use axum::{extract, routing::post, Extension, Json, Router};
  7use reqwest::StatusCode;
  8use serde::{Deserialize, Serialize};
  9use stripe::{
 10    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 11    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 12    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 13    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 14    CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
 15    SubscriptionStatus,
 16};
 17use util::ResultExt;
 18
 19use crate::db::billing_subscription::StripeSubscriptionStatus;
 20use crate::db::{
 21    billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
 22    CreateBillingSubscriptionParams,
 23};
 24use crate::{AppState, Error, Result};
 25
 26pub fn router() -> Router {
 27    Router::new()
 28        .route("/billing/subscriptions", post(create_billing_subscription))
 29        .route(
 30            "/billing/subscriptions/manage",
 31            post(manage_billing_subscription),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct CreateBillingSubscriptionBody {
 37    github_user_id: i32,
 38}
 39
 40#[derive(Debug, Serialize)]
 41struct CreateBillingSubscriptionResponse {
 42    checkout_session_url: String,
 43}
 44
 45/// Initiates a Stripe Checkout session for creating a billing subscription.
 46async fn create_billing_subscription(
 47    Extension(app): Extension<Arc<AppState>>,
 48    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 49) -> Result<Json<CreateBillingSubscriptionResponse>> {
 50    let user = app
 51        .db
 52        .get_user_by_github_user_id(body.github_user_id)
 53        .await?
 54        .ok_or_else(|| anyhow!("user not found"))?;
 55
 56    let Some((stripe_client, stripe_price_id)) = app
 57        .stripe_client
 58        .clone()
 59        .zip(app.config.stripe_price_id.clone())
 60    else {
 61        log::error!("failed to retrieve Stripe client or price ID");
 62        Err(Error::Http(
 63            StatusCode::NOT_IMPLEMENTED,
 64            "not supported".into(),
 65        ))?
 66    };
 67
 68    let customer_id =
 69        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
 70            CustomerId::from_str(&existing_customer.stripe_customer_id)
 71                .context("failed to parse customer ID")?
 72        } else {
 73            let customer = Customer::create(
 74                &stripe_client,
 75                CreateCustomer {
 76                    email: user.email_address.as_deref(),
 77                    ..Default::default()
 78                },
 79            )
 80            .await?;
 81
 82            customer.id
 83        };
 84
 85    let checkout_session = {
 86        let mut params = CreateCheckoutSession::new();
 87        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
 88        params.customer = Some(customer_id);
 89        params.client_reference_id = Some(user.github_login.as_str());
 90        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
 91            price: Some(stripe_price_id.to_string()),
 92            quantity: Some(1),
 93            ..Default::default()
 94        }]);
 95        params.success_url = Some("https://zed.dev/billing/success");
 96
 97        CheckoutSession::create(&stripe_client, params).await?
 98    };
 99
100    Ok(Json(CreateBillingSubscriptionResponse {
101        checkout_session_url: checkout_session
102            .url
103            .ok_or_else(|| anyhow!("no checkout session URL"))?,
104    }))
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(rename_all = "snake_case")]
109enum ManageSubscriptionIntent {
110    /// The user intends to cancel their subscription.
111    Cancel,
112}
113
114#[derive(Debug, Deserialize)]
115struct ManageBillingSubscriptionBody {
116    github_user_id: i32,
117    intent: ManageSubscriptionIntent,
118    /// The ID of the subscription to manage.
119    ///
120    /// If not provided, we will try to use the active subscription (if there is only one).
121    subscription_id: Option<BillingSubscriptionId>,
122}
123
124#[derive(Debug, Serialize)]
125struct ManageBillingSubscriptionResponse {
126    billing_portal_session_url: String,
127}
128
129/// Initiates a Stripe customer portal session for managing a billing subscription.
130async fn manage_billing_subscription(
131    Extension(app): Extension<Arc<AppState>>,
132    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
133) -> Result<Json<ManageBillingSubscriptionResponse>> {
134    let user = app
135        .db
136        .get_user_by_github_user_id(body.github_user_id)
137        .await?
138        .ok_or_else(|| anyhow!("user not found"))?;
139
140    let Some(stripe_client) = app.stripe_client.clone() else {
141        log::error!("failed to retrieve Stripe client");
142        Err(Error::Http(
143            StatusCode::NOT_IMPLEMENTED,
144            "not supported".into(),
145        ))?
146    };
147
148    let customer = app
149        .db
150        .get_billing_customer_by_user_id(user.id)
151        .await?
152        .ok_or_else(|| anyhow!("billing customer not found"))?;
153    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
154        .context("failed to parse customer ID")?;
155
156    let subscription = if let Some(subscription_id) = body.subscription_id {
157        app.db
158            .get_billing_subscription_by_id(subscription_id)
159            .await?
160            .ok_or_else(|| anyhow!("subscription not found"))?
161    } else {
162        // If no subscription ID was provided, try to find the only active subscription ID.
163        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
164        if subscriptions.len() > 1 {
165            Err(anyhow!("user has multiple active subscriptions"))?;
166        }
167
168        subscriptions
169            .into_iter()
170            .next()
171            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
172    };
173
174    let flow = match body.intent {
175        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
176            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
177            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
178                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
179                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
180                    return_url: "https://zed.dev/billing".into(),
181                }),
182                ..Default::default()
183            }),
184            subscription_cancel: Some(
185                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
186                    subscription: subscription.stripe_subscription_id,
187                    retention: None,
188                },
189            ),
190            ..Default::default()
191        },
192    };
193
194    let mut params = CreateBillingPortalSession::new(customer_id);
195    params.flow_data = Some(flow);
196    params.return_url = Some("https://zed.dev/billing");
197
198    let session = BillingPortalSession::create(&stripe_client, params).await?;
199
200    Ok(Json(ManageBillingSubscriptionResponse {
201        billing_portal_session_url: session.url,
202    }))
203}
204
205const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
206
207/// Polls the Stripe events API periodically to reconcile the records in our
208/// database with the data in Stripe.
209pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
210    let Some(stripe_client) = app.stripe_client.clone() else {
211        log::warn!("failed to retrieve Stripe client");
212        return;
213    };
214
215    let executor = app.executor.clone();
216    executor.spawn_detached({
217        let executor = executor.clone();
218        async move {
219            loop {
220                poll_stripe_events(&app, &stripe_client).await.log_err();
221
222                executor.sleep(POLL_EVENTS_INTERVAL).await;
223            }
224        }
225    });
226}
227
228async fn poll_stripe_events(
229    app: &Arc<AppState>,
230    stripe_client: &stripe::Client,
231) -> anyhow::Result<()> {
232    let event_types = [
233        EventType::CustomerCreated.to_string(),
234        EventType::CustomerSubscriptionCreated.to_string(),
235        EventType::CustomerSubscriptionUpdated.to_string(),
236        EventType::CustomerSubscriptionPaused.to_string(),
237        EventType::CustomerSubscriptionResumed.to_string(),
238        EventType::CustomerSubscriptionDeleted.to_string(),
239    ]
240    .into_iter()
241    .map(|event_type| {
242        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
243        // so we need to unquote it.
244        event_type.trim_matches('"').to_string()
245    })
246    .collect::<Vec<_>>();
247
248    loop {
249        log::info!("retrieving events from Stripe: {}", event_types.join(", "));
250
251        let mut params = ListEvents::new();
252        params.types = Some(event_types.clone());
253        params.limit = Some(100);
254
255        let events = stripe::Event::list(stripe_client, &params).await?;
256        for event in events.data {
257            match event.type_ {
258                EventType::CustomerCreated => {
259                    handle_customer_event(app, stripe_client, event)
260                        .await
261                        .log_err();
262                }
263                EventType::CustomerSubscriptionCreated
264                | EventType::CustomerSubscriptionUpdated
265                | EventType::CustomerSubscriptionPaused
266                | EventType::CustomerSubscriptionResumed
267                | EventType::CustomerSubscriptionDeleted => {
268                    handle_customer_subscription_event(app, stripe_client, event)
269                        .await
270                        .log_err();
271                }
272                _ => {}
273            }
274        }
275
276        if !events.has_more {
277            break;
278        }
279    }
280
281    Ok(())
282}
283
284async fn handle_customer_event(
285    app: &Arc<AppState>,
286    stripe_client: &stripe::Client,
287    event: stripe::Event,
288) -> anyhow::Result<()> {
289    let EventObject::Customer(customer) = event.data.object else {
290        bail!("unexpected event payload for {}", event.id);
291    };
292
293    find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer)))
294        .await?;
295
296    Ok(())
297}
298
299async fn handle_customer_subscription_event(
300    app: &Arc<AppState>,
301    stripe_client: &stripe::Client,
302    event: stripe::Event,
303) -> anyhow::Result<()> {
304    let EventObject::Subscription(subscription) = event.data.object else {
305        bail!("unexpected event payload for {}", event.id);
306    };
307
308    let billing_customer =
309        find_or_create_billing_customer(app, stripe_client, subscription.customer)
310            .await?
311            .ok_or_else(|| anyhow!("billing customer not found"))?;
312
313    app.db
314        .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams {
315            billing_customer_id: billing_customer.id,
316            stripe_subscription_id: subscription.id.to_string(),
317            stripe_subscription_status: subscription.status.into(),
318        })
319        .await?;
320
321    Ok(())
322}
323
324impl From<SubscriptionStatus> for StripeSubscriptionStatus {
325    fn from(value: SubscriptionStatus) -> Self {
326        match value {
327            SubscriptionStatus::Incomplete => Self::Incomplete,
328            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
329            SubscriptionStatus::Trialing => Self::Trialing,
330            SubscriptionStatus::Active => Self::Active,
331            SubscriptionStatus::PastDue => Self::PastDue,
332            SubscriptionStatus::Canceled => Self::Canceled,
333            SubscriptionStatus::Unpaid => Self::Unpaid,
334            SubscriptionStatus::Paused => Self::Paused,
335        }
336    }
337}
338
339/// Finds or creates a billing customer using the provided customer.
340async fn find_or_create_billing_customer(
341    app: &Arc<AppState>,
342    stripe_client: &stripe::Client,
343    customer_or_id: Expandable<Customer>,
344) -> anyhow::Result<Option<billing_customer::Model>> {
345    let customer_id = match &customer_or_id {
346        Expandable::Id(id) => id,
347        Expandable::Object(customer) => customer.id.as_ref(),
348    };
349
350    // If we already have a billing customer record associated with the Stripe customer,
351    // there's nothing more we need to do.
352    if let Some(billing_customer) = app
353        .db
354        .get_billing_customer_by_stripe_customer_id(&customer_id)
355        .await?
356    {
357        return Ok(Some(billing_customer));
358    }
359
360    // If all we have is a customer ID, resolve it to a full customer record by
361    // hitting the Stripe API.
362    let customer = match customer_or_id {
363        Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
364        Expandable::Object(customer) => *customer,
365    };
366
367    let Some(email) = customer.email else {
368        return Ok(None);
369    };
370
371    let Some(user) = app.db.get_user_by_email(&email).await? else {
372        return Ok(None);
373    };
374
375    let billing_customer = app
376        .db
377        .create_billing_customer(&CreateBillingCustomerParams {
378            user_id: user.id,
379            stripe_customer_id: customer.id.to_string(),
380        })
381        .await?;
382
383    Ok(Some(billing_customer))
384}