Detailed changes
@@ -260,7 +260,6 @@ dependencies = [
"chrono",
"client",
"collections",
- "credentials_provider",
"env_logger 0.11.8",
"feature_flags",
"fs",
@@ -289,6 +288,7 @@ dependencies = [
"util",
"uuid",
"watch",
+ "zed_credentials_provider",
]
[[package]]
@@ -2856,6 +2856,7 @@ dependencies = [
"chrono",
"clock",
"cloud_api_client",
+ "cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
@@ -2869,6 +2870,7 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
+ "language_model",
"log",
"objc2-foundation",
"parking_lot",
@@ -2900,6 +2902,7 @@ dependencies = [
"util",
"windows 0.61.3",
"worktree",
+ "zed_credentials_provider",
]
[[package]]
@@ -3059,6 +3062,7 @@ dependencies = [
"serde",
"serde_json",
"text",
+ "zed_credentials_provider",
"zeta_prompt",
]
@@ -4035,12 +4039,8 @@ name = "credentials_provider"
version = "0.1.0"
dependencies = [
"anyhow",
- "futures 0.3.31",
"gpui",
- "paths",
- "release_channel",
"serde",
- "serde_json",
]
[[package]]
@@ -5115,6 +5115,7 @@ dependencies = [
"collections",
"copilot",
"copilot_ui",
+ "credentials_provider",
"ctor",
"db",
"edit_prediction_context",
@@ -5157,6 +5158,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
+ "zed_credentials_provider",
"zeta_prompt",
"zlog",
"zstd",
@@ -5583,6 +5585,13 @@ dependencies = [
"log",
]
+[[package]]
+name = "env_var"
+version = "0.1.0"
+dependencies = [
+ "gpui",
+]
+
[[package]]
name = "envy"
version = "0.4.2"
@@ -9315,12 +9324,12 @@ dependencies = [
"anthropic",
"anyhow",
"base64 0.22.1",
- "client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
+ "env_var",
"futures 0.3.31",
"gpui",
"http_client",
@@ -9336,7 +9345,6 @@ dependencies = [
"smol",
"thiserror 2.0.17",
"util",
- "zed_env_vars",
]
[[package]]
@@ -13137,6 +13145,7 @@ dependencies = [
"wax",
"which 6.0.3",
"worktree",
+ "zed_credentials_provider",
"zeroize",
"zlog",
"ztracing",
@@ -15746,6 +15755,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
+ "zed_credentials_provider",
]
[[package]]
@@ -22180,10 +22190,24 @@ dependencies = [
]
[[package]]
-name = "zed_env_vars"
+name = "zed_credentials_provider"
version = "0.1.0"
dependencies = [
+ "anyhow",
+ "credentials_provider",
+ "futures 0.3.31",
"gpui",
+ "paths",
+ "release_channel",
+ "serde",
+ "serde_json",
+]
+
+[[package]]
+name = "zed_env_vars"
+version = "0.1.0"
+dependencies = [
+ "env_var",
]
[[package]]
@@ -61,6 +61,7 @@ members = [
"crates/edit_prediction_ui",
"crates/editor",
"crates/encoding_selector",
+ "crates/env_var",
"crates/etw_tracing",
"crates/eval_cli",
"crates/eval_utils",
@@ -220,6 +221,7 @@ members = [
"crates/x_ai",
"crates/zed",
"crates/zed_actions",
+ "crates/zed_credentials_provider",
"crates/zed_env_vars",
"crates/zeta_prompt",
"crates/zlog",
@@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
encoding_selector = { path = "crates/encoding_selector" }
+env_var = { path = "crates/env_var" }
etw_tracing = { path = "crates/etw_tracing" }
eval_utils = { path = "crates/eval_utils" }
extension = { path = "crates/extension" }
@@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" }
x_ai = { path = "crates/x_ai" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
+zed_credentials_provider = { path = "crates/zed_credentials_provider" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }
@@ -4,7 +4,7 @@ use crate::{
ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput,
};
use Role::*;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
@@ -1423,7 +1423,8 @@ impl EditAgentTest {
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
});
@@ -6,7 +6,7 @@ use acp_thread::{
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use feature_flags::FeatureFlagAppExt as _;
@@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};
@@ -6,7 +6,7 @@ use crate::{
};
use Role::*;
use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use fs::FakeFs;
use futures::{FutureExt, StreamExt, future::LocalBoxFuture};
use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
@@ -274,7 +274,8 @@ impl StreamingEditToolTest {
cx.set_http_client(http_client);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client, cx);
});
@@ -32,7 +32,6 @@ futures.workspace = true
gpui.workspace = true
feature_flags.workspace = true
gpui_tokio = { workspace = true, optional = true }
-credentials_provider.workspace = true
google_ai.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -53,6 +52,7 @@ terminal.workspace = true
uuid.workspace = true
util.workspace = true
watch.workspace = true
+zed_credentials_provider.workspace = true
[target.'cfg(unix)'.dependencies]
libc.workspace = true
@@ -3,7 +3,6 @@ use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use collections::HashSet;
-use credentials_provider::CredentialsProvider;
use fs::Fs;
use gpui::{App, AppContext as _, Entity, Task};
use language_model::{ApiKey, EnvVar};
@@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
if let Some(key) = env_var.value {
return Task::ready(Ok(key));
}
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = google_ai::API_URL.to_string();
cx.spawn(async move |cx| {
Ok(
@@ -1,6 +1,7 @@
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
+use client::RefreshLlmTokenListener;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::AppContext;
use gpui::{Entity, TestAppContext};
@@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
- language_model::init(user_store, client, cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store, cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(
@@ -815,7 +815,7 @@ mod tests {
cx.set_global(store);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
editor::init(cx);
});
@@ -1809,7 +1809,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
});
let fs = FakeFs::new(cx.executor());
@@ -1966,7 +1966,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
workspace::register_project_item::<Editor>(cx);
});
@@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
pub mod evals {
use crate::InlineAssistant;
use agent::ThreadStore;
- use client::{Client, UserStore};
+ use client::{Client, RefreshLlmTokenListener, UserStore};
use editor::{Editor, MultiBuffer, MultiBufferOffset};
use eval_utils::{EvalOutput, NoProcessor};
use fs::FakeFs;
@@ -2091,7 +2091,8 @@ pub mod evals {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);
@@ -22,6 +22,7 @@ base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
+cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
@@ -35,6 +36,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
+language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -60,6 +62,7 @@ tokio.workspace = true
url.workspace = true
util.workspace = true
worktree.workspace = true
+zed_credentials_provider.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
@@ -1,6 +1,7 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
+mod llm_token;
mod proxy;
pub mod telemetry;
pub mod user;
@@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
-use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
+use cloud_api_client::{ClientApiError, CloudApiClient};
+use cloud_api_types::OrganizationId;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
@@ -24,6 +26,7 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
+use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;
@@ -51,6 +54,7 @@ use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
+pub use llm_token::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider {
impl ClientCredentialsProvider {
pub fn new(cx: &App) -> Self {
Self {
- provider: <dyn CredentialsProvider>::global(cx),
+ provider: zed_credentials_provider::global(cx),
}
}
@@ -568,6 +572,10 @@ impl Client {
self.http.clone()
}
+ pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
+ self.credentials_provider.provider.clone()
+ }
+
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
self.cloud_client.clone()
}
@@ -1513,6 +1521,66 @@ impl Client {
})
}
+ pub async fn acquire_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .acquire(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
+ }
+ Err(err) => Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn clear_and_refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .clear_and_refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();
@@ -0,0 +1,116 @@
+use super::{Client, UserStore};
+use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
+use gpui::{
+ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
+use language_model::LlmApiToken;
+use std::sync::Arc;
+
+pub trait NeedsLlmTokenRefresh {
+ /// Returns whether the LLM token needs to be refreshed.
+ fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+ fn needs_llm_token_refresh(&self) -> bool {
+ self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+ || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
+ }
+}
+
+enum TokenRefreshMode {
+ Refresh,
+ ClearAndRefresh,
+}
+
+pub fn global_llm_token(cx: &App) -> LlmApiToken {
+ RefreshLlmTokenListener::global(cx)
+ .read(cx)
+ .llm_api_token
+ .clone()
+}
+
+struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct LlmTokenRefreshedEvent;
+
+pub struct RefreshLlmTokenListener {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_api_token: LlmApiToken,
+ _subscription: Subscription,
+}
+
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+ pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+ let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
+ cx.set_global(GlobalRefreshLlmTokenListener(listener));
+ }
+
+ pub fn global(cx: &App) -> Entity<Self> {
+ GlobalRefreshLlmTokenListener::global(cx).0.clone()
+ }
+
+ fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ client.add_message_to_client_handler({
+ let this = cx.weak_entity();
+ move |message, cx| {
+ if let Some(this) = this.upgrade() {
+ Self::handle_refresh_llm_token(this, message, cx);
+ }
+ }
+ });
+
+ let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
+ if matches!(event, super::user::Event::OrganizationChanged) {
+ this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
+ }
+ });
+
+ Self {
+ client,
+ user_store,
+ llm_api_token: LlmApiToken::default(),
+ _subscription: subscription,
+ }
+ }
+
+ fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+ cx.spawn(async move |this, cx| {
+ match mode {
+ TokenRefreshMode::Refresh => {
+ client
+ .refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ TokenRefreshMode::ClearAndRefresh => {
+ client
+ .clear_and_refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ }
+ this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
+ match message {
+ MessageToClient::UserUpdated => {
+ this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ }
+ }
+ }
+}
@@ -22,6 +22,7 @@ log.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true
+zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
@@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
}
pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = codestral_api_url(cx);
codestral_api_key_state(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(api_url, |s| s, cx)
+ key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}
@@ -13,9 +13,5 @@ path = "src/credentials_provider.rs"
[dependencies]
anyhow.workspace = true
-futures.workspace = true
gpui.workspace = true
-paths.workspace = true
-release_channel.workspace = true
serde.workspace = true
-serde_json.workspace = true
@@ -1,26 +1,8 @@
-use std::collections::HashMap;
use std::future::Future;
-use std::path::PathBuf;
use std::pin::Pin;
-use std::sync::{Arc, LazyLock};
use anyhow::Result;
-use futures::FutureExt as _;
-use gpui::{App, AsyncApp};
-use release_channel::ReleaseChannel;
-
-/// An environment variable whose presence indicates that the system keychain
-/// should be used in development.
-///
-/// By default, running Zed in development uses the development credentials
-/// provider. Setting this environment variable allows you to interact with the
-/// system keychain (for instance, if you need to test something).
-///
-/// Only works in development. Setting this environment variable in other
-/// release channels is a no-op.
-static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
- std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
-});
+use gpui::AsyncApp;
/// A provider for credentials.
///
@@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync {
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
}
-
-impl dyn CredentialsProvider {
- /// Returns the global [`CredentialsProvider`].
- pub fn global(cx: &App) -> Arc<Self> {
- // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
- // seems like this is a false positive from Clippy.
- #[allow(clippy::arc_with_non_send_sync)]
- Self::new(cx)
- }
-
- fn new(cx: &App) -> Arc<Self> {
- let use_development_provider = match ReleaseChannel::try_global(cx) {
- Some(ReleaseChannel::Dev) => {
- // In development we default to using the development
- // credentials provider to avoid getting spammed by relentless
- // keychain access prompts.
- //
- // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
- // variable is set, we will use the actual keychain.
- !*ZED_DEVELOPMENT_USE_KEYCHAIN
- }
- Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
- | None => false,
- };
-
- if use_development_provider {
- Arc::new(DevelopmentCredentialsProvider::new())
- } else {
- Arc::new(KeychainCredentialsProvider)
- }
- }
-}
-
-/// A credentials provider that stores credentials in the system keychain.
-struct KeychainCredentialsProvider;
-
-impl CredentialsProvider for KeychainCredentialsProvider {
- fn read_credentials<'a>(
- &'a self,
- url: &'a str,
- cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
- async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
- }
-
- fn write_credentials<'a>(
- &'a self,
- url: &'a str,
- username: &'a str,
- password: &'a [u8],
- cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
- async move {
- cx.update(move |cx| cx.write_credentials(url, username, password))
- .await
- }
- .boxed_local()
- }
-
- fn delete_credentials<'a>(
- &'a self,
- url: &'a str,
- cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
- async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
- }
-}
-
-/// A credentials provider that stores credentials in a local file.
-///
-/// This MUST only be used in development, as this is not a secure way of storing
-/// credentials on user machines.
-///
-/// Its existence is purely to work around the annoyance of having to constantly
-/// re-allow access to the system keychain when developing Zed.
-struct DevelopmentCredentialsProvider {
- path: PathBuf,
-}
-
-impl DevelopmentCredentialsProvider {
- fn new() -> Self {
- let path = paths::config_dir().join("development_credentials");
-
- Self { path }
- }
-
- fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
- let json = std::fs::read(&self.path)?;
- let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
-
- Ok(credentials)
- }
-
- fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
- let json = serde_json::to_string(credentials)?;
- std::fs::write(&self.path, json)?;
-
- Ok(())
- }
-}
-
-impl CredentialsProvider for DevelopmentCredentialsProvider {
- fn read_credentials<'a>(
- &'a self,
- url: &'a str,
- _cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
- async move {
- Ok(self
- .load_credentials()
- .unwrap_or_default()
- .get(url)
- .cloned())
- }
- .boxed_local()
- }
-
- fn write_credentials<'a>(
- &'a self,
- url: &'a str,
- username: &'a str,
- password: &'a [u8],
- _cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
- async move {
- let mut credentials = self.load_credentials().unwrap_or_default();
- credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
-
- self.save_credentials(&credentials)
- }
- .boxed_local()
- }
-
- fn delete_credentials<'a>(
- &'a self,
- url: &'a str,
- _cx: &'a AsyncApp,
- ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
- async move {
- let mut credentials = self.load_credentials()?;
- credentials.remove(url);
-
- self.save_credentials(&credentials)
- }
- .boxed_local()
- }
-}
@@ -26,6 +26,7 @@ cloud_llm_client.workspace = true
collections.workspace = true
copilot.workspace = true
copilot_ui.workspace = true
+credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
@@ -65,6 +66,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
+zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
zstd.workspace = true
@@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String {
mod tests {
use super::*;
use crate::EditPredictionStore;
+ use client::RefreshLlmTokenListener;
use client::{Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
@@ -548,7 +549,8 @@ mod tests {
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}
@@ -1,5 +1,5 @@
use anyhow::Result;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -11,6 +11,7 @@ use cloud_llm_client::{
};
use collections::{HashMap, HashSet};
use copilot::{Copilot, Reinstall, SignIn, SignOut};
+use credentials_provider::CredentialsProvider;
use db::kvp::{Dismissable, KeyValueStore};
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
@@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -150,6 +151,7 @@ pub struct EditPredictionStore {
rated_predictions: HashSet<EditPredictionId>,
#[cfg(test)]
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
pub(crate) struct EditPredictionRejectionPayload {
@@ -746,7 +748,7 @@ impl EditPredictionStore {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let data_collection_choice = Self::load_data_collection_choice(cx);
- let llm_token = LlmApiToken::global(cx);
+ let llm_token = global_llm_token(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@@ -787,6 +789,8 @@ impl EditPredictionStore {
.log_err();
});
+ let credentials_provider = zed_credentials_provider::global(cx);
+
let this = Self {
projects: HashMap::default(),
client,
@@ -807,6 +811,8 @@ impl EditPredictionStore {
shown_predictions: Default::default(),
#[cfg(test)]
settled_event_callback: None,
+
+ credentials_provider,
};
this
@@ -871,7 +877,9 @@ impl EditPredictionStore {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
- let token = llm_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -2315,7 +2323,10 @@ impl EditPredictionStore {
zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
}
EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
- EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
+ EditPredictionModel::Mercury => {
+ self.mercury
+ .request_prediction(inputs, self.credentials_provider.clone(), cx)
+ }
};
cx.spawn(async move |this, cx| {
@@ -2536,12 +2547,15 @@ impl EditPredictionStore {
Res: DeserializeOwned,
{
let http_client = client.http_client();
-
let mut token = if require_auth {
- Some(llm_token.acquire(&client, organization_id.clone()).await?)
+ Some(
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ )
} else {
- llm_token
- .acquire(&client, organization_id.clone())
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
.await
.ok()
};
@@ -2585,7 +2599,11 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
- token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
+ token = Some(
+ client
+ .refresh_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ );
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1,6 +1,6 @@
use super::*;
use crate::udiff::apply_diff_to_string;
-use client::{UserStore, test::FakeServer};
+use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
use clock::FakeSystemClock;
use clock::ReplicaId;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -23,7 +23,7 @@ use language::{
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
-use language_model::RefreshLlmTokenListener;
+
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
@@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
client.cloud_client().set_credentials(1, "test".into());
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
- language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
@@ -5,6 +5,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
+use credentials_provider::CredentialsProvider;
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Context, Entity, Global, SharedString, Task,
@@ -51,10 +52,11 @@ impl Mercury {
debug_tx,
..
}: EditPredictionModelInput,
+ credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
self.api_token.update(cx, |key_state, cx| {
- _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
+ _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
});
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
@@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
}
pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
+ let credentials_provider = zed_credentials_provider::global(cx);
mercury_api_token(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
+ key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
})
}
@@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
pub fn load_open_ai_compatible_api_token(
cx: &mut App,
) -> Task<Result<(), language_model::AuthenticateError>> {
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = open_ai_compatible_api_url(cx);
open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(api_url, |s| s, cx)
+ key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}
@@ -1,4 +1,4 @@
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -0,0 +1,15 @@
+[package]
+name = "env_var"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/env_var.rs"
+
+[dependencies]
+gpui.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,40 @@
+use gpui::SharedString;
+
+#[derive(Clone)]
+pub struct EnvVar {
+ pub name: SharedString,
+ /// Value of the environment variable. Also `None` when set to an empty string.
+ pub value: Option<String>,
+}
+
+impl EnvVar {
+ pub fn new(name: SharedString) -> Self {
+ let value = std::env::var(name.as_str()).ok();
+ if value.as_ref().is_some_and(|v| v.is_empty()) {
+ Self { name, value: None }
+ } else {
+ Self { name, value }
+ }
+ }
+
+ pub fn or(self, other: EnvVar) -> EnvVar {
+ if self.value.is_some() { self } else { other }
+ }
+}
+
+/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
+#[macro_export]
+macro_rules! env_var {
+ ($name:expr) => {
+ ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
+ };
+}
+
+/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
+/// environment variable exists and is non-empty.
+#[macro_export]
+macro_rules! bool_env_var {
+ ($name:expr) => {
+ ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
+ };
+}
@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::sync::Arc;
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
-client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
+env_var.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
@@ -40,7 +40,6 @@ serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true
-zed_env_vars.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
+use env_var::EnvVar;
use futures::{FutureExt, future};
use gpui::{AsyncApp, Context, SharedString, Task};
use std::{
@@ -7,7 +8,6 @@ use std::{
sync::Arc,
};
use util::ResultExt as _;
-use zed_env_vars::EnvVar;
use crate::AuthenticateError;
@@ -101,6 +101,7 @@ impl ApiKeyState {
url: SharedString,
key: Option<String>,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ provider: Arc<dyn CredentialsProvider>,
cx: &Context<Ent>,
) -> Task<Result<()>> {
if self.is_from_env_var() {
@@ -108,18 +109,14 @@ impl ApiKeyState {
"bug: attempted to store API key in system keychain when API key is from env var",
)));
}
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |ent, cx| {
if let Some(key) = &key {
- credentials_provider
+ provider
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
.await
.log_err();
} else {
- credentials_provider
- .delete_credentials(&url, cx)
- .await
- .log_err();
+ provider.delete_credentials(&url, cx).await.log_err();
}
ent.update(cx, |ent, cx| {
let this = get_this(ent);
@@ -144,12 +141,13 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<Ent>,
) {
if url != self.url {
if !self.is_from_env_var() {
// loading will continue even though this result task is dropped
- let _task = self.load_if_needed(url, get_this, cx);
+ let _task = self.load_if_needed(url, get_this, provider, cx);
}
}
}
@@ -163,6 +161,7 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<Ent>,
) -> Task<Result<(), AuthenticateError>> {
if let LoadStatus::Loaded { .. } = &self.load_status
@@ -185,7 +184,7 @@ impl ApiKeyState {
let task = if let Some(load_task) = &self.load_task {
load_task.clone()
} else {
- let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
+ let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
self.url = url;
self.load_status = LoadStatus::NotPresent;
self.load_task = Some(load_task.clone());
@@ -206,14 +205,13 @@ impl ApiKeyState {
fn load<Ent: 'static>(
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ provider: Arc<dyn CredentialsProvider>,
cx: &Context<Ent>,
) -> Task<()> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn({
async move |ent, cx| {
let load_status =
- ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
- .await;
+ ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
ent.update(cx, |ent, cx| {
let this = get_this(ent);
this.url = url;
@@ -11,12 +11,10 @@ pub mod tool_schema;
pub mod fake_provider;
use anyhow::{Result, anyhow};
-use client::Client;
-use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
@@ -36,15 +34,10 @@ pub use crate::registry::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
+pub use env_var::{EnvVar, env_var};
pub use provider::*;
-pub use zed_env_vars::{EnvVar, env_var};
-pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
- init_settings(cx);
- RefreshLlmTokenListener::register(client, user_store, cx);
-}
-
-pub fn init_settings(cx: &mut App) {
+pub fn init(cx: &mut App) {
registry::init(cx);
}
@@ -1,16 +1,9 @@
use std::fmt;
use std::sync::Arc;
-use anyhow::{Context as _, Result};
-use client::Client;
-use client::UserStore;
use cloud_api_client::ClientApiError;
+use cloud_api_client::CloudApiClient;
use cloud_api_types::OrganizationId;
-use cloud_api_types::websocket_protocol::MessageToClient;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{
- App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
-};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
- pub fn global(cx: &App) -> Self {
- RefreshLlmTokenListener::global(cx)
- .read(cx)
- .llm_api_token
- .clone()
- }
-
pub async fn acquire(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
+ ) -> Result<String, ClientApiError> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
@@ -49,6 +36,7 @@ impl LlmApiToken {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
+ system_id,
organization_id,
)
.await
@@ -57,10 +45,11 @@ impl LlmApiToken {
pub async fn refresh(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
- Self::fetch(self.0.write().await, client, organization_id).await
+ ) -> Result<String, ClientApiError> {
+ Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
@@ -69,28 +58,22 @@ impl LlmApiToken {
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
+ ) -> Result<String, ClientApiError> {
let mut lock = self.0.write().await;
*lock = None;
- Self::fetch(lock, client, organization_id).await
+ Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
- client: &Arc<Client>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
organization_id: Option<OrganizationId>,
- ) -> Result<String> {
- let system_id = client
- .telemetry()
- .system_id()
- .map(|system_id| system_id.to_string());
-
- let result = client
- .cloud_client()
- .create_llm_token(system_id, organization_id)
- .await;
+ ) -> Result<String, ClientApiError> {
+ let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@@ -98,112 +81,7 @@ impl LlmApiToken {
}
Err(err) => {
*lock = None;
- match err {
- ClientApiError::Unauthorized => {
- client.request_sign_out();
- Err(err).context("Failed to create LLM token")
- }
- ClientApiError::Other(err) => Err(err),
- }
- }
- }
- }
-}
-
-pub trait NeedsLlmTokenRefresh {
- /// Returns whether the LLM token needs to be refreshed.
- fn needs_llm_token_refresh(&self) -> bool;
-}
-
-impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
- fn needs_llm_token_refresh(&self) -> bool {
- self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
- || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
- }
-}
-
-enum TokenRefreshMode {
- Refresh,
- ClearAndRefresh,
-}
-
-struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct LlmTokenRefreshedEvent;
-
-pub struct RefreshLlmTokenListener {
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- llm_api_token: LlmApiToken,
- _subscription: Subscription,
-}
-
-impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
- pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
- let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
- cx.set_global(GlobalRefreshLlmTokenListener(listener));
- }
-
- pub fn global(cx: &App) -> Entity<Self> {
- GlobalRefreshLlmTokenListener::global(cx).0.clone()
- }
-
- fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- client.add_message_to_client_handler({
- let this = cx.weak_entity();
- move |message, cx| {
- if let Some(this) = this.upgrade() {
- Self::handle_refresh_llm_token(this, message, cx);
- }
- }
- });
-
- let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
- if matches!(event, client::user::Event::OrganizationChanged) {
- this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
- }
- });
-
- Self {
- client,
- user_store,
- llm_api_token: LlmApiToken::default(),
- _subscription: subscription,
- }
- }
-
- fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = self
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |this, cx| {
- match mode {
- TokenRefreshMode::Refresh => {
- llm_api_token.refresh(&client, organization_id).await?;
- }
- TokenRefreshMode::ClearAndRefresh => {
- llm_api_token
- .clear_and_refresh(&client, organization_id)
- .await?;
- }
- }
- this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
- })
- .detach_and_log_err(cx);
- }
-
- fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
- match message {
- MessageToClient::UserUpdated => {
- this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ Err(err)
}
}
}
@@ -3,6 +3,7 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
+use credentials_provider::CredentialsProvider;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
@@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
+ let credentials_provider = client.credentials_provider();
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_language_model_providers(registry, user_store, client.clone(), cx);
+ register_language_model_providers(
+ registry,
+ user_store,
+ client.clone(),
+ credentials_provider.clone(),
+ cx,
+ );
});
// Subscribe to extension store events to track LLM extension installations
@@ -104,6 +112,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
&HashSet::default(),
&openai_compatible_providers,
client.clone(),
+ credentials_provider.clone(),
cx,
);
});
@@ -124,6 +133,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
&openai_compatible_providers,
&openai_compatible_providers_new,
client.clone(),
+ credentials_provider.clone(),
cx,
);
});
@@ -138,6 +148,7 @@ fn register_openai_compatible_providers(
old: &HashSet<Arc<str>>,
new: &HashSet<Arc<str>>,
client: Arc<Client>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<LanguageModelRegistry>,
) {
for provider_id in old {
@@ -152,6 +163,7 @@ fn register_openai_compatible_providers(
Arc::new(OpenAiCompatibleLanguageModelProvider::new(
provider_id.clone(),
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
@@ -164,6 +176,7 @@ fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
client: Arc<Client>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
@@ -177,62 +190,105 @@ fn register_language_model_providers(
registry.register_provider(
Arc::new(AnthropicLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OpenAiLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OllamaLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(LmStudioLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(DeepSeekLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(GoogleLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- MistralLanguageModelProvider::global(client.http_client(), cx),
+ MistralLanguageModelProvider::global(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ ),
cx,
);
registry.register_provider(
- Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(BedrockLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
Arc::new(OpenRouterLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(VercelLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
Arc::new(VercelAiGatewayLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(XAiLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OpenCodeLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider,
+ cx,
+ )),
cx,
);
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
@@ -6,6 +6,7 @@ use anthropic::{
};
use anyhow::Result;
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
@@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -59,30 +61,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl AnthropicLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -195,12 +195,13 @@ pub struct State {
settings: Option<AmazonBedrockSettings>,
/// Whether credentials came from environment variables (only relevant for static credentials)
credentials_from_env: bool,
+ credentials_provider: Arc<dyn CredentialsProvider>,
_subscription: Subscription,
}
impl State {
fn reset_auth(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(AMAZON_AWS_URL, cx)
@@ -220,7 +221,7 @@ impl State {
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let auth = credentials.clone().into_auth();
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(
@@ -287,7 +288,7 @@ impl State {
&self,
cx: &mut Context<Self>,
) -> Task<Result<(), AuthenticateError>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
// Try environment variables first
let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
@@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider {
}
impl BedrockLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| State {
auth: None,
settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
credentials_from_env: false,
+ credentials_provider,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@@ -1,7 +1,9 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
-use client::{Client, UserStore, zed_urls};
+use client::{
+ Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
+};
use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@@ -24,10 +26,9 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
- OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
- RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
- ZED_CLOUD_PROVIDER_NAME,
+ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
+ OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+ ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
};
use release_channel::AppVersion;
use schemars::JsonSchema;
@@ -111,7 +112,7 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let llm_api_token = LlmApiToken::global(cx);
+ let llm_api_token = global_llm_token(cx);
Self {
client: client.clone(),
llm_api_token,
@@ -226,7 +227,9 @@ impl State {
organization_id: Option<OrganizationId>,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -414,8 +417,8 @@ impl CloudLanguageModel {
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
- let mut token = llm_api_token
- .acquire(&client, organization_id.clone())
+ let mut token = client
+ .acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
let mut refreshed_token = false;
@@ -447,8 +450,8 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
- token = llm_api_token
- .refresh(&client, organization_id.clone())
+ token = client
+ .refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
refreshed_token = true;
continue;
@@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel {
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use deepseek::DEEPSEEK_API_URL;
use futures::Stream;
@@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -57,30 +59,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl DeepSeekLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::{Context as _, Result};
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
@@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
@@ -76,30 +78,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl GoogleLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::HashMap;
+use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
@@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<lmstudio::Model>,
fetch_model_task: Option<Task<Result<()>>>,
@@ -64,10 +66,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
- let task = self
- .api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
self.restart_fetch_models_task(cx);
task
}
@@ -114,10 +121,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
- let _task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let _task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
if self.is_authenticated() {
return Task::ready(Ok(()));
@@ -152,16 +163,29 @@ impl State {
}
impl LmStudioLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
move |this: &mut State, cx| {
- let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
- if &settings != new_settings {
- settings = new_settings.clone();
+ let new_settings =
+ AllLanguageModelSettings::get_global(cx).lmstudio.clone();
+ if settings != new_settings {
+ let credentials_provider = this.credentials_provider.clone();
+ let api_url = Self::api_url(cx).into();
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
+ settings = new_settings;
this.restart_fetch_models_task(cx);
cx.notify();
}
@@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider {
Self::api_url(cx).into(),
(*API_KEY_ENV_VAR).clone(),
),
+ credentials_provider,
http_client,
available_models: Default::default(),
fetch_model_task: None,
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
@@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -51,15 +53,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
@@ -73,20 +86,30 @@ impl MistralLanguageModelProvider {
.map(|this| &this.0)
}
- pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
+ pub fn global(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Arc<Self> {
if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
return this.0.clone();
}
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,4 +1,5 @@
use anyhow::{Result, anyhow};
+use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
@@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
fetched_models: Vec<ollama::Model>,
fetch_model_task: Option<Task<Result<()>>>,
@@ -65,10 +67,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
self.fetched_models.clear();
cx.spawn(async move |this, cx| {
@@ -80,10 +87,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
// Always try to fetch models - if no API key is needed (local Ollama), it will work
// If API key is needed and provided, it will work
@@ -157,7 +168,11 @@ impl State {
}
impl OllamaLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
@@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider {
let url_changed = last_settings.api_url != current_settings.api_url;
last_settings = current_settings.clone();
if url_changed {
+ let credentials_provider = this.credentials_provider.clone();
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
this.fetched_models.clear();
this.authenticate(cx).detach();
}
@@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider {
fetched_models: Default::default(),
fetch_model_task: None,
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
}),
};
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
@@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -63,30 +65,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl OpenAiLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::Result;
use convert_case::{Case, Casing};
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@@ -44,6 +45,7 @@ pub struct State {
id: Arc<str>,
api_key_state: ApiKeyState,
settings: OpenAiCompatibleSettings,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -52,20 +54,36 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = SharedString::new(self.settings.api_url.as_str());
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = SharedString::new(self.settings.api_url.clone());
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl OpenAiCompatibleLanguageModelProvider {
- pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ id: Arc<str>,
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
crate::AllLanguageModelSettings::get_global(cx)
.openai_compatible
@@ -79,10 +97,12 @@ impl OpenAiCompatibleLanguageModelProvider {
return;
};
if &this.settings != &settings {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = SharedString::new(settings.api_url.as_str());
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
+ credentials_provider,
cx,
);
this.settings = settings;
@@ -98,6 +118,7 @@ impl OpenAiCompatibleLanguageModelProvider {
EnvVar::new(api_key_env_var_name),
),
settings,
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::Result;
use collections::HashMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
use http_client::HttpClient;
@@ -42,6 +43,7 @@ pub struct OpenRouterLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<open_router::Model>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@@ -53,16 +55,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.spawn(async move |this, cx| {
let result = task.await;
@@ -114,7 +126,11 @@ impl State {
}
impl OpenRouterLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>({
let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
@@ -131,6 +147,7 @@ impl OpenRouterLanguageModelProvider {
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
@@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@@ -43,6 +44,7 @@ pub struct OpenCodeLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -51,30 +53,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl OpenCodeLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@@ -38,6 +39,7 @@ pub struct VercelLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -46,30 +48,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = VercelLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = VercelLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl VercelLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
@@ -41,6 +42,7 @@ pub struct VercelAiGatewayLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<AvailableModel>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@@ -52,16 +54,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.spawn(async move |this, cx| {
let result = task.await;
@@ -100,7 +112,11 @@ impl State {
}
impl VercelAiGatewayLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>({
let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
@@ -116,6 +132,7 @@ impl VercelAiGatewayLanguageModelProvider {
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
@@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
@@ -39,6 +40,7 @@ pub struct XAiLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@@ -47,30 +49,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl XAiLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
@@ -98,6 +98,7 @@ watch.workspace = true
wax.workspace = true
which.workspace = true
worktree.workspace = true
+zed_credentials_provider.workspace = true
zeroize.workspace = true
zlog.workspace = true
ztracing.workspace = true
@@ -684,7 +684,7 @@ impl ContextServerStore {
let server_url = url.clone();
let id = id.clone();
cx.spawn(async move |_this, cx| {
- let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
{
log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
@@ -797,8 +797,7 @@ impl ContextServerStore {
if configuration.has_static_auth_header() {
None
} else {
- let credentials_provider =
- cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
let http_client = cx.update(|cx| cx.http_client());
match Self::load_session(&credentials_provider, url, &cx).await {
@@ -1070,7 +1069,7 @@ impl ContextServerStore {
.context("Failed to start OAuth callback server")?;
let http_client = cx.update(|cx| cx.http_client());
- let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
let server_url = match configuration.as_ref() {
ContextServerConfiguration::Http { url, .. } => url.clone(),
_ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
@@ -1233,7 +1232,7 @@ impl ContextServerStore {
self.stop_server(&id, cx)?;
cx.spawn(async move |this, cx| {
- let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
log::error!("{} failed to clear OAuth session: {}", id, err);
}
@@ -1451,7 +1450,7 @@ async fn resolve_start_failure(
// (e.g. timeout because the server rejected the token silently). Clear it
// so the next start attempt can get a clean 401 and trigger the auth flow.
if www_authenticate.is_none() {
- let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
Ok(Some(_)) => {
log::info!("{id} start failed with a cached OAuth session present; clearing it");
@@ -59,6 +59,7 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
+zed_credentials_provider.workspace = true
[dev-dependencies]
fs = { workspace = true, features = ["test-support"] }
@@ -185,9 +185,15 @@ fn render_api_key_provider(
cx: &mut Context<SettingsWindow>,
) -> impl IntoElement {
let weak_page = cx.weak_entity();
+ let credentials_provider = zed_credentials_provider::global(cx);
_ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
let task = api_key_state.update(cx, |key_state, cx| {
- key_state.load_if_needed(current_url(cx), |state| state, cx)
+ key_state.load_if_needed(
+ current_url(cx),
+ |state| state,
+ credentials_provider.clone(),
+ cx,
+ )
});
cx.spawn(async move |_, cx| {
task.await.ok();
@@ -208,10 +214,17 @@ fn render_api_key_provider(
});
let write_key = move |api_key: Option<String>, cx: &mut App| {
+ let credentials_provider = zed_credentials_provider::global(cx);
api_key_state
.update(cx, |key_state, cx| {
let url = current_url(cx);
- key_state.store(url, api_key, |key_state| key_state, cx)
+ key_state.store(
+ url,
+ api_key,
+ |key_state| key_state,
+ credentials_provider,
+ cx,
+ )
})
.detach_and_log_err(cx);
};
@@ -1,13 +1,13 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -30,7 +30,7 @@ pub struct State {
impl State {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let llm_api_token = LlmApiToken::global(cx);
+ let llm_api_token = global_llm_token(cx);
Self {
client,
@@ -73,8 +73,8 @@ async fn perform_web_search(
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
- let mut token = llm_api_token
- .acquire(&client, organization_id.clone())
+ let mut token = client
+ .acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
loop {
@@ -100,8 +100,8 @@ async fn perform_web_search(
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
- token = llm_api_token
- .refresh(&client, organization_id.clone())
+ token = client
+ .refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
retries_remaining -= 1;
} else {
@@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
use clap::Parser;
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
-use client::{Client, ProxySettings, UserStore, parse_zed_link};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView;
use collections::HashMap;
use crashes::InitCrashHandler;
@@ -664,7 +664,12 @@ fn main() {
);
copilot_ui::init(&app_state, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
zed::telemetry_log::init(cx);
@@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
});
prompt_store::init(cx);
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
git_ui::init(cx);
project::AgentRegistryStore::init_global(
@@ -5189,7 +5189,12 @@ mod tests {
cx,
);
image_viewer::init(cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);
@@ -313,7 +313,12 @@ mod tests {
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
client::init(&app_state.client, cx);
- language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_model::init(cx);
+ client::RefreshLlmTokenListener::register(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ );
editor::init(cx);
app_state
});
@@ -0,0 +1,22 @@
+[package]
+name = "zed_credentials_provider"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zed_credentials_provider.rs"
+
+[dependencies]
+anyhow.workspace = true
+credentials_provider.workspace = true
+futures.workspace = true
+gpui.workspace = true
+paths.workspace = true
+release_channel.workspace = true
+serde.workspace = true
+serde_json.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,181 @@
+use std::collections::HashMap;
+use std::future::Future;
+use std::path::PathBuf;
+use std::pin::Pin;
+use std::sync::{Arc, LazyLock};
+
+use anyhow::Result;
+use credentials_provider::CredentialsProvider;
+use futures::FutureExt as _;
+use gpui::{App, AsyncApp, Global};
+use release_channel::ReleaseChannel;
+
+/// An environment variable whose presence indicates that the system keychain
+/// should be used in development.
+///
+/// By default, running Zed in development uses the development credentials
+/// provider. Setting this environment variable allows you to interact with the
+/// system keychain (for instance, if you need to test something).
+///
+/// Only works in development. Setting this environment variable in other
+/// release channels is a no-op.
+static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
+ std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
+});
+
+pub struct ZedCredentialsProvider(pub Arc<dyn CredentialsProvider>);
+
+impl Global for ZedCredentialsProvider {}
+
+/// Returns the global [`CredentialsProvider`].
+pub fn init_global(cx: &mut App) {
+ // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
+ // seems like this is a false positive from Clippy.
+ #[allow(clippy::arc_with_non_send_sync)]
+ let provider = new(cx);
+ cx.set_global(ZedCredentialsProvider(provider));
+}
+
+pub fn global(cx: &App) -> Arc<dyn CredentialsProvider> {
+ cx.try_global::<ZedCredentialsProvider>()
+ .map(|provider| provider.0.clone())
+ .unwrap_or_else(|| new(cx))
+}
+
+fn new(cx: &App) -> Arc<dyn CredentialsProvider> {
+ let use_development_provider = match ReleaseChannel::try_global(cx) {
+ Some(ReleaseChannel::Dev) => {
+ // In development we default to using the development
+ // credentials provider to avoid getting spammed by relentless
+ // keychain access prompts.
+ //
+ // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
+ // variable is set, we will use the actual keychain.
+ !*ZED_DEVELOPMENT_USE_KEYCHAIN
+ }
+ Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) | None => {
+ false
+ }
+ };
+
+ if use_development_provider {
+ Arc::new(DevelopmentCredentialsProvider::new())
+ } else {
+ Arc::new(KeychainCredentialsProvider)
+ }
+}
+
+/// A credentials provider that stores credentials in the system keychain.
+struct KeychainCredentialsProvider;
+
+impl CredentialsProvider for KeychainCredentialsProvider {
+ fn read_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+ async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
+ }
+
+ fn write_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ username: &'a str,
+ password: &'a [u8],
+ cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ cx.update(move |cx| cx.write_credentials(url, username, password))
+ .await
+ }
+ .boxed_local()
+ }
+
+ fn delete_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
+ }
+}
+
+/// A credentials provider that stores credentials in a local file.
+///
+/// This MUST only be used in development, as this is not a secure way of storing
+/// credentials on user machines.
+///
+/// Its existence is purely to work around the annoyance of having to constantly
+/// re-allow access to the system keychain when developing Zed.
+struct DevelopmentCredentialsProvider {
+ path: PathBuf,
+}
+
+impl DevelopmentCredentialsProvider {
+ fn new() -> Self {
+ let path = paths::config_dir().join("development_credentials");
+
+ Self { path }
+ }
+
+ fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
+ let json = std::fs::read(&self.path)?;
+ let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
+
+ Ok(credentials)
+ }
+
+ fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
+ let json = serde_json::to_string(credentials)?;
+ std::fs::write(&self.path, json)?;
+
+ Ok(())
+ }
+}
+
+impl CredentialsProvider for DevelopmentCredentialsProvider {
+ fn read_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+ async move {
+ Ok(self
+ .load_credentials()
+ .unwrap_or_default()
+ .get(url)
+ .cloned())
+ }
+ .boxed_local()
+ }
+
+ fn write_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ username: &'a str,
+ password: &'a [u8],
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ let mut credentials = self.load_credentials().unwrap_or_default();
+ credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
+
+ self.save_credentials(&credentials)
+ }
+ .boxed_local()
+ }
+
+ fn delete_credentials<'a>(
+ &'a self,
+ url: &'a str,
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ let mut credentials = self.load_credentials()?;
+ credentials.remove(url);
+
+ self.save_credentials(&credentials)
+ }
+ .boxed_local()
+ }
+}
@@ -15,4 +15,4 @@ path = "src/zed_env_vars.rs"
default = []
[dependencies]
-gpui.workspace = true
+env_var.workspace = true
@@ -1,45 +1,6 @@
-use gpui::SharedString;
+pub use env_var::{EnvVar, bool_env_var, env_var};
use std::sync::LazyLock;
/// Whether Zed is running in stateless mode.
/// When true, Zed will use in-memory databases instead of persistent storage.
pub static ZED_STATELESS: LazyLock<bool> = bool_env_var!("ZED_STATELESS");
-
-#[derive(Clone)]
-pub struct EnvVar {
- pub name: SharedString,
- /// Value of the environment variable. Also `None` when set to an empty string.
- pub value: Option<String>,
-}
-
-impl EnvVar {
- pub fn new(name: SharedString) -> Self {
- let value = std::env::var(name.as_str()).ok();
- if value.as_ref().is_some_and(|v| v.is_empty()) {
- Self { name, value: None }
- } else {
- Self { name, value }
- }
- }
-
- pub fn or(self, other: EnvVar) -> EnvVar {
- if self.value.is_some() { self } else { other }
- }
-}
-
-/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
-#[macro_export]
-macro_rules! env_var {
- ($name:expr) => {
- ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
- };
-}
-
-/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
-/// environment variable exists and is non-empty.
-#[macro_export]
-macro_rules! bool_env_var {
- ($name:expr) => {
- ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
- };
-}