Detailed changes
@@ -2865,6 +2865,7 @@ dependencies = [
"chrono",
"clock",
"cloud_api_client",
+ "cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
@@ -2878,6 +2879,7 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
+ "language_model",
"log",
"objc2-foundation",
"parking_lot",
@@ -9335,7 +9337,6 @@ dependencies = [
"anthropic",
"anyhow",
"base64 0.22.1",
- "client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
@@ -6,7 +6,7 @@ use acp_thread::{
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use feature_flags::FeatureFlagAppExt as _;
@@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};
@@ -1,6 +1,7 @@
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
+use client::RefreshLlmTokenListener;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::AppContext;
use gpui::{Entity, TestAppContext};
@@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
- language_model::init(user_store, client, cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(
@@ -815,7 +815,7 @@ mod tests {
cx.set_global(store);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
editor::init(cx);
});
@@ -1808,7 +1808,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
});
let fs = FakeFs::new(cx.executor());
@@ -1965,7 +1965,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
workspace::register_project_item::<Editor>(cx);
});
@@ -2114,7 +2114,8 @@ pub mod evals {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);
@@ -22,6 +22,7 @@ base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
+cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
@@ -35,6 +36,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
+language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -1,6 +1,7 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
+mod llm_token;
mod proxy;
pub mod telemetry;
pub mod user;
@@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
-use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
+use cloud_api_client::{ClientApiError, CloudApiClient};
+use cloud_api_types::OrganizationId;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
@@ -24,6 +26,7 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
+use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;
@@ -51,6 +54,7 @@ use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
+pub use llm_token::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@@ -1517,6 +1521,66 @@ impl Client {
})
}
+ pub async fn acquire_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .acquire(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
+ }
+ Err(err) => Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn clear_and_refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .clear_and_refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();
@@ -0,0 +1,116 @@
+use super::{Client, UserStore};
+use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
+use gpui::{
+ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
+use language_model::LlmApiToken;
+use std::sync::Arc;
+
+pub trait NeedsLlmTokenRefresh {
+ /// Returns whether the LLM token needs to be refreshed.
+ fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+ fn needs_llm_token_refresh(&self) -> bool {
+ self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+ || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
+ }
+}
+
+enum TokenRefreshMode {
+ Refresh,
+ ClearAndRefresh,
+}
+
+pub fn global_llm_token(cx: &App) -> LlmApiToken {
+ RefreshLlmTokenListener::global(cx)
+ .read(cx)
+ .llm_api_token
+ .clone()
+}
+
+struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct LlmTokenRefreshedEvent;
+
+pub struct RefreshLlmTokenListener {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_api_token: LlmApiToken,
+ _subscription: Subscription,
+}
+
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+ pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+ let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
+ cx.set_global(GlobalRefreshLlmTokenListener(listener));
+ }
+
+ pub fn global(cx: &App) -> Entity<Self> {
+ GlobalRefreshLlmTokenListener::global(cx).0.clone()
+ }
+
+ fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ client.add_message_to_client_handler({
+ let this = cx.weak_entity();
+ move |message, cx| {
+ if let Some(this) = this.upgrade() {
+ Self::handle_refresh_llm_token(this, message, cx);
+ }
+ }
+ });
+
+ let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
+ if matches!(event, super::user::Event::OrganizationChanged) {
+ this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
+ }
+ });
+
+ Self {
+ client,
+ user_store,
+ llm_api_token: LlmApiToken::default(),
+ _subscription: subscription,
+ }
+ }
+
+ fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+ cx.spawn(async move |this, cx| {
+ match mode {
+ TokenRefreshMode::Refresh => {
+ client
+ .refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ TokenRefreshMode::ClearAndRefresh => {
+ client
+ .clear_and_refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ }
+ this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
+ match message {
+ MessageToClient::UserUpdated => {
+ this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ }
+ }
+ }
+}
@@ -1,6 +1,8 @@
use crate::{StoredEvent, example_spec::ExampleSpec};
use anyhow::Result;
use buffer_diff::BufferDiffSnapshot;
+#[cfg(test)]
+use client::RefreshLlmTokenListener;
use collections::HashMap;
use gpui::{App, Entity, Task};
use language::Buffer;
@@ -548,7 +550,8 @@ mod tests {
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}
@@ -1,5 +1,8 @@
use anyhow::Result;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{
+ Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore,
+ global_llm_token as global_llm_api_token,
+};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -31,7 +34,7 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -748,7 +751,7 @@ impl EditPredictionStore {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let data_collection_choice = Self::load_data_collection_choice(cx);
- let llm_token = LlmApiToken::global(cx);
+ let llm_token = global_llm_api_token(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@@ -877,7 +880,9 @@ impl EditPredictionStore {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
- let token = llm_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -2539,12 +2544,15 @@ impl EditPredictionStore {
Res: DeserializeOwned,
{
let http_client = client.http_client();
-
let mut token = if require_auth {
- Some(llm_token.acquire(&client, organization_id.clone()).await?)
+ Some(
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ )
} else {
- llm_token
- .acquire(&client, organization_id.clone())
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
.await
.ok()
};
@@ -2588,7 +2596,11 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
- token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
+ token = Some(
+ client
+ .refresh_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ );
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1,6 +1,6 @@
use super::*;
use crate::udiff::apply_diff_to_string;
-use client::{UserStore, test::FakeServer};
+use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
use clock::FakeSystemClock;
use clock::ReplicaId;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -23,7 +23,7 @@ use language::{
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
-use language_model::RefreshLlmTokenListener;
+
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
@@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
client.cloud_client().set_credentials(1, "test".into());
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
- language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
@@ -1,4 +1,4 @@
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::sync::Arc;
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
-client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
@@ -11,12 +11,10 @@ pub mod tool_schema;
pub mod fake_provider;
use anyhow::{Result, anyhow};
-use client::Client;
-use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
@@ -39,12 +37,7 @@ pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub use env_var::{EnvVar, env_var};
pub use provider::*;
-pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
- init_settings(cx);
- RefreshLlmTokenListener::register(client, user_store, cx);
-}
-
-pub fn init_settings(cx: &mut App) {
+pub fn init(cx: &mut App) {
registry::init(cx);
}
@@ -1,16 +1,9 @@
use std::fmt;
use std::sync::Arc;
-use anyhow::{Context as _, Result};
-use client::Client;
-use client::UserStore;
use cloud_api_client::ClientApiError;
+use cloud_api_client::CloudApiClient;
use cloud_api_types::OrganizationId;
-use cloud_api_types::websocket_protocol::MessageToClient;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{
- App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
-};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
- pub fn global(cx: &App) -> Self {
- RefreshLlmTokenListener::global(cx)
- .read(cx)
- .llm_api_token
- .clone()
- }
-
pub async fn acquire(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
+ ) -> Result<String, ClientApiError> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
@@ -49,6 +36,7 @@ impl LlmApiToken {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
+ system_id,
organization_id,
)
.await
@@ -57,10 +45,11 @@ impl LlmApiToken {
pub async fn refresh(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
- Self::fetch(self.0.write().await, client, organization_id).await
+ ) -> Result<String, ClientApiError> {
+ Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
@@ -69,28 +58,22 @@ impl LlmApiToken {
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
+ ) -> Result<String, ClientApiError> {
let mut lock = self.0.write().await;
*lock = None;
- Self::fetch(lock, client, organization_id).await
+ Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
- let system_id = client
- .telemetry()
- .system_id()
- .map(|system_id| system_id.to_string());
-
- let result = client
- .cloud_client()
- .create_llm_token(system_id, organization_id)
- .await;
+ ) -> Result<String, ClientApiError> {
+ let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@@ -98,112 +81,7 @@ impl LlmApiToken {
}
Err(err) => {
*lock = None;
- match err {
- ClientApiError::Unauthorized => {
- client.request_sign_out();
- Err(err).context("Failed to create LLM token")
- }
- ClientApiError::Other(err) => Err(err),
- }
- }
- }
- }
-}
-
-pub trait NeedsLlmTokenRefresh {
- /// Returns whether the LLM token needs to be refreshed.
- fn needs_llm_token_refresh(&self) -> bool;
-}
-
-impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
- fn needs_llm_token_refresh(&self) -> bool {
- self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
- || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
- }
-}
-
-enum TokenRefreshMode {
- Refresh,
- ClearAndRefresh,
-}
-
-struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct LlmTokenRefreshedEvent;
-
-pub struct RefreshLlmTokenListener {
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- llm_api_token: LlmApiToken,
- _subscription: Subscription,
-}
-
-impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
- pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
- let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
- cx.set_global(GlobalRefreshLlmTokenListener(listener));
- }
-
- pub fn global(cx: &App) -> Entity<Self> {
- GlobalRefreshLlmTokenListener::global(cx).0.clone()
- }
-
- fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- client.add_message_to_client_handler({
- let this = cx.weak_entity();
- move |message, cx| {
- if let Some(this) = this.upgrade() {
- Self::handle_refresh_llm_token(this, message, cx);
- }
- }
- });
-
- let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
- if matches!(event, client::user::Event::OrganizationChanged) {
- this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
- }
- });
-
- Self {
- client,
- user_store,
- llm_api_token: LlmApiToken::default(),
- _subscription: subscription,
- }
- }
-
- fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = self
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |this, cx| {
- match mode {
- TokenRefreshMode::Refresh => {
- llm_api_token.refresh(&client, organization_id).await?;
- }
- TokenRefreshMode::ClearAndRefresh => {
- llm_api_token
- .clear_and_refresh(&client, organization_id)
- .await?;
- }
- }
- this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
- })
- .detach_and_log_err(cx);
- }
-
- fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
- match message {
- MessageToClient::UserUpdated => {
- this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ Err(err)
}
}
}
@@ -1,7 +1,10 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
-use client::{Client, UserStore, zed_urls};
+use client::{
+ Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore,
+ global_llm_token as global_llm_api_token, zed_urls,
+};
use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@@ -24,10 +27,9 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
- OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
- RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
- ZED_CLOUD_PROVIDER_NAME,
+ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
+ OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+ ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
};
use release_channel::AppVersion;
use schemars::JsonSchema;
@@ -111,7 +113,7 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let llm_api_token = LlmApiToken::global(cx);
+ let llm_api_token = global_llm_api_token(cx);
Self {
client: client.clone(),
llm_api_token,
@@ -226,7 +228,9 @@ impl State {
organization_id: Option<OrganizationId>,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -414,8 +418,8 @@ impl CloudLanguageModel {
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
- let mut token = llm_api_token
- .acquire(&client, organization_id.clone())
+ let mut token = client
+ .acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
let mut refreshed_token = false;
@@ -447,8 +451,8 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
- token = llm_api_token
- .refresh(&client, organization_id.clone())
+ token = client
+ .refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
refreshed_token = true;
continue;
@@ -713,7 +717,9 @@ impl LanguageModel for CloudLanguageModel {
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
@@ -1,13 +1,13 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token as global_llm_api_token};
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -30,7 +30,7 @@ pub struct State {
impl State {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let llm_api_token = LlmApiToken::global(cx);
+ let llm_api_token = global_llm_api_token(cx);
Self {
client,
@@ -73,8 +73,8 @@ async fn perform_web_search(
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
- let mut token = llm_api_token
- .acquire(&client, organization_id.clone())
+ let mut token = client
+ .acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
loop {
@@ -100,8 +100,8 @@ async fn perform_web_search(
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
- token = llm_api_token
- .refresh(&client, organization_id.clone())
+ token = client
+ .refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
retries_remaining -= 1;
} else {
@@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
use clap::Parser;
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
-use client::{Client, ProxySettings, UserStore, parse_zed_link};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView;
use collections::HashMap;
use crashes::InitCrashHandler;
@@ -664,7 +664,12 @@ fn main() {
);
copilot_ui::init(&app_state, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
zed::telemetry_log::init(cx);
@@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
});
prompt_store::init(cx);
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
git_ui::init(cx);
project::AgentRegistryStore::init_global(
@@ -5015,7 +5015,12 @@ mod tests {
cx,
);
image_viewer::init(cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);
@@ -313,7 +313,12 @@ mod tests {
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
client::init(&app_state.client, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
editor::init(cx);
app_state
});