@@ -1,17 +1,29 @@
use std::str::FromStr;
use std::sync::Arc;
-use anyhow::anyhow;
+use anyhow::{anyhow, Context};
use axum::{extract, routing::post, Extension, Json, Router};
use collections::HashSet;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
-use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
+use stripe::{
+ BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
+ CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
+ CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
+ CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
+ CustomerId,
+};
+use crate::db::BillingSubscriptionId;
use crate::{AppState, Error, Result};
pub fn router() -> Router {
- Router::new().route("/billing/subscriptions", post(create_billing_subscription))
+ Router::new()
+ .route("/billing/subscriptions", post(create_billing_subscription))
+ .route(
+ "/billing/subscriptions/manage",
+ post(manage_billing_subscription),
+ )
}
#[derive(Debug, Deserialize)]
@@ -61,7 +73,7 @@ async fn create_billing_subscription(
distinct_customer_ids
.into_iter()
.next()
- .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
+ .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
.transpose()
}?;
@@ -86,3 +98,96 @@ async fn create_billing_subscription(
.ok_or_else(|| anyhow!("no checkout session URL"))?,
}))
}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ManageSubscriptionIntent {
+ /// The user intends to cancel their subscription.
+ Cancel,
+}
+
+#[derive(Debug, Deserialize)]
+struct ManageBillingSubscriptionBody {
+ github_user_id: i32,
+ intent: ManageSubscriptionIntent,
+ /// The ID of the subscription to manage.
+ ///
+ /// If not provided, we will try to use the active subscription (if there is only one).
+ subscription_id: Option<BillingSubscriptionId>,
+}
+
+#[derive(Debug, Serialize)]
+struct ManageBillingSubscriptionResponse {
+ billing_portal_session_url: String,
+}
+
+/// Initiates a Stripe customer portal session for managing a billing subscription.
+async fn manage_billing_subscription(
+ Extension(app): Extension<Arc<AppState>>,
+ extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
+) -> Result<Json<ManageBillingSubscriptionResponse>> {
+ let user = app
+ .db
+ .get_user_by_github_user_id(body.github_user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let Some(stripe_client) = app.stripe_client.clone() else {
+ log::error!("failed to retrieve Stripe client");
+ Err(Error::Http(
+ StatusCode::NOT_IMPLEMENTED,
+ "not supported".into(),
+ ))?
+ };
+
+ let subscription = if let Some(subscription_id) = body.subscription_id {
+ app.db
+ .get_billing_subscription_by_id(subscription_id)
+ .await?
+ .ok_or_else(|| anyhow!("subscription not found"))?
+ } else {
+ // If no subscription ID was provided, try to find the only active subscription ID.
+ let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
+ if subscriptions.len() > 1 {
+ Err(anyhow!("user has multiple active subscriptions"))?;
+ }
+
+ subscriptions
+ .into_iter()
+ .next()
+ .ok_or_else(|| anyhow!("user has no active subscriptions"))?
+ };
+
+ let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
+ .context("failed to parse customer ID")?;
+
+ let flow = match body.intent {
+ ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
+ type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
+ after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
+ type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
+ redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
+ return_url: "https://zed.dev/billing".into(),
+ }),
+ ..Default::default()
+ }),
+ subscription_cancel: Some(
+ stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
+ subscription: subscription.stripe_subscription_id,
+ retention: None,
+ },
+ ),
+ ..Default::default()
+ },
+ };
+
+ let mut params = CreateBillingPortalSession::new(customer_id);
+ params.flow_data = Some(flow);
+ params.return_url = Some("https://zed.dev/billing");
+
+ let session = BillingPortalSession::create(&stripe_client, params).await?;
+
+ Ok(Json(ManageBillingSubscriptionResponse {
+ billing_portal_session_url: session.url,
+ }))
+}
@@ -32,6 +32,19 @@ impl Database {
.await
}
+ /// Returns the billing subscription with the specified ID.
+ pub async fn get_billing_subscription_by_id(
+ &self,
+ id: BillingSubscriptionId,
+ ) -> Result<Option<billing_subscription::Model>> {
+ self.transaction(|tx| async move {
+ Ok(billing_subscription::Entity::find_by_id(id)
+ .one(&*tx)
+ .await?)
+ })
+ .await
+ }
+
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
@@ -44,6 +57,7 @@ impl Database {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.filter(billing_subscription::Column::UserId.eq(user_id))
+ .order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;
@@ -65,6 +79,7 @@ impl Database {
.eq(StripeSubscriptionStatus::Active),
),
)
+ .order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;