Detailed changes
@@ -17,9 +17,8 @@ use stripe::{
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
- CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
- EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
- SubscriptionStatus,
+ CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
+ Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
@@ -310,13 +309,6 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;
- let Some(stripe_client) = app.stripe_client.clone() else {
- log::error!("failed to retrieve Stripe client");
- Err(Error::http(
- StatusCode::NOT_IMPLEMENTED,
- "not supported".into(),
- ))?
- };
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
@@ -351,35 +343,9 @@ async fn create_billing_subscription(
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
- let existing_customer = if let Some(email) = user.email_address.as_deref() {
- let customers = Customer::list(
- &stripe_client,
- &stripe::ListCustomers {
- email: Some(email),
- ..Default::default()
- },
- )
- .await?;
-
- customers.data.first().cloned()
- } else {
- None
- };
-
- if let Some(existing_customer) = existing_customer {
- existing_customer.id
- } else {
- let customer = Customer::create(
- &stripe_client,
- CreateCustomer {
- email: user.email_address.as_deref(),
- ..Default::default()
- },
- )
- .await?;
-
- customer.id
- }
+ stripe_billing
+ .find_or_create_customer_by_email(user.email_address.as_deref())
+ .await?
};
let success_url = format!(
@@ -1487,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
}
/// Finds or creates a billing customer using the provided customer.
-async fn find_or_create_billing_customer(
+pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &stripe::Client,
customer_or_id: Expandable<Customer>,
@@ -32,9 +32,9 @@ impl Database {
pub async fn create_billing_subscription(
&self,
params: &CreateBillingSubscriptionParams,
- ) -> Result<()> {
+ ) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
- billing_subscription::Entity::insert(billing_subscription::ActiveModel {
+ let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
@@ -44,10 +44,14 @@ impl Database {
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
..Default::default()
})
- .exec_without_returning(&*tx)
- .await?;
+ .exec(&*tx)
+ .await?
+ .last_insert_id;
- Ok(())
+ Ok(billing_subscription::Entity::find_by_id(id)
+ .one(&*tx)
+ .await?
+ .ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
})
.await
}
@@ -42,7 +42,7 @@ impl LlmTokenClaims {
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
- subscription: Option<billing_subscription::Model>,
+ subscription: billing_subscription::Model,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
@@ -54,17 +54,14 @@ impl LlmTokenClaims {
let plan = if is_staff {
Plan::ZedPro
} else {
- subscription
- .as_ref()
- .and_then(|subscription| subscription.kind)
- .map_or(Plan::ZedFree, |kind| match kind {
- SubscriptionKind::ZedFree => Plan::ZedFree,
- SubscriptionKind::ZedPro => Plan::ZedPro,
- SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
- })
+ subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
+ SubscriptionKind::ZedFree => Plan::ZedFree,
+ SubscriptionKind::ZedPro => Plan::ZedPro,
+ SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
+ })
};
let subscription_period =
- billing_subscription::Model::current_period(subscription, is_staff)
+ billing_subscription::Model::current_period(Some(subscription), is_staff)
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
@@ -1,5 +1,6 @@
mod connection_pool;
+use crate::api::billing::find_or_create_billing_customer;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
@@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
Err(anyhow!("terms of service not accepted"))?
}
- let billing_subscription = db.get_active_billing_subscription(user.id).await?;
+ let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
+ Err(anyhow!("failed to retrieve Stripe client"))?
+ };
+
+ let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
+ Err(anyhow!("failed to retrieve Stripe billing object"))?
+ };
+
+ let billing_customer =
+ if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
+ billing_customer
+ } else {
+ let customer_id = stripe_billing
+ .find_or_create_customer_by_email(user.email_address.as_deref())
+ .await?;
+
+ find_or_create_billing_customer(
+ &session.app_state,
+ &stripe_client,
+ stripe::Expandable::Id(customer_id),
+ )
+ .await?
+ .ok_or_else(|| anyhow!("billing customer not found"))?
+ };
+
+ let billing_subscription =
+ if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
+ billing_subscription
+ } else {
+ let stripe_customer_id = billing_customer
+ .stripe_customer_id
+ .parse::<stripe::CustomerId>()
+ .context("failed to parse Stripe customer ID from database")?;
+
+ let stripe_subscription = stripe_billing
+ .subscribe_to_zed_free(stripe_customer_id)
+ .await?;
+
+ db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
+ billing_customer_id: billing_customer.id,
+ kind: Some(SubscriptionKind::ZedFree),
+ stripe_subscription_id: stripe_subscription.id.to_string(),
+ stripe_subscription_status: stripe_subscription.status.into(),
+ stripe_cancellation_reason: None,
+ stripe_current_period_start: Some(stripe_subscription.current_period_start),
+ stripe_current_period_end: Some(stripe_subscription.current_period_end),
+ })
+ .await?
+ };
+
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
@@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
-use stripe::{PriceId, SubscriptionStatus};
+use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
@@ -122,6 +122,47 @@ impl StripeBilling {
})
}
+ /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
+ /// not already exist.
+ ///
+ /// Always returns a new Stripe customer if the email address is `None`.
+ pub async fn find_or_create_customer_by_email(
+ &self,
+ email_address: Option<&str>,
+ ) -> Result<CustomerId> {
+ let existing_customer = if let Some(email) = email_address {
+ let customers = Customer::list(
+ &self.client,
+ &stripe::ListCustomers {
+ email: Some(email),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ customers.data.first().cloned()
+ } else {
+ None
+ };
+
+ let customer_id = if let Some(existing_customer) = existing_customer {
+ existing_customer.id
+ } else {
+ let customer = Customer::create(
+ &self.client,
+ CreateCustomer {
+ email: email_address,
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ customer.id
+ };
+
+ Ok(customer_id)
+ }
+
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,