collab: Create Zed Free subscription when issuing an LLM token (#30975)

Marshall Bowers and Max Brunsfeld created

This PR makes it so we create a Zed Free subscription when issuing an
LLM token, if one does not already exist.

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/collab/src/api/billing.rs                      | 46 +----------
crates/collab/src/db/queries/billing_subscriptions.rs | 14 ++-
crates/collab/src/llm/token.rs                        | 17 +--
crates/collab/src/rpc.rs                              | 52 ++++++++++++
crates/collab/src/stripe_billing.rs                   | 43 ++++++++++
5 files changed, 115 insertions(+), 57 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -17,9 +17,8 @@ use stripe::{
     CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
     CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
     CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
-    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
-    EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
-    SubscriptionStatus,
+    CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
+    Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
 };
 use util::{ResultExt, maybe};
 
@@ -310,13 +309,6 @@ async fn create_billing_subscription(
         .await?
         .ok_or_else(|| anyhow!("user not found"))?;
 
-    let Some(stripe_client) = app.stripe_client.clone() else {
-        log::error!("failed to retrieve Stripe client");
-        Err(Error::http(
-            StatusCode::NOT_IMPLEMENTED,
-            "not supported".into(),
-        ))?
-    };
     let Some(stripe_billing) = app.stripe_billing.clone() else {
         log::error!("failed to retrieve Stripe billing object");
         Err(Error::http(
@@ -351,35 +343,9 @@ async fn create_billing_subscription(
         CustomerId::from_str(&existing_customer.stripe_customer_id)
             .context("failed to parse customer ID")?
     } else {
-        let existing_customer = if let Some(email) = user.email_address.as_deref() {
-            let customers = Customer::list(
-                &stripe_client,
-                &stripe::ListCustomers {
-                    email: Some(email),
-                    ..Default::default()
-                },
-            )
-            .await?;
-
-            customers.data.first().cloned()
-        } else {
-            None
-        };
-
-        if let Some(existing_customer) = existing_customer {
-            existing_customer.id
-        } else {
-            let customer = Customer::create(
-                &stripe_client,
-                CreateCustomer {
-                    email: user.email_address.as_deref(),
-                    ..Default::default()
-                },
-            )
-            .await?;
-
-            customer.id
-        }
+        stripe_billing
+            .find_or_create_customer_by_email(user.email_address.as_deref())
+            .await?
     };
 
     let success_url = format!(
@@ -1487,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
 }
 
 /// Finds or creates a billing customer using the provided customer.
-async fn find_or_create_billing_customer(
+pub async fn find_or_create_billing_customer(
     app: &Arc<AppState>,
     stripe_client: &stripe::Client,
     customer_or_id: Expandable<Customer>,

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -32,9 +32,9 @@ impl Database {
     pub async fn create_billing_subscription(
         &self,
         params: &CreateBillingSubscriptionParams,
-    ) -> Result<()> {
+    ) -> Result<billing_subscription::Model> {
         self.transaction(|tx| async move {
-            billing_subscription::Entity::insert(billing_subscription::ActiveModel {
+            let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
                 billing_customer_id: ActiveValue::set(params.billing_customer_id),
                 kind: ActiveValue::set(params.kind),
                 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
@@ -44,10 +44,14 @@ impl Database {
                 stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
                 ..Default::default()
             })
-            .exec_without_returning(&*tx)
-            .await?;
+            .exec(&*tx)
+            .await?
+            .last_insert_id;
 
-            Ok(())
+            Ok(billing_subscription::Entity::find_by_id(id)
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
         })
         .await
     }

crates/collab/src/llm/token.rs 🔗

@@ -42,7 +42,7 @@ impl LlmTokenClaims {
         is_staff: bool,
         billing_preferences: Option<billing_preference::Model>,
         feature_flags: &Vec<String>,
-        subscription: Option<billing_subscription::Model>,
+        subscription: billing_subscription::Model,
         system_id: Option<String>,
         config: &Config,
     ) -> Result<String> {
@@ -54,17 +54,14 @@ impl LlmTokenClaims {
         let plan = if is_staff {
             Plan::ZedPro
         } else {
-            subscription
-                .as_ref()
-                .and_then(|subscription| subscription.kind)
-                .map_or(Plan::ZedFree, |kind| match kind {
-                    SubscriptionKind::ZedFree => Plan::ZedFree,
-                    SubscriptionKind::ZedPro => Plan::ZedPro,
-                    SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
-                })
+            subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
+                SubscriptionKind::ZedFree => Plan::ZedFree,
+                SubscriptionKind::ZedPro => Plan::ZedPro,
+                SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
+            })
         };
         let subscription_period =
-            billing_subscription::Model::current_period(subscription, is_staff)
+            billing_subscription::Model::current_period(Some(subscription), is_staff)
                 .map(|(start, end)| (start.naive_utc(), end.naive_utc()))
                 .ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
 

crates/collab/src/rpc.rs 🔗

@@ -1,5 +1,6 @@
 mod connection_pool;
 
+use crate::api::billing::find_or_create_billing_customer;
 use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::db::LlmDatabase;
@@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
         Err(anyhow!("terms of service not accepted"))?
     }
 
-    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
+    let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
+        Err(anyhow!("failed to retrieve Stripe client"))?
+    };
+
+    let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
+        Err(anyhow!("failed to retrieve Stripe billing object"))?
+    };
+
+    let billing_customer =
+        if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
+            billing_customer
+        } else {
+            let customer_id = stripe_billing
+                .find_or_create_customer_by_email(user.email_address.as_deref())
+                .await?;
+
+            find_or_create_billing_customer(
+                &session.app_state,
+                &stripe_client,
+                stripe::Expandable::Id(customer_id),
+            )
+            .await?
+            .ok_or_else(|| anyhow!("billing customer not found"))?
+        };
+
+    let billing_subscription =
+        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
+            billing_subscription
+        } else {
+            let stripe_customer_id = billing_customer
+                .stripe_customer_id
+                .parse::<stripe::CustomerId>()
+                .context("failed to parse Stripe customer ID from database")?;
+
+            let stripe_subscription = stripe_billing
+                .subscribe_to_zed_free(stripe_customer_id)
+                .await?;
+
+            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
+                billing_customer_id: billing_customer.id,
+                kind: Some(SubscriptionKind::ZedFree),
+                stripe_subscription_id: stripe_subscription.id.to_string(),
+                stripe_subscription_status: stripe_subscription.status.into(),
+                stripe_cancellation_reason: None,
+                stripe_current_period_start: Some(stripe_subscription.current_period_start),
+                stripe_current_period_end: Some(stripe_subscription.current_period_end),
+            })
+            .await?
+        };
+
     let billing_preferences = db.get_billing_preferences(user.id).await?;
 
     let token = LlmTokenClaims::create(

crates/collab/src/stripe_billing.rs 🔗

@@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
 use chrono::Utc;
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
-use stripe::{PriceId, SubscriptionStatus};
+use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
 use tokio::sync::RwLock;
 use uuid::Uuid;
 
@@ -122,6 +122,47 @@ impl StripeBilling {
         })
     }
 
+    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
+    /// not already exist.
+    ///
+    /// Always returns a new Stripe customer if the email address is `None`.
+    pub async fn find_or_create_customer_by_email(
+        &self,
+        email_address: Option<&str>,
+    ) -> Result<CustomerId> {
+        let existing_customer = if let Some(email) = email_address {
+            let customers = Customer::list(
+                &self.client,
+                &stripe::ListCustomers {
+                    email: Some(email),
+                    ..Default::default()
+                },
+            )
+            .await?;
+
+            customers.data.first().cloned()
+        } else {
+            None
+        };
+
+        let customer_id = if let Some(existing_customer) = existing_customer {
+            existing_customer.id
+        } else {
+            let customer = Customer::create(
+                &self.client,
+                CreateCustomer {
+                    email: email_address,
+                    ..Default::default()
+                },
+            )
+            .await?;
+
+            customer.id
+        };
+
+        Ok(customer_id)
+    }
+
     pub async fn subscribe_to_price(
         &self,
         subscription_id: &stripe::SubscriptionId,