collab: Transfer existing usage from trial to Pro (#28884)

Marshall Bowers and Mikayla created

This PR adds support for transferring any existing usage from a trial
subscription to a Zed Pro subscription when the user upgrades.

Release Notes:

- N/A

---------

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

crates/collab/src/api/billing.rs                           |  22 +
crates/collab/src/llm/db/queries/subscription_usages.rs    | 158 +++++++
crates/collab/src/llm/db/tests.rs                          |   1 
crates/collab/src/llm/db/tests/subscription_usage_tests.rs |  69 +++
4 files changed, 243 insertions(+), 7 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -843,6 +843,28 @@ async fn handle_customer_subscription_event(
         .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
         .await?
     {
+        let llm_db = app
+            .llm_db
+            .clone()
+            .ok_or_else(|| anyhow!("LLM DB not initialized"))?;
+
+        let new_period_start_at =
+            chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
+                .ok_or_else(|| anyhow!("No subscription period start"))?;
+        let new_period_end_at =
+            chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
+                .ok_or_else(|| anyhow!("No subscription period end"))?;
+
+        llm_db
+            .transfer_existing_subscription_usage(
+                billing_customer.user_id,
+                &existing_subscription,
+                subscription_kind,
+                new_period_start_at,
+                new_period_end_at,
+            )
+            .await?;
+
         app.db
             .update_billing_subscription(
                 existing_subscription.id,

crates/collab/src/llm/db/queries/subscription_usages.rs 🔗

@@ -1,8 +1,87 @@
-use crate::db::UserId;
+use chrono::Timelike;
+use time::PrimitiveDateTime;
+
+use crate::db::billing_subscription::SubscriptionKind;
+use crate::db::{UserId, billing_subscription};
 
 use super::*;
 
+fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
+    use chrono::{Datelike as _, Timelike as _};
+
+    let date = time::Date::from_calendar_date(
+        datetime.year(),
+        time::Month::try_from(datetime.month() as u8).unwrap(),
+        datetime.day() as u8,
+    )?;
+
+    let time = time::Time::from_hms_nano(
+        datetime.hour() as u8,
+        datetime.minute() as u8,
+        datetime.second() as u8,
+        datetime.nanosecond(),
+    )?;
+
+    Ok(PrimitiveDateTime::new(date, time))
+}
+
 impl LlmDatabase {
+    pub async fn create_subscription_usage(
+        &self,
+        user_id: UserId,
+        period_start_at: DateTimeUtc,
+        period_end_at: DateTimeUtc,
+        plan: SubscriptionKind,
+        model_requests: i32,
+        edit_predictions: i32,
+    ) -> Result<subscription_usage::Model> {
+        self.transaction(|tx| async move {
+            self.create_subscription_usage_in_tx(
+                user_id,
+                period_start_at,
+                period_end_at,
+                plan,
+                model_requests,
+                edit_predictions,
+                &tx,
+            )
+            .await
+        })
+        .await
+    }
+
+    async fn create_subscription_usage_in_tx(
+        &self,
+        user_id: UserId,
+        period_start_at: DateTimeUtc,
+        period_end_at: DateTimeUtc,
+        plan: SubscriptionKind,
+        model_requests: i32,
+        edit_predictions: i32,
+        tx: &DatabaseTransaction,
+    ) -> Result<subscription_usage::Model> {
+        // Clear out the nanoseconds so that these timestamps are comparable with Unix timestamps.
+        let period_start_at = period_start_at.with_nanosecond(0).unwrap();
+        let period_end_at = period_end_at.with_nanosecond(0).unwrap();
+
+        let period_start_at = convert_chrono_to_time(period_start_at)?;
+        let period_end_at = convert_chrono_to_time(period_end_at)?;
+
+        Ok(
+            subscription_usage::Entity::insert(subscription_usage::ActiveModel {
+                id: ActiveValue::not_set(),
+                user_id: ActiveValue::set(user_id),
+                period_start_at: ActiveValue::set(period_start_at),
+                period_end_at: ActiveValue::set(period_end_at),
+                plan: ActiveValue::set(plan),
+                model_requests: ActiveValue::set(model_requests),
+                edit_predictions: ActiveValue::set(edit_predictions),
+            })
+            .exec_with_returning(tx)
+            .await?,
+        )
+    }
+
     pub async fn get_subscription_usage_for_period(
         &self,
         user_id: UserId,
@@ -10,12 +89,77 @@ impl LlmDatabase {
         period_end_at: DateTimeUtc,
     ) -> Result<Option<subscription_usage::Model>> {
         self.transaction(|tx| async move {
-            Ok(subscription_usage::Entity::find()
-                .filter(subscription_usage::Column::UserId.eq(user_id))
-                .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
-                .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
-                .one(&*tx)
-                .await?)
+            self.get_subscription_usage_for_period_in_tx(
+                user_id,
+                period_start_at,
+                period_end_at,
+                &tx,
+            )
+            .await
+        })
+        .await
+    }
+
+    async fn get_subscription_usage_for_period_in_tx(
+        &self,
+        user_id: UserId,
+        period_start_at: DateTimeUtc,
+        period_end_at: DateTimeUtc,
+        tx: &DatabaseTransaction,
+    ) -> Result<Option<subscription_usage::Model>> {
+        Ok(subscription_usage::Entity::find()
+            .filter(subscription_usage::Column::UserId.eq(user_id))
+            .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
+            .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
+            .one(tx)
+            .await?)
+    }
+
+    pub async fn transfer_existing_subscription_usage(
+        &self,
+        user_id: UserId,
+        existing_subscription: &billing_subscription::Model,
+        new_subscription_kind: Option<SubscriptionKind>,
+        new_period_start_at: DateTimeUtc,
+        new_period_end_at: DateTimeUtc,
+    ) -> Result<Option<subscription_usage::Model>> {
+        self.transaction(|tx| async move {
+            match existing_subscription.kind {
+                Some(SubscriptionKind::ZedProTrial) => {
+                    let trial_period_start_at = existing_subscription
+                        .current_period_start_at()
+                        .ok_or_else(|| anyhow!("No trial subscription period start"))?;
+                    let trial_period_end_at = existing_subscription
+                        .current_period_end_at()
+                        .ok_or_else(|| anyhow!("No trial subscription period end"))?;
+
+                    let existing_usage = self
+                        .get_subscription_usage_for_period_in_tx(
+                            user_id,
+                            trial_period_start_at,
+                            trial_period_end_at,
+                            &tx,
+                        )
+                        .await?;
+                    if let Some(existing_usage) = existing_usage {
+                        return Ok(Some(
+                            self.create_subscription_usage_in_tx(
+                                user_id,
+                                new_period_start_at,
+                                new_period_end_at,
+                                new_subscription_kind.unwrap_or(existing_usage.plan),
+                                existing_usage.model_requests,
+                                existing_usage.edit_predictions,
+                                &tx,
+                            )
+                            .await?,
+                        ));
+                    }
+                }
+                _ => {}
+            }
+
+            Ok(None)
         })
         .await
     }

crates/collab/src/llm/db/tests/subscription_usage_tests.rs 🔗

@@ -0,0 +1,69 @@
+use chrono::{Duration, Utc};
+use pretty_assertions::assert_eq;
+
+use crate::db::billing_subscription::SubscriptionKind;
+use crate::db::{UserId, billing_subscription};
+use crate::llm::db::LlmDatabase;
+use crate::test_llm_db;
+
+test_llm_db!(
+    test_transfer_existing_subscription_usage,
+    test_transfer_existing_subscription_usage_postgres
+);
+
+async fn test_transfer_existing_subscription_usage(db: &mut LlmDatabase) {
+    let user_id = UserId(1);
+
+    let now = Utc::now();
+
+    let trial_period_start_at = now - Duration::days(14);
+    let trial_period_end_at = now;
+
+    let new_period_start_at = now;
+    let new_period_end_at = now + Duration::days(30);
+
+    let existing_subscription = billing_subscription::Model {
+        kind: Some(SubscriptionKind::ZedProTrial),
+        stripe_current_period_start: Some(trial_period_start_at.timestamp()),
+        stripe_current_period_end: Some(trial_period_end_at.timestamp()),
+        ..Default::default()
+    };
+
+    let existing_usage = db
+        .create_subscription_usage(
+            user_id,
+            trial_period_start_at,
+            trial_period_end_at,
+            SubscriptionKind::ZedProTrial,
+            25,
+            1_000,
+        )
+        .await
+        .unwrap();
+
+    let transferred_usage = db
+        .transfer_existing_subscription_usage(
+            user_id,
+            &existing_subscription,
+            Some(SubscriptionKind::ZedPro),
+            new_period_start_at,
+            new_period_end_at,
+        )
+        .await
+        .unwrap();
+
+    assert!(
+        transferred_usage.is_some(),
+        "subscription usage not transferred successfully"
+    );
+    let transferred_usage = transferred_usage.unwrap();
+
+    assert_eq!(
+        transferred_usage.model_requests,
+        existing_usage.model_requests
+    );
+    assert_eq!(
+        transferred_usage.edit_predictions,
+        existing_usage.edit_predictions
+    );
+}