Detailed changes
@@ -3324,7 +3324,6 @@ dependencies = [
"http_client",
"hyper 0.14.32",
"indoc",
- "jsonwebtoken",
"language",
"language_model",
"livekit_api",
@@ -3370,7 +3369,6 @@ dependencies = [
"telemetry_events",
"text",
"theme",
- "thiserror 2.0.12",
"time",
"tokio",
"toml 0.8.20",
@@ -39,7 +39,6 @@ futures.workspace = true
gpui.workspace = true
hex.workspace = true
http_client.workspace = true
-jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
@@ -65,7 +64,6 @@ subtle.workspace = true
supermaven_api.workspace = true
telemetry_events.workspace = true
text.workspace = true
-thiserror.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["full"] }
toml.workspace = true
@@ -1,7 +1,4 @@
pub mod db;
-mod token;
-
-pub use token::*;
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
@@ -1,146 +0,0 @@
-use crate::db::billing_subscription::SubscriptionKind;
-use crate::db::{billing_customer, billing_subscription, user};
-use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
-use crate::{Config, db::billing_preference};
-use anyhow::{Context as _, Result};
-use chrono::{NaiveDateTime, Utc};
-use cloud_llm_client::Plan;
-use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
-use serde::{Deserialize, Serialize};
-use std::time::Duration;
-use thiserror::Error;
-use uuid::Uuid;
-
-#[derive(Clone, Debug, Default, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct LlmTokenClaims {
- pub iat: u64,
- pub exp: u64,
- pub jti: String,
- pub user_id: u64,
- pub system_id: Option<String>,
- pub metrics_id: Uuid,
- pub github_user_login: String,
- pub account_created_at: NaiveDateTime,
- pub is_staff: bool,
- pub has_llm_closed_beta_feature_flag: bool,
- pub bypass_account_age_check: bool,
- pub use_llm_request_queue: bool,
- pub plan: Plan,
- pub has_extended_trial: bool,
- pub subscription_period: (NaiveDateTime, NaiveDateTime),
- pub enable_model_request_overages: bool,
- pub model_request_overages_spend_limit_in_cents: u32,
- pub can_use_web_search_tool: bool,
- #[serde(default)]
- pub has_overdue_invoices: bool,
-}
-
-const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
-
-impl LlmTokenClaims {
- pub fn create(
- user: &user::Model,
- is_staff: bool,
- billing_customer: billing_customer::Model,
- billing_preferences: Option<billing_preference::Model>,
- feature_flags: &Vec<String>,
- subscription: billing_subscription::Model,
- system_id: Option<String>,
- config: &Config,
- ) -> Result<String> {
- let secret = config
- .llm_api_secret
- .as_ref()
- .context("no LLM API secret")?;
-
- let plan = if is_staff {
- Plan::ZedPro
- } else {
- subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
- SubscriptionKind::ZedFree => Plan::ZedFree,
- SubscriptionKind::ZedPro => Plan::ZedPro,
- SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
- })
- };
- let subscription_period =
- billing_subscription::Model::current_period(Some(subscription), is_staff)
- .map(|(start, end)| (start.naive_utc(), end.naive_utc()))
- .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
-
- let now = Utc::now();
- let claims = Self {
- iat: now.timestamp() as u64,
- exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
- jti: uuid::Uuid::new_v4().to_string(),
- user_id: user.id.to_proto(),
- system_id,
- metrics_id: user.metrics_id,
- github_user_login: user.github_login.clone(),
- account_created_at: user.account_created_at(),
- is_staff,
- has_llm_closed_beta_feature_flag: feature_flags
- .iter()
- .any(|flag| flag == "llm-closed-beta"),
- bypass_account_age_check: feature_flags
- .iter()
- .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
- can_use_web_search_tool: true,
- use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
- plan,
- has_extended_trial: feature_flags
- .iter()
- .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
- subscription_period,
- enable_model_request_overages: billing_preferences
- .as_ref()
- .map_or(false, |preferences| {
- preferences.model_request_overages_enabled
- }),
- model_request_overages_spend_limit_in_cents: billing_preferences
- .as_ref()
- .map_or(0, |preferences| {
- preferences.model_request_overages_spend_limit_in_cents as u32
- }),
- has_overdue_invoices: billing_customer.has_overdue_invoices,
- };
-
- Ok(jsonwebtoken::encode(
- &Header::default(),
- &claims,
- &EncodingKey::from_secret(secret.as_ref()),
- )?)
- }
-
- pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
- let secret = config
- .llm_api_secret
- .as_ref()
- .context("no LLM API secret")?;
-
- match jsonwebtoken::decode::<Self>(
- token,
- &DecodingKey::from_secret(secret.as_ref()),
- &Validation::default(),
- ) {
- Ok(token) => Ok(token.claims),
- Err(e) => {
- if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
- Err(ValidateLlmTokenError::Expired)
- } else {
- Err(ValidateLlmTokenError::JwtError(e))
- }
- }
- }
- }
-}
-
-#[derive(Error, Debug)]
-pub enum ValidateLlmTokenError {
- #[error("access token is expired")]
- Expired,
- #[error("access token validation error: {0}")]
- JwtError(#[from] jsonwebtoken::errors::Error),
- #[error("{0}")]
- Other(#[from] anyhow::Error),
-}
@@ -1,14 +1,12 @@
mod connection_pool;
-use crate::api::billing::find_or_create_billing_customer;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{
- AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
+ AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
MIN_ACCOUNT_AGE_FOR_LLM_USE,
};
-use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@@ -218,6 +216,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
+ #[allow(unused)]
system_id: Option<String>,
_executor: Executor,
}
@@ -464,7 +463,6 @@ impl Server {
.add_message_handler(unfollow)
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
- .add_request_handler(get_llm_api_token)
.add_request_handler(accept_terms_of_service)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version)
@@ -4251,96 +4249,6 @@ async fn accept_terms_of_service(
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
- // When the user accepts the terms of service, we want to refresh their LLM
- // token to grant access.
- session
- .peer
- .send(session.connection_id, proto::RefreshLlmToken {})?;
-
- Ok(())
-}
-
-async fn get_llm_api_token(
- _request: proto::GetLlmToken,
- response: Response<proto::GetLlmToken>,
- session: MessageContext,
-) -> Result<()> {
- let db = session.db().await;
-
- let flags = db.get_user_flags(session.user_id()).await?;
-
- let user_id = session.user_id();
- let user = db
- .get_user_by_id(user_id)
- .await?
- .with_context(|| format!("user {user_id} not found"))?;
-
- if user.accepted_tos_at.is_none() {
- Err(anyhow!("terms of service not accepted"))?
- }
-
- let stripe_client = session
- .app_state
- .stripe_client
- .as_ref()
- .context("failed to retrieve Stripe client")?;
-
- let stripe_billing = session
- .app_state
- .stripe_billing
- .as_ref()
- .context("failed to retrieve Stripe billing object")?;
-
- let billing_customer = if let Some(billing_customer) =
- db.get_billing_customer_by_user_id(user.id).await?
- {
- billing_customer
- } else {
- let customer_id = stripe_billing
- .find_or_create_customer_by_email(user.email_address.as_deref())
- .await?;
-
- find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
- .await?
- .context("billing customer not found")?
- };
-
- let billing_subscription =
- if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
- billing_subscription
- } else {
- let stripe_customer_id =
- StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
-
- let stripe_subscription = stripe_billing
- .subscribe_to_zed_free(stripe_customer_id)
- .await?;
-
- db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
- billing_customer_id: billing_customer.id,
- kind: Some(SubscriptionKind::ZedFree),
- stripe_subscription_id: stripe_subscription.id.to_string(),
- stripe_subscription_status: stripe_subscription.status.into(),
- stripe_cancellation_reason: None,
- stripe_current_period_start: Some(stripe_subscription.current_period_start),
- stripe_current_period_end: Some(stripe_subscription.current_period_end),
- })
- .await?
- };
-
- let billing_preferences = db.get_billing_preferences(user.id).await?;
-
- let token = LlmTokenClaims::create(
- &user,
- session.is_staff(),
- billing_customer,
- billing_preferences,
- &flags,
- billing_subscription,
- session.system_id.clone(),
- &session.app_state.config,
- )?;
- response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
@@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
repeated ContextVersion contexts = 1;
}
-message GetLlmToken {}
-
-message GetLlmTokenResponse {
- string token = 1;
-}
-
-message RefreshLlmToken {}
-
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
@@ -250,10 +250,6 @@ message Envelope {
AddWorktree add_worktree = 222;
AddWorktreeResponse add_worktree_response = 223;
- GetLlmToken get_llm_token = 235;
- GetLlmTokenResponse get_llm_token_response = 236;
- RefreshLlmToken refresh_llm_token = 259;
-
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
@@ -419,7 +415,9 @@ message Envelope {
reserved 221;
reserved 224 to 229;
reserved 230 to 231;
+ reserved 235 to 236;
reserved 246;
+ reserved 259;
reserved 270;
reserved 247 to 254;
reserved 255 to 256;
@@ -119,8 +119,6 @@ messages!(
(GetTypeDefinitionResponse, Background),
(GetImplementation, Background),
(GetImplementationResponse, Background),
- (GetLlmToken, Background),
- (GetLlmTokenResponse, Background),
(OpenUnstagedDiff, Foreground),
(OpenUnstagedDiffResponse, Foreground),
(OpenUncommittedDiff, Foreground),
@@ -196,7 +194,6 @@ messages!(
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(RefreshInlayHints, Foreground),
- (RefreshLlmToken, Background),
(RegisterBufferWithLanguageServers, Background),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@@ -354,7 +351,6 @@ request_messages!(
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetDocumentSymbols, GetDocumentSymbolsResponse),
(GetHover, GetHoverResponse),
- (GetLlmToken, GetLlmTokenResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),