@@ -2,6 +2,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
+use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::{
AppState, Error, Result, auth,
@@ -67,7 +68,7 @@ use std::{
time::{Duration, Instant},
};
use time::OffsetDateTime;
-use tokio::sync::{MutexGuard, Semaphore, watch};
+use tokio::sync::{Semaphore, watch};
use tower::ServiceBuilder;
use tracing::{
Instrument,
@@ -166,29 +167,6 @@ impl Session {
}
}
- pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
- if self.is_staff() {
- return Ok(proto::Plan::ZedPro);
- }
-
- let user_id = self.user_id();
-
- let subscription = db.get_active_billing_subscription(user_id).await?;
- let subscription_kind = subscription.and_then(|subscription| subscription.kind);
-
- let plan = if let Some(subscription_kind) = subscription_kind {
- match subscription_kind {
- SubscriptionKind::ZedPro => proto::Plan::ZedPro,
- SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
- SubscriptionKind::ZedFree => proto::Plan::Free,
- }
- } else {
- proto::Plan::Free
- };
-
- Ok(plan)
- }
-
fn user_id(&self) -> UserId {
match &self.principal {
Principal::User(user) => user.id,
@@ -953,6 +931,32 @@ impl Server {
Ok(())
}
+ pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
+ let user = self
+ .app_state
+ .db
+ .get_user_by_id(user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let update_user_plan = make_update_user_plan_message(
+ &self.app_state.db,
+ self.app_state.llm_db.clone(),
+ user_id,
+ user.admin,
+ )
+ .await?;
+
+ let pool = self.connection_pool.lock();
+ for connection_id in pool.user_connection_ids(user_id) {
+ self.peer
+ .send(connection_id, update_user_plan.clone())
+ .trace_err();
+ }
+
+ Ok(())
+ }
+
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
@@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
- let db = session.db().await;
+async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
+ if is_staff {
+ return Ok(proto::Plan::ZedPro);
+ }
+
+ let subscription = db.get_active_billing_subscription(user_id).await?;
+ let subscription_kind = subscription.and_then(|subscription| subscription.kind);
+
+ let plan = if let Some(subscription_kind) = subscription_kind {
+ match subscription_kind {
+ SubscriptionKind::ZedPro => proto::Plan::ZedPro,
+ SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
+ SubscriptionKind::ZedFree => proto::Plan::Free,
+ }
+ } else {
+ proto::Plan::Free
+ };
+ Ok(plan)
+}
+
+async fn make_update_user_plan_message(
+ db: &Arc<Database>,
+ llm_db: Option<Arc<LlmDatabase>>,
+ user_id: UserId,
+ is_staff: bool,
+) -> Result<proto::UpdateUserPlan> {
let feature_flags = db.get_user_flags(user_id).await?;
- let plan = session.current_plan(&db).await?;
+ let plan = current_plan(db, user_id, is_staff).await?;
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
let billing_preferences = db.get_billing_preferences(user_id).await?;
- let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
+ let (subscription_period, usage) = if let Some(llm_db) = llm_db {
let subscription = db.get_active_billing_subscription(user_id).await?;
- let subscription_period = crate::db::billing_subscription::Model::current_period(
- subscription,
- session.is_staff(),
- );
+ let subscription_period =
+ crate::db::billing_subscription::Model::current_period(subscription, is_staff);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db
@@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
(None, None)
};
- session
- .peer
- .send(
- session.connection_id,
- proto::UpdateUserPlan {
- plan: plan.into(),
- trial_started_at: billing_customer
- .and_then(|billing_customer| billing_customer.trial_started_at)
- .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
- is_usage_based_billing_enabled: if session.is_staff() {
- Some(true)
- } else {
- billing_preferences
- .map(|preferences| preferences.model_request_overages_enabled)
- },
- subscription_period: subscription_period.map(|(started_at, ended_at)| {
- proto::SubscriptionPeriod {
- started_at: started_at.timestamp() as u64,
- ended_at: ended_at.timestamp() as u64,
- }
- }),
- usage: usage.map(|usage| {
- let plan = match plan {
- proto::Plan::Free => zed_llm_client::Plan::ZedFree,
- proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
- proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+ Ok(proto::UpdateUserPlan {
+ plan: plan.into(),
+ trial_started_at: billing_customer
+ .and_then(|billing_customer| billing_customer.trial_started_at)
+ .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
+ is_usage_based_billing_enabled: if is_staff {
+ Some(true)
+ } else {
+ billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
+ },
+ subscription_period: subscription_period.map(|(started_at, ended_at)| {
+ proto::SubscriptionPeriod {
+ started_at: started_at.timestamp() as u64,
+ ended_at: ended_at.timestamp() as u64,
+ }
+ }),
+ usage: usage.map(|usage| {
+ let plan = match plan {
+ proto::Plan::Free => zed_llm_client::Plan::ZedFree,
+ proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
+ proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+ };
+
+ let model_requests_limit = match plan.model_requests_limit() {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ let limit = if plan == zed_llm_client::Plan::ZedProTrial
+ && feature_flags
+ .iter()
+ .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
+ {
+ 1_000
+ } else {
+ limit
};
- let model_requests_limit = match plan.model_requests_limit() {
- zed_llm_client::UsageLimit::Limited(limit) => {
- let limit = if plan == zed_llm_client::Plan::ZedProTrial
- && feature_flags
- .iter()
- .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
- {
- 1_000
- } else {
- limit
- };
+ zed_llm_client::UsageLimit::Limited(limit)
+ }
+ zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
+ };
- zed_llm_client::UsageLimit::Limited(limit)
+ proto::SubscriptionUsage {
+ model_requests_usage_amount: usage.model_requests as u32,
+ model_requests_usage_limit: Some(proto::UsageLimit {
+ variant: Some(match model_requests_limit {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+ limit: limit as u32,
+ })
}
zed_llm_client::UsageLimit::Unlimited => {
- zed_llm_client::UsageLimit::Unlimited
+ proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
- };
-
- proto::SubscriptionUsage {
- model_requests_usage_amount: usage.model_requests as u32,
- model_requests_usage_limit: Some(proto::UsageLimit {
- variant: Some(match model_requests_limit {
- zed_llm_client::UsageLimit::Limited(limit) => {
- proto::usage_limit::Variant::Limited(
- proto::usage_limit::Limited {
- limit: limit as u32,
- },
- )
- }
- zed_llm_client::UsageLimit::Unlimited => {
- proto::usage_limit::Variant::Unlimited(
- proto::usage_limit::Unlimited {},
- )
- }
- }),
- }),
- edit_predictions_usage_amount: usage.edit_predictions as u32,
- edit_predictions_usage_limit: Some(proto::UsageLimit {
- variant: Some(match plan.edit_predictions_limit() {
- zed_llm_client::UsageLimit::Limited(limit) => {
- proto::usage_limit::Variant::Limited(
- proto::usage_limit::Limited {
- limit: limit as u32,
- },
- )
- }
- zed_llm_client::UsageLimit::Unlimited => {
- proto::usage_limit::Variant::Unlimited(
- proto::usage_limit::Unlimited {},
- )
- }
- }),
- }),
- }
+ }),
}),
- },
- )
+ edit_predictions_usage_amount: usage.edit_predictions as u32,
+ edit_predictions_usage_limit: Some(proto::UsageLimit {
+ variant: Some(match plan.edit_predictions_limit() {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+ limit: limit as u32,
+ })
+ }
+ zed_llm_client::UsageLimit::Unlimited => {
+ proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
+ }
+ }),
+ }),
+ }
+ }),
+ })
+}
+
+async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+ let db = session.db().await;
+
+ let update_user_plan = make_update_user_plan_message(
+ &db.0,
+ session.app_state.llm_db.clone(),
+ user_id,
+ session.is_staff(),
+ )
+ .await?;
+
+ session
+ .peer
+ .send(session.connection_id, update_user_plan)
.trace_err();
Ok(())