Detailed changes
@@ -505,7 +505,10 @@ CREATE TABLE IF NOT EXISTS billing_subscriptions (
stripe_subscription_id TEXT NOT NULL,
stripe_subscription_status TEXT NOT NULL,
stripe_cancel_at TIMESTAMP,
- stripe_cancellation_reason TEXT
+ stripe_cancellation_reason TEXT,
+ kind TEXT,
+ stripe_current_period_start BIGINT,
+ stripe_current_period_end BIGINT
);
CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
@@ -0,0 +1,4 @@
+alter table billing_subscriptions
+ add column kind text,
+ add column stripe_current_period_start bigint,
+ add column stripe_current_period_end bigint;
@@ -21,7 +21,9 @@ use stripe::{
use util::ResultExt;
use crate::api::events::SnowflakeRow;
-use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
+use crate::db::billing_subscription::{
+ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
+};
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::rpc::{ResultExt as _, Server};
use crate::{AppState, Cents, Error, Result};
@@ -184,7 +186,10 @@ async fn list_billing_subscriptions(
.into_iter()
.map(|subscription| BillingSubscriptionJson {
id: subscription.id,
- name: "Zed LLM Usage".to_string(),
+ name: match subscription.kind {
+ Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
+ None => "Zed LLM Usage".to_string(),
+ },
status: subscription.stripe_subscription_status,
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at
@@ -691,6 +696,23 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
+ let subscription_kind =
+ if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
+ let has_zed_pro_price = subscription.items.data.iter().any(|item| {
+ item.price
+ .as_ref()
+ .map_or(false, |price| price.id.as_str() == zed_pro_price_id)
+ });
+
+ if has_zed_pro_price {
+ Some(SubscriptionKind::ZedPro)
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
.await?
@@ -727,6 +749,7 @@ async fn handle_customer_subscription_event(
existing_subscription.id,
&UpdateBillingSubscriptionParams {
billing_customer_id: ActiveValue::set(billing_customer.id),
+ kind: ActiveValue::set(subscription_kind),
stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
stripe_subscription_status: ActiveValue::set(subscription.status.into()),
stripe_cancel_at: ActiveValue::set(
@@ -741,6 +764,12 @@ async fn handle_customer_subscription_event(
.and_then(|details| details.reason)
.map(|reason| reason.into()),
),
+ stripe_current_period_start: ActiveValue::set(Some(
+ subscription.current_period_start,
+ )),
+ stripe_current_period_end: ActiveValue::set(Some(
+ subscription.current_period_end,
+ )),
},
)
.await?;
@@ -775,12 +804,15 @@ async fn handle_customer_subscription_event(
app.db
.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
+ kind: subscription_kind,
stripe_subscription_id: subscription.id.to_string(),
stripe_subscription_status: subscription.status.into(),
stripe_cancellation_reason: subscription
.cancellation_details
.and_then(|details| details.reason)
.map(|reason| reason.into()),
+ stripe_current_period_start: Some(subscription.current_period_start),
+ stripe_current_period_end: Some(subscription.current_period_end),
})
.await?;
}
@@ -1,22 +1,30 @@
-use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
+use crate::db::billing_subscription::{
+ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
+};
use super::*;
#[derive(Debug)]
pub struct CreateBillingSubscriptionParams {
pub billing_customer_id: BillingCustomerId,
+ pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
+ pub stripe_current_period_start: Option<i64>,
+ pub stripe_current_period_end: Option<i64>,
}
#[derive(Debug, Default)]
pub struct UpdateBillingSubscriptionParams {
pub billing_customer_id: ActiveValue<BillingCustomerId>,
+ pub kind: ActiveValue<Option<SubscriptionKind>>,
pub stripe_subscription_id: ActiveValue<String>,
pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
+ pub stripe_current_period_start: ActiveValue<Option<i64>>,
+ pub stripe_current_period_end: ActiveValue<Option<i64>>,
}
impl Database {
@@ -28,9 +36,12 @@ impl Database {
self.transaction(|tx| async move {
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
+ kind: ActiveValue::set(params.kind),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
+ stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
+ stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
..Default::default()
})
.exec_without_returning(&*tx)
@@ -9,10 +9,13 @@ pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingSubscriptionId,
pub billing_customer_id: BillingCustomerId,
+ pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancel_at: Option<DateTime>,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
+ pub stripe_current_period_start: Option<i64>,
+ pub stripe_current_period_end: Option<i64>,
pub created_at: DateTime,
}
@@ -34,6 +37,14 @@ impl Related<super::billing_customer::Entity> for Entity {
impl ActiveModelBehavior for ActiveModel {}
+#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
+#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
+#[serde(rename_all = "snake_case")]
+pub enum SubscriptionKind {
+ #[sea_orm(string_value = "zed_pro")]
+ ZedPro,
+}
+
/// The status of a Stripe subscription.
///
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
@@ -39,9 +39,12 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
+ kind: None,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
stripe_cancellation_reason: None,
+ stripe_current_period_start: None,
+ stripe_current_period_end: None,
})
.await
.unwrap();
@@ -74,9 +77,12 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
+ kind: None,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
stripe_cancellation_reason: None,
+ stripe_current_period_start: None,
+ stripe_current_period_end: None,
})
.await
.unwrap();