@@ -101,6 +101,8 @@ pub struct UserStore {
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
+ trial_started_at: Option<DateTime<Utc>>,
+ is_usage_based_billing_enabled: Option<bool>,
current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>,
contacts: Vec<Arc<Contact>>,
@@ -160,6 +162,8 @@ impl UserStore {
by_github_login: Default::default(),
current_user: current_user_rx,
current_plan: None,
+ trial_started_at: None,
+ is_usage_based_billing_enabled: None,
accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
@@ -321,6 +325,11 @@ impl UserStore {
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.current_plan = Some(message.payload.plan());
+ this.trial_started_at = message
+ .payload
+ .trial_started_at
+ .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
+ this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
cx.notify();
})?;
Ok(())
@@ -482,7 +482,9 @@ CREATE TABLE IF NOT EXISTS billing_preferences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users (id),
- max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
+ max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL,
+ model_request_overages_enabled bool NOT NULL DEFAULT FALSE,
+ model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
@@ -2,7 +2,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
-use crate::llm::LlmTokenClaims;
+use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::{
AppState, Error, Result, auth,
db::{
@@ -37,6 +37,7 @@ use core::fmt::{self, Debug, Formatter};
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
+use util::maybe;
use futures::{
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
@@ -2693,14 +2694,111 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
-async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
- let plan = session.current_plan(&session.db().await).await?;
+async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+ let db = session.db().await;
+
+ let feature_flags = db.get_user_flags(user_id).await?;
+ let plan = session.current_plan(&db).await?;
+ let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
+ let billing_preferences = db.get_billing_preferences(user_id).await?;
+
+ let usage = if let Some(llm_db) = session.app_state.llm_db.clone() {
+ let subscription = db.get_active_billing_subscription(user_id).await?;
+
+ let subscription_period = maybe!({
+ let subscription = subscription?;
+ let period_start_at = subscription.current_period_start_at()?;
+ let period_end_at = subscription.current_period_end_at()?;
+
+ Some((period_start_at, period_end_at))
+ });
+
+ if let Some((period_start_at, period_end_at)) = subscription_period {
+ llm_db
+ .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
+ .await?
+ } else {
+ None
+ }
+ } else {
+ None
+ };
session
.peer
.send(
session.connection_id,
- proto::UpdateUserPlan { plan: plan.into() },
+ proto::UpdateUserPlan {
+ plan: plan.into(),
+ trial_started_at: billing_customer
+ .and_then(|billing_customer| billing_customer.trial_started_at)
+ .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
+ is_usage_based_billing_enabled: billing_preferences
+ .map(|preferences| preferences.model_request_overages_enabled),
+ usage: usage.map(|usage| {
+ let plan = match plan {
+ proto::Plan::Free => zed_llm_client::Plan::Free,
+ proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
+ proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+ };
+
+ let model_requests_limit = match plan.model_requests_limit() {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ let limit = if plan == zed_llm_client::Plan::ZedProTrial
+ && feature_flags
+ .iter()
+ .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
+ {
+ 1_000
+ } else {
+ limit
+ };
+
+ zed_llm_client::UsageLimit::Limited(limit)
+ }
+ zed_llm_client::UsageLimit::Unlimited => {
+ zed_llm_client::UsageLimit::Unlimited
+ }
+ };
+
+ proto::SubscriptionUsage {
+ model_requests_usage_amount: usage.model_requests as u32,
+ model_requests_usage_limit: Some(proto::UsageLimit {
+ variant: Some(match model_requests_limit {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ proto::usage_limit::Variant::Limited(
+ proto::usage_limit::Limited {
+ limit: limit as u32,
+ },
+ )
+ }
+ zed_llm_client::UsageLimit::Unlimited => {
+ proto::usage_limit::Variant::Unlimited(
+ proto::usage_limit::Unlimited {},
+ )
+ }
+ }),
+ }),
+ edit_predictions_usage_amount: usage.edit_predictions as u32,
+ edit_predictions_usage_limit: Some(proto::UsageLimit {
+ variant: Some(match plan.edit_predictions_limit() {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ proto::usage_limit::Variant::Limited(
+ proto::usage_limit::Limited {
+ limit: limit as u32,
+ },
+ )
+ }
+ zed_llm_client::UsageLimit::Unlimited => {
+ proto::usage_limit::Variant::Unlimited(
+ proto::usage_limit::Unlimited {},
+ )
+ }
+ }),
+ }),
+ }
+ }),
+ },
)
.trace_err();