diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index b213354a1bbc2a53a56712a357dc38ec67f1ac12..55d4e9b214e0a852fe97bdee7e74588da5a0a74b 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -101,6 +101,8 @@ pub struct UserStore { participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, current_plan: Option, + trial_started_at: Option>, + is_usage_based_billing_enabled: Option, current_user: watch::Receiver>>, accepted_tos_at: Option>>, contacts: Vec>, @@ -160,6 +162,8 @@ impl UserStore { by_github_login: Default::default(), current_user: current_user_rx, current_plan: None, + trial_started_at: None, + is_usage_based_billing_enabled: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -321,6 +325,11 @@ impl UserStore { ) -> Result<()> { this.update(&mut cx, |this, cx| { this.current_plan = Some(message.payload.plan()); + this.trial_started_at = message + .payload + .trial_started_at + .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); + this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; cx.notify(); })?; Ok(()) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 4646069d74c79ba3b0ba5d1475c0657fa011dc98..1d8344045bcb7f2b3d6bad1cac92a9519128d25a 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -482,7 +482,9 @@ CREATE TABLE IF NOT EXISTS billing_preferences ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, user_id INTEGER NOT NULL REFERENCES users (id), - max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL + max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL, + model_request_overages_enabled bool NOT NULL DEFAULT FALSE, + model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 16a884bd1013c7c1d353a501fc291191a5bf5c8d..68d92f4be62ded0e422b0ae6627ae702b2be4e9d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod connection_pool; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::LlmTokenClaims; +use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims}; use crate::{ AppState, Error, Result, auth, db::{ @@ -37,6 +37,7 @@ use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; use rpc::proto::split_repository_update; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; +use util::maybe; use futures::{ FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -2701,14 +2702,111 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> { - let plan = session.current_plan(&session.db().await).await?; +async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + + let feature_flags = db.get_user_flags(user_id).await?; + let plan = session.current_plan(&db).await?; + let billing_customer = db.get_billing_customer_by_user_id(user_id).await?; + let billing_preferences = db.get_billing_preferences(user_id).await?; + + let usage = if let Some(llm_db) = session.app_state.llm_db.clone() { + let subscription = db.get_active_billing_subscription(user_id).await?; + + let subscription_period = maybe!({ + let subscription = subscription?; + let period_start_at = subscription.current_period_start_at()?; + let period_end_at = subscription.current_period_end_at()?; + + Some((period_start_at, period_end_at)) + }); + + if let Some((period_start_at, period_end_at)) = subscription_period { + llm_db + .get_subscription_usage_for_period(user_id, period_start_at, period_end_at) + .await? + } else { + None + } + } else { + None + }; session .peer .send( session.connection_id, - proto::UpdateUserPlan { plan: plan.into() }, + proto::UpdateUserPlan { + plan: plan.into(), + trial_started_at: billing_customer + .and_then(|billing_customer| billing_customer.trial_started_at) + .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), + is_usage_based_billing_enabled: billing_preferences + .map(|preferences| preferences.model_request_overages_enabled), + usage: usage.map(|usage| { + let plan = match plan { + proto::Plan::Free => zed_llm_client::Plan::Free, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + }; + + let model_requests_limit = match plan.model_requests_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == zed_llm_client::Plan::ZedProTrial + && feature_flags + .iter() + .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) + { + 1_000 + } else { + limit + }; + + zed_llm_client::UsageLimit::Limited(limit) + } + zed_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited + } + }; + + proto::SubscriptionUsage { + model_requests_usage_amount: usage.model_requests as u32, + model_requests_usage_limit: Some(proto::UsageLimit { + variant: Some(match model_requests_limit { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited( + proto::usage_limit::Limited { + limit: limit as u32, + }, + ) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited( + proto::usage_limit::Unlimited {}, + ) + } + }), + }), + edit_predictions_usage_amount: usage.edit_predictions as u32, + edit_predictions_usage_limit: Some(proto::UsageLimit { + variant: Some(match plan.edit_predictions_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited( + proto::usage_limit::Limited { + limit: limit as u32, + }, + ) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited( + proto::usage_limit::Unlimited {}, + ) + } + }), + }), + } + }), + }, ) .trace_err(); diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index bb1d8357c0a0181ede289f592ff634aff8a53037..fd187416b13a6e4591eb5b3e577b00cd8878fcfd 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -23,6 +23,29 @@ enum Plan { message UpdateUserPlan { Plan plan = 1; + optional uint64 trial_started_at = 2; + optional bool is_usage_based_billing_enabled = 3; + optional SubscriptionUsage usage = 4; +} + +message SubscriptionUsage { + uint32 model_requests_usage_amount = 1; + UsageLimit model_requests_usage_limit = 2; + uint32 edit_predictions_usage_amount = 3; + UsageLimit edit_predictions_usage_limit = 4; +} + +message UsageLimit { + oneof variant { + Limited limited = 1; + Unlimited unlimited = 2; + } + + message Limited { + uint32 limit = 1; + } + + message Unlimited {} } message AcceptTermsOfService {}