collab: Add endpoint for managing a billing subscription (#15455)

Marshall Bowers created

This PR adds a new `POST /billing/subscriptions/manage` endpoint that
can be used to manage a billing subscription.

The endpoint accepts a `github_user_id` to identify the user, as well as
an optional `subscription_id` for managing a specific subscription. If
`subscription_id` is not provided, it try and use the active
subscription, if there is only one.

Right now the endpoint only supports cancelling an active subscription.
This is done by passing `"intent": "cancel"` in the request body.

The endpoint will return the URL to a Stripe customer portal session,
which the caller can redirect the user to.

Here's an example of how to call it:

```sh
curl -X POST "http://localhost:8080/billing/subscriptions/manage" \
     -H "Authorization: <ADMIN_TOKEN>" \
     -H "Content-Type: application/json" \
     -d '{"github_user_id": 12345, "intent": "cancel"}'
```

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      | 113 ++++++++++++
crates/collab/src/db/queries/billing_subscriptions.rs |  15 +
2 files changed, 124 insertions(+), 4 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -1,17 +1,29 @@
 use std::str::FromStr;
 use std::sync::Arc;
 
-use anyhow::anyhow;
+use anyhow::{anyhow, Context};
 use axum::{extract, routing::post, Extension, Json, Router};
 use collections::HashSet;
 use reqwest::StatusCode;
 use serde::{Deserialize, Serialize};
-use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
+use stripe::{
+    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
+    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
+    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
+    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
+    CustomerId,
+};
 
+use crate::db::BillingSubscriptionId;
 use crate::{AppState, Error, Result};
 
 pub fn router() -> Router {
-    Router::new().route("/billing/subscriptions", post(create_billing_subscription))
+    Router::new()
+        .route("/billing/subscriptions", post(create_billing_subscription))
+        .route(
+            "/billing/subscriptions/manage",
+            post(manage_billing_subscription),
+        )
 }
 
 #[derive(Debug, Deserialize)]
@@ -61,7 +73,7 @@ async fn create_billing_subscription(
         distinct_customer_ids
             .into_iter()
             .next()
-            .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
+            .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
             .transpose()
     }?;
 
@@ -86,3 +98,96 @@ async fn create_billing_subscription(
             .ok_or_else(|| anyhow!("no checkout session URL"))?,
     }))
 }
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ManageSubscriptionIntent {
+    /// The user intends to cancel their subscription.
+    Cancel,
+}
+
+#[derive(Debug, Deserialize)]
+struct ManageBillingSubscriptionBody {
+    github_user_id: i32,
+    intent: ManageSubscriptionIntent,
+    /// The ID of the subscription to manage.
+    ///
+    /// If not provided, we will try to use the active subscription (if there is only one).
+    subscription_id: Option<BillingSubscriptionId>,
+}
+
+#[derive(Debug, Serialize)]
+struct ManageBillingSubscriptionResponse {
+    billing_portal_session_url: String,
+}
+
+/// Initiates a Stripe customer portal session for managing a billing subscription.
+async fn manage_billing_subscription(
+    Extension(app): Extension<Arc<AppState>>,
+    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
+) -> Result<Json<ManageBillingSubscriptionResponse>> {
+    let user = app
+        .db
+        .get_user_by_github_user_id(body.github_user_id)
+        .await?
+        .ok_or_else(|| anyhow!("user not found"))?;
+
+    let Some(stripe_client) = app.stripe_client.clone() else {
+        log::error!("failed to retrieve Stripe client");
+        Err(Error::Http(
+            StatusCode::NOT_IMPLEMENTED,
+            "not supported".into(),
+        ))?
+    };
+
+    let subscription = if let Some(subscription_id) = body.subscription_id {
+        app.db
+            .get_billing_subscription_by_id(subscription_id)
+            .await?
+            .ok_or_else(|| anyhow!("subscription not found"))?
+    } else {
+        // If no subscription ID was provided, try to find the only active subscription ID.
+        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
+        if subscriptions.len() > 1 {
+            Err(anyhow!("user has multiple active subscriptions"))?;
+        }
+
+        subscriptions
+            .into_iter()
+            .next()
+            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
+    };
+
+    let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
+        .context("failed to parse customer ID")?;
+
+    let flow = match body.intent {
+        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
+            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
+            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
+                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
+                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
+                    return_url: "https://zed.dev/billing".into(),
+                }),
+                ..Default::default()
+            }),
+            subscription_cancel: Some(
+                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
+                    subscription: subscription.stripe_subscription_id,
+                    retention: None,
+                },
+            ),
+            ..Default::default()
+        },
+    };
+
+    let mut params = CreateBillingPortalSession::new(customer_id);
+    params.flow_data = Some(flow);
+    params.return_url = Some("https://zed.dev/billing");
+
+    let session = BillingPortalSession::create(&stripe_client, params).await?;
+
+    Ok(Json(ManageBillingSubscriptionResponse {
+        billing_portal_session_url: session.url,
+    }))
+}

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -32,6 +32,19 @@ impl Database {
         .await
     }
 
+    /// Returns the billing subscription with the specified ID.
+    pub async fn get_billing_subscription_by_id(
+        &self,
+        id: BillingSubscriptionId,
+    ) -> Result<Option<billing_subscription::Model>> {
+        self.transaction(|tx| async move {
+            Ok(billing_subscription::Entity::find_by_id(id)
+                .one(&*tx)
+                .await?)
+        })
+        .await
+    }
+
     /// Returns all of the billing subscriptions for the user with the specified ID.
     ///
     /// Note that this returns the subscriptions regardless of their status.
@@ -44,6 +57,7 @@ impl Database {
         self.transaction(|tx| async move {
             let subscriptions = billing_subscription::Entity::find()
                 .filter(billing_subscription::Column::UserId.eq(user_id))
+                .order_by_asc(billing_subscription::Column::Id)
                 .all(&*tx)
                 .await?;
 
@@ -65,6 +79,7 @@ impl Database {
                             .eq(StripeSubscriptionStatus::Active),
                     ),
                 )
+                .order_by_asc(billing_subscription::Column::Id)
                 .all(&*tx)
                 .await?;