collab: Pass down billing information in `UpdatePlan` message (#29929)

Marshall Bowers created

This PR updates the `UpdatePlan` message to include some additional
information about the user's billing subscription usage.

Release Notes:

- N/A

Change summary

crates/client/src/user.rs                                      |   9 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   4 
crates/collab/src/rpc.rs                                       | 106 +++
crates/proto/proto/app.proto                                   |  23 
4 files changed, 137 insertions(+), 5 deletions(-)

Detailed changes

crates/client/src/user.rs 🔗

@@ -101,6 +101,8 @@ pub struct UserStore {
     participant_indices: HashMap<u64, ParticipantIndex>,
     update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
     current_plan: Option<proto::Plan>,
+    trial_started_at: Option<DateTime<Utc>>,
+    is_usage_based_billing_enabled: Option<bool>,
     current_user: watch::Receiver<Option<Arc<User>>>,
     accepted_tos_at: Option<Option<DateTime<Utc>>>,
     contacts: Vec<Arc<Contact>>,
@@ -160,6 +162,8 @@ impl UserStore {
             by_github_login: Default::default(),
             current_user: current_user_rx,
             current_plan: None,
+            trial_started_at: None,
+            is_usage_based_billing_enabled: None,
             accepted_tos_at: None,
             contacts: Default::default(),
             incoming_contact_requests: Default::default(),
@@ -321,6 +325,11 @@ impl UserStore {
     ) -> Result<()> {
         this.update(&mut cx, |this, cx| {
             this.current_plan = Some(message.payload.plan());
+            this.trial_started_at = message
+                .payload
+                .trial_started_at
+                .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
+            this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
             cx.notify();
         })?;
         Ok(())

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -482,7 +482,9 @@ CREATE TABLE IF NOT EXISTS billing_preferences (
     id INTEGER PRIMARY KEY AUTOINCREMENT,
     created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
     user_id INTEGER NOT NULL REFERENCES users (id),
-    max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
+    max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL,
+    model_request_overages_enabled bool NOT NULL DEFAULT FALSE,
+    model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0
 );
 
 CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);

crates/collab/src/rpc.rs 🔗

@@ -2,7 +2,7 @@ mod connection_pool;
 
 use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::db::billing_subscription::SubscriptionKind;
-use crate::llm::LlmTokenClaims;
+use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
 use crate::{
     AppState, Error, Result, auth,
     db::{
@@ -37,6 +37,7 @@ use core::fmt::{self, Debug, Formatter};
 use reqwest_client::ReqwestClient;
 use rpc::proto::split_repository_update;
 use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
+use util::maybe;
 
 use futures::{
     FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
@@ -2693,14 +2694,111 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
     version.0.minor() < 139
 }
 
-async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
-    let plan = session.current_plan(&session.db().await).await?;
+async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+    let db = session.db().await;
+
+    let feature_flags = db.get_user_flags(user_id).await?;
+    let plan = session.current_plan(&db).await?;
+    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
+    let billing_preferences = db.get_billing_preferences(user_id).await?;
+
+    let usage = if let Some(llm_db) = session.app_state.llm_db.clone() {
+        let subscription = db.get_active_billing_subscription(user_id).await?;
+
+        let subscription_period = maybe!({
+            let subscription = subscription?;
+            let period_start_at = subscription.current_period_start_at()?;
+            let period_end_at = subscription.current_period_end_at()?;
+
+            Some((period_start_at, period_end_at))
+        });
+
+        if let Some((period_start_at, period_end_at)) = subscription_period {
+            llm_db
+                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
+                .await?
+        } else {
+            None
+        }
+    } else {
+        None
+    };
 
     session
         .peer
         .send(
             session.connection_id,
-            proto::UpdateUserPlan { plan: plan.into() },
+            proto::UpdateUserPlan {
+                plan: plan.into(),
+                trial_started_at: billing_customer
+                    .and_then(|billing_customer| billing_customer.trial_started_at)
+                    .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
+                is_usage_based_billing_enabled: billing_preferences
+                    .map(|preferences| preferences.model_request_overages_enabled),
+                usage: usage.map(|usage| {
+                    let plan = match plan {
+                        proto::Plan::Free => zed_llm_client::Plan::Free,
+                        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
+                        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+                    };
+
+                    let model_requests_limit = match plan.model_requests_limit() {
+                        zed_llm_client::UsageLimit::Limited(limit) => {
+                            let limit = if plan == zed_llm_client::Plan::ZedProTrial
+                                && feature_flags
+                                    .iter()
+                                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
+                            {
+                                1_000
+                            } else {
+                                limit
+                            };
+
+                            zed_llm_client::UsageLimit::Limited(limit)
+                        }
+                        zed_llm_client::UsageLimit::Unlimited => {
+                            zed_llm_client::UsageLimit::Unlimited
+                        }
+                    };
+
+                    proto::SubscriptionUsage {
+                        model_requests_usage_amount: usage.model_requests as u32,
+                        model_requests_usage_limit: Some(proto::UsageLimit {
+                            variant: Some(match model_requests_limit {
+                                zed_llm_client::UsageLimit::Limited(limit) => {
+                                    proto::usage_limit::Variant::Limited(
+                                        proto::usage_limit::Limited {
+                                            limit: limit as u32,
+                                        },
+                                    )
+                                }
+                                zed_llm_client::UsageLimit::Unlimited => {
+                                    proto::usage_limit::Variant::Unlimited(
+                                        proto::usage_limit::Unlimited {},
+                                    )
+                                }
+                            }),
+                        }),
+                        edit_predictions_usage_amount: usage.edit_predictions as u32,
+                        edit_predictions_usage_limit: Some(proto::UsageLimit {
+                            variant: Some(match plan.edit_predictions_limit() {
+                                zed_llm_client::UsageLimit::Limited(limit) => {
+                                    proto::usage_limit::Variant::Limited(
+                                        proto::usage_limit::Limited {
+                                            limit: limit as u32,
+                                        },
+                                    )
+                                }
+                                zed_llm_client::UsageLimit::Unlimited => {
+                                    proto::usage_limit::Variant::Unlimited(
+                                        proto::usage_limit::Unlimited {},
+                                    )
+                                }
+                            }),
+                        }),
+                    }
+                }),
+            },
         )
         .trace_err();
 

crates/proto/proto/app.proto 🔗

@@ -23,6 +23,29 @@ enum Plan {
 
 message UpdateUserPlan {
     Plan plan = 1;
+    optional uint64 trial_started_at = 2;
+    optional bool is_usage_based_billing_enabled = 3;
+    optional SubscriptionUsage usage = 4;
+}
+
+message SubscriptionUsage {
+    uint32 model_requests_usage_amount = 1;
+    UsageLimit model_requests_usage_limit = 2;
+    uint32 edit_predictions_usage_amount = 3;
+    UsageLimit edit_predictions_usage_limit = 4;
+}
+
+message UsageLimit {
+    oneof variant {
+        Limited limited = 1;
+        Unlimited unlimited = 2;
+    }
+
+    message Limited {
+        uint32 limit = 1;
+    }
+
+    message Unlimited {}
 }
 
 message AcceptTermsOfService {}