@@ -110,6 +110,13 @@ pub enum Principal {
}
impl Principal {
+ fn user(&self) -> &User {
+ match self {
+ Principal::User(user) => user,
+ Principal::Impersonated { user, .. } => user,
+ }
+ }
+
fn update_span(&self, span: &tracing::Span) {
match &self {
Principal::User(user) => {
@@ -741,7 +748,7 @@ impl Server {
supermaven_client,
};
- if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
+ if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
tracing::error!(?error, "failed to send initial client update");
return;
}
@@ -825,7 +832,6 @@ impl Server {
async fn send_initial_client_update(
&self,
connection_id: ConnectionId,
- principal: &Principal,
zed_version: ZedVersion,
mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
session: &Session,
@@ -841,7 +847,7 @@ impl Server {
let _ = send_connection_id.send(connection_id);
}
- match principal {
+ match &session.principal {
Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
if !user.connected_once {
self.peer.send(connection_id, proto::ShowContacts {})?;
@@ -851,7 +857,7 @@ impl Server {
.await?;
}
- update_user_plan(user.id, session).await?;
+ update_user_plan(session).await?;
let contacts = self.app_state.db.get_contacts(user.id).await?;
@@ -941,10 +947,10 @@ impl Server {
.context("user not found")?;
let update_user_plan = make_update_user_plan_message(
+ &user,
+ user.admin,
&self.app_state.db,
self.app_state.llm_db.clone(),
- user_id,
- user.admin,
)
.await?;
@@ -2707,26 +2713,25 @@ async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Re
}
async fn make_update_user_plan_message(
+ user: &User,
+ is_staff: bool,
db: &Arc<Database>,
llm_db: Option<Arc<LlmDatabase>>,
- user_id: UserId,
- is_staff: bool,
) -> Result<proto::UpdateUserPlan> {
- let feature_flags = db.get_user_flags(user_id).await?;
- let plan = current_plan(db, user_id, is_staff).await?;
- let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
- let billing_preferences = db.get_billing_preferences(user_id).await?;
- let user = db.get_user_by_id(user_id).await?;
+ let feature_flags = db.get_user_flags(user.id).await?;
+ let plan = current_plan(db, user.id, is_staff).await?;
+ let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
+ let billing_preferences = db.get_billing_preferences(user.id).await?;
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
- let subscription = db.get_active_billing_subscription(user_id).await?;
+ let subscription = db.get_active_billing_subscription(user.id).await?;
let subscription_period =
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db
- .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
+ .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
.await?
} else {
None
@@ -2737,17 +2742,8 @@ async fn make_update_user_plan_message(
(None, None)
};
- // Calculate account_too_young
- let account_too_young = if matches!(plan, proto::Plan::ZedPro) {
- // If they have paid, then we allow them to use all of the features
- false
- } else if let Some(user) = user {
- // If we have access to the profile age, we use that
- chrono::Utc::now().naive_utc() - user.account_created_at() < MIN_ACCOUNT_AGE_FOR_LLM_USE
- } else {
- // Default to false otherwise
- false
- };
+ let account_too_young =
+ !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
Ok(proto::UpdateUserPlan {
plan: plan.into(),
@@ -2822,14 +2818,14 @@ async fn make_update_user_plan_message(
})
}
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+async fn update_user_plan(session: &Session) -> Result<()> {
let db = session.db().await;
let update_user_plan = make_update_user_plan_message(
+ session.principal.user(),
+ session.is_staff(),
&db.0,
session.app_state.llm_db.clone(),
- user_id,
- session.is_staff(),
)
.await?;