@@ -54,6 +54,10 @@ pub fn router() -> Router {
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
+ .route(
+ "/billing/subscriptions/migrate",
+ post(migrate_to_new_billing),
+ )
.route("/billing/monthly_spend", get(get_monthly_spend))
.route("/billing/usage", get(get_current_usage))
}
@@ -610,6 +614,85 @@ async fn manage_billing_subscription(
}))
}
+#[derive(Debug, Deserialize)]
+struct MigrateToNewBillingBody {
+ github_user_id: i32,
+}
+
+#[derive(Debug, Serialize)]
+struct MigrateToNewBillingResponse {
+ /// The ID of the subscription that was canceled.
+ canceled_subscription_id: String,
+}
+
+async fn migrate_to_new_billing(
+ Extension(app): Extension<Arc<AppState>>,
+ extract::Json(body): extract::Json<MigrateToNewBillingBody>,
+) -> Result<Json<MigrateToNewBillingResponse>> {
+ let Some(stripe_client) = app.stripe_client.clone() else {
+ log::error!("failed to retrieve Stripe client");
+ Err(Error::http(
+ StatusCode::NOT_IMPLEMENTED,
+ "not supported".into(),
+ ))?
+ };
+
+ let user = app
+ .db
+ .get_user_by_github_user_id(body.github_user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let old_billing_subscriptions_by_user = app
+ .db
+ .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
+ .await?;
+
+ let Some((_billing_customer, billing_subscription)) =
+ old_billing_subscriptions_by_user.get(&user.id)
+ else {
+ return Err(Error::http(
+ StatusCode::NOT_FOUND,
+ "No active billing subscriptions to migrate".into(),
+ ));
+ };
+
+ let stripe_subscription_id = billing_subscription
+ .stripe_subscription_id
+ .parse::<stripe::SubscriptionId>()
+ .context("failed to parse Stripe subscription ID from database")?;
+
+ Subscription::cancel(
+ &stripe_client,
+ &stripe_subscription_id,
+ stripe::CancelSubscription {
+ invoice_now: Some(true),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ let feature_flags = app.db.list_feature_flags().await?;
+
+ for feature_flag in ["new-billing", "assistant2"] {
+ let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
+ if already_in_feature_flag {
+ continue;
+ }
+
+ let feature_flag = feature_flags
+ .iter()
+ .find(|flag| flag.flag == feature_flag)
+ .context("failed to find feature flag: {feature_flag:?}")?;
+
+ app.db.add_user_flag(user.id, feature_flag.id).await?;
+ }
+
+ Ok(Json(MigrateToNewBillingResponse {
+ canceled_subscription_id: stripe_subscription_id.to_string(),
+ }))
+}
+
/// The amount of time we wait in between each poll of Stripe events.
///
/// This value should strike a balance between: