diff --git a/Cargo.lock b/Cargo.lock index ce645cae5bf4bbf76dac037880e9e7038df67df9..aae7afecc5ea6f6ba3d63453321c829b677e1c58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -260,7 +260,6 @@ dependencies = [ "chrono", "client", "collections", - "credentials_provider", "env_logger 0.11.8", "feature_flags", "fs", @@ -289,6 +288,7 @@ dependencies = [ "util", "uuid", "watch", + "zed_credentials_provider", ] [[package]] @@ -2856,6 +2856,7 @@ dependencies = [ "chrono", "clock", "cloud_api_client", + "cloud_api_types", "cloud_llm_client", "collections", "credentials_provider", @@ -2869,6 +2870,7 @@ dependencies = [ "http_client", "http_client_tls", "httparse", + "language_model", "log", "objc2-foundation", "parking_lot", @@ -2900,6 +2902,7 @@ dependencies = [ "util", "windows 0.61.3", "worktree", + "zed_credentials_provider", ] [[package]] @@ -3059,6 +3062,7 @@ dependencies = [ "serde", "serde_json", "text", + "zed_credentials_provider", "zeta_prompt", ] @@ -4035,12 +4039,8 @@ name = "credentials_provider" version = "0.1.0" dependencies = [ "anyhow", - "futures 0.3.31", "gpui", - "paths", - "release_channel", "serde", - "serde_json", ] [[package]] @@ -5115,6 +5115,7 @@ dependencies = [ "collections", "copilot", "copilot_ui", + "credentials_provider", "ctor", "db", "edit_prediction_context", @@ -5157,6 +5158,7 @@ dependencies = [ "workspace", "worktree", "zed_actions", + "zed_credentials_provider", "zeta_prompt", "zlog", "zstd", @@ -5583,6 +5585,13 @@ dependencies = [ "log", ] +[[package]] +name = "env_var" +version = "0.1.0" +dependencies = [ + "gpui", +] + [[package]] name = "envy" version = "0.4.2" @@ -9315,12 +9324,12 @@ dependencies = [ "anthropic", "anyhow", "base64 0.22.1", - "client", "cloud_api_client", "cloud_api_types", "cloud_llm_client", "collections", "credentials_provider", + "env_var", "futures 0.3.31", "gpui", "http_client", @@ -9336,7 +9345,6 @@ dependencies = [ "smol", "thiserror 2.0.17", "util", - "zed_env_vars", ] [[package]] @@ -13137,6 +13145,7 @@ dependencies = [ "wax", "which 6.0.3", "worktree", + "zed_credentials_provider", "zeroize", "zlog", "ztracing", @@ -15746,6 +15755,7 @@ dependencies = [ "util", "workspace", "zed_actions", + "zed_credentials_provider", ] [[package]] @@ -22180,10 +22190,24 @@ dependencies = [ ] [[package]] -name = "zed_env_vars" +name = "zed_credentials_provider" version = "0.1.0" dependencies = [ + "anyhow", + "credentials_provider", + "futures 0.3.31", "gpui", + "paths", + "release_channel", + "serde", + "serde_json", +] + +[[package]] +name = "zed_env_vars" +version = "0.1.0" +dependencies = [ + "env_var", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3a393237ab9f5a5a8cd4b02517f6d22382ff51ff..81bbb1176ddddcc117fc9082586cbc08dbb95d61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ members = [ "crates/edit_prediction_ui", "crates/editor", "crates/encoding_selector", + "crates/env_var", "crates/etw_tracing", "crates/eval_cli", "crates/eval_utils", @@ -220,6 +221,7 @@ members = [ "crates/x_ai", "crates/zed", "crates/zed_actions", + "crates/zed_credentials_provider", "crates/zed_env_vars", "crates/zeta_prompt", "crates/zlog", @@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } encoding_selector = { path = "crates/encoding_selector" } +env_var = { path = "crates/env_var" } etw_tracing = { path = "crates/etw_tracing" } eval_utils = { path = "crates/eval_utils" } extension = { path = "crates/extension" } @@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" } x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } +zed_credentials_provider = { path = "crates/zed_credentials_provider" } zed_env_vars = { path = "crates/zed_env_vars" } edit_prediction = { path = "crates/edit_prediction" } zeta_prompt = { path = "crates/zeta_prompt" } diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index e7b67e37bf4a8b71664a78b99b757c6985794ec6..ba8b7ed867ea26bcdcdee7f8bf20390c2f9592b3 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -4,7 +4,7 @@ use crate::{ ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput, }; use Role::*; -use client::{Client, UserStore}; +use client::{Client, RefreshLlmTokenListener, UserStore}; use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind}; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; @@ -1423,7 +1423,8 @@ impl EditAgentTest { let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); settings::init(cx); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); }); diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 036a6f1030c43b16d51f864a1d0176891e90b772..9808b95dd0812f9a857da8a9c39e78fde40af1f9 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -6,7 +6,7 @@ use acp_thread::{ use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; use anyhow::Result; -use client::{Client, UserStore}; +use client::{Client, RefreshLlmTokenListener, UserStore}; use collections::IndexMap; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use feature_flags::FeatureFlagAppExt as _; @@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let clock = Arc::new(clock::FakeSystemClock::new()); let client = Client::new(clock, http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); LanguageModelRegistry::test(cx); }); @@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.set_http_client(Arc::new(http_client)); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); } }; diff --git a/crates/agent/src/tools/evals/streaming_edit_file.rs b/crates/agent/src/tools/evals/streaming_edit_file.rs index 6a55517037e54ae4166cd22427201d9325ef0f76..0c6290ec098f9c37a0f6a077daf0a041c013d8ff 100644 --- a/crates/agent/src/tools/evals/streaming_edit_file.rs +++ b/crates/agent/src/tools/evals/streaming_edit_file.rs @@ -6,7 +6,7 @@ use crate::{ }; use Role::*; use anyhow::{Context as _, Result}; -use client::{Client, UserStore}; +use client::{Client, RefreshLlmTokenListener, UserStore}; use fs::FakeFs; use futures::{FutureExt, StreamExt, future::LocalBoxFuture}; use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _}; @@ -274,7 +274,8 @@ impl StreamingEditToolTest { cx.set_http_client(http_client); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client, cx); }); diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 1542466be35bbce80983a73a3fc2e0998799160c..7151f0084b1cb7d9b206f57551ce715ef67483f7 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -32,7 +32,6 @@ futures.workspace = true gpui.workspace = true feature_flags.workspace = true gpui_tokio = { workspace = true, optional = true } -credentials_provider.workspace = true google_ai.workspace = true http_client.workspace = true indoc.workspace = true @@ -53,6 +52,7 @@ terminal.workspace = true uuid.workspace = true util.workspace = true watch.workspace = true +zed_credentials_provider.workspace = true [target.'cfg(unix)'.dependencies] libc.workspace = true diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index 0dcd2240d6ecf6dc052cdd55953cff8ec1442eae..fb8d0a515244576d2cf02e4989cbd71beca448c7 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -3,7 +3,6 @@ use acp_thread::AgentConnection; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use collections::HashSet; -use credentials_provider::CredentialsProvider; use fs::Fs; use gpui::{App, AppContext as _, Entity, Task}; use language_model::{ApiKey, EnvVar}; @@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task> { if let Some(key) = env_var.value { return Task::ready(Ok(key)); } - let credentials_provider = ::global(cx); + let credentials_provider = zed_credentials_provider::global(cx); let api_url = google_ai::API_URL.to_string(); cx.spawn(async move |cx| { Ok( diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 956d106df2a260bd2eb31c14f4f1f1705bf74cd6..aa29a0c230c13949b15f2b39a245ae41ead4884d 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,6 +1,7 @@ use crate::{AgentServer, AgentServerDelegate}; use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; use agent_client_protocol as acp; +use client::RefreshLlmTokenListener; use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::AppContext; use gpui::{Entity, TestAppContext}; @@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { cx.set_http_client(Arc::new(http_client)); let client = client::Client::production(cx); let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx)); - language_model::init(user_store, client, cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store, cx); #[cfg(test)] project::agent_server_store::AllAgentServersSettings::override_global( diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 4e3dd63b0337f9be54b550f4f4a6a5ca2e7cdd42..b97583377a00d28ea1a8aae6a1380cff3b69e6a0 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -815,7 +815,7 @@ mod tests { cx.set_global(store); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); editor::init(cx); }); diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index d5cf63f6cdde9a85a54daaa29f8fc2c6833bdd77..7b70740dd1ac462614a9d08d9e48d7d13ac2ed32 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1809,7 +1809,7 @@ mod tests { cx.set_global(settings_store); prompt_store::init(cx); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); }); let fs = FakeFs::new(cx.executor()); @@ -1966,7 +1966,7 @@ mod tests { cx.set_global(settings_store); prompt_store::init(cx); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); workspace::register_project_item::(cx); }); diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 20e0b702978b7e72a8526b03570854965335310c..39d70790e0d4a18554b2a1c11510e529d921cd1b 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { pub mod evals { use crate::InlineAssistant; use agent::ThreadStore; - use client::{Client, UserStore}; + use client::{Client, RefreshLlmTokenListener, UserStore}; use editor::{Editor, MultiBuffer, MultiBufferOffset}; use eval_utils::{EvalOutput, NoProcessor}; use fs::FakeFs; @@ -2091,7 +2091,8 @@ pub mod evals { client::init(&client, cx); workspace::init(app_state.clone(), cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); cx.set_global(inline_assistant); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 1edbb3399e4332e2ebd23f812c66697bda72d587..7bbaccb22e0e6c7508240186103e216f83be2f0c 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -22,6 +22,7 @@ base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true cloud_api_client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true @@ -35,6 +36,7 @@ gpui_tokio.workspace = true http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" +language_model.workspace = true log.workspace = true parking_lot.workspace = true paths.workspace = true @@ -60,6 +62,7 @@ tokio.workspace = true url.workspace = true util.workspace = true worktree.workspace = true +zed_credentials_provider.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 6a11a6b924eed3dfd79ff379638ed4085e2b7bcb..dfd9963a0ee52d167f8d4edb0b850f4debed7fd4 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,6 +1,7 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; +mod llm_token; mod proxy; pub mod telemetry; pub mod user; @@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{ http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; -use cloud_api_client::CloudApiClient; use cloud_api_client::websocket_protocol::MessageToClient; +use cloud_api_client::{ClientApiError, CloudApiClient}; +use cloud_api_types::OrganizationId; use credentials_provider::CredentialsProvider; use feature_flags::FeatureFlagAppExt as _; use futures::{ @@ -24,6 +26,7 @@ use futures::{ }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; +use language_model::LlmApiToken; use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; @@ -51,6 +54,7 @@ use tokio::net::TcpStream; use url::Url; use util::{ConnectionResult, ResultExt}; +pub use llm_token::*; pub use rpc::*; pub use telemetry_events::Event; pub use user::*; @@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider { impl ClientCredentialsProvider { pub fn new(cx: &App) -> Self { Self { - provider: ::global(cx), + provider: zed_credentials_provider::global(cx), } } @@ -568,6 +572,10 @@ impl Client { self.http.clone() } + pub fn credentials_provider(&self) -> Arc { + self.credentials_provider.provider.clone() + } + pub fn cloud_client(&self) -> Arc { self.cloud_client.clone() } @@ -1513,6 +1521,66 @@ impl Client { }) } + pub async fn acquire_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .acquire(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + Err(ClientApiError::Unauthorized).context("Failed to create LLM token") + } + Err(err) => Err(anyhow::Error::from(err)), + } + } + + pub async fn refresh_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .refresh(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + return Err(ClientApiError::Unauthorized).context("Failed to create LLM token"); + } + Err(err) => return Err(anyhow::Error::from(err)), + } + } + + pub async fn clear_and_refresh_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .clear_and_refresh(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + return Err(ClientApiError::Unauthorized).context("Failed to create LLM token"); + } + Err(err) => return Err(anyhow::Error::from(err)), + } + } + pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; self.cloud_client.clear_credentials(); diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816 --- /dev/null +++ b/crates/client/src/llm_token.rs @@ -0,0 +1,116 @@ +use super::{Client, UserStore}; +use cloud_api_types::websocket_protocol::MessageToClient; +use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; +use gpui::{ + App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, +}; +use language_model::LlmApiToken; +use std::sync::Arc; + +pub trait NeedsLlmTokenRefresh { + /// Returns whether the LLM token needs to be refreshed. + fn needs_llm_token_refresh(&self) -> bool; +} + +impl NeedsLlmTokenRefresh for http_client::Response { + fn needs_llm_token_refresh(&self) -> bool { + self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some() + || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some() + } +} + +enum TokenRefreshMode { + Refresh, + ClearAndRefresh, +} + +pub fn global_llm_token(cx: &App) -> LlmApiToken { + RefreshLlmTokenListener::global(cx) + .read(cx) + .llm_api_token + .clone() +} + +struct GlobalRefreshLlmTokenListener(Entity); + +impl Global for GlobalRefreshLlmTokenListener {} + +pub struct LlmTokenRefreshedEvent; + +pub struct RefreshLlmTokenListener { + client: Arc, + user_store: Entity, + llm_api_token: LlmApiToken, + _subscription: Subscription, +} + +impl EventEmitter for RefreshLlmTokenListener {} + +impl RefreshLlmTokenListener { + pub fn register(client: Arc, user_store: Entity, cx: &mut App) { + let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx)); + cx.set_global(GlobalRefreshLlmTokenListener(listener)); + } + + pub fn global(cx: &App) -> Entity { + GlobalRefreshLlmTokenListener::global(cx).0.clone() + } + + fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + client.add_message_to_client_handler({ + let this = cx.weak_entity(); + move |message, cx| { + if let Some(this) = this.upgrade() { + Self::handle_refresh_llm_token(this, message, cx); + } + } + }); + + let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { + if matches!(event, super::user::Event::OrganizationChanged) { + this.refresh(TokenRefreshMode::ClearAndRefresh, cx); + } + }); + + Self { + client, + user_store, + llm_api_token: LlmApiToken::default(), + _subscription: subscription, + } + } + + fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + cx.spawn(async move |this, cx| { + match mode { + TokenRefreshMode::Refresh => { + client + .refresh_llm_token(&llm_api_token, organization_id) + .await?; + } + TokenRefreshMode::ClearAndRefresh => { + client + .clear_and_refresh_llm_token(&llm_api_token, organization_id) + .await?; + } + } + this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) + }) + .detach_and_log_err(cx); + } + + fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); + } + } + } +} diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml index 0daaee8fb1420c76757ca898655e8dd1a5244d7e..801221d3128b8aa2d25175e086a741d5d85da626 100644 --- a/crates/codestral/Cargo.toml +++ b/crates/codestral/Cargo.toml @@ -22,6 +22,7 @@ log.workspace = true serde.workspace = true serde_json.workspace = true text.workspace = true +zed_credentials_provider.workspace = true zeta_prompt.workspace = true [dev-dependencies] diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 3930e2e873a91618bfae456bc188bbd90ffa64b9..7685fa8f5b1eae9e98a621484602e199c2b76f96 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option> { } pub fn load_codestral_api_key(cx: &mut App) -> Task> { + let credentials_provider = zed_credentials_provider::global(cx); let api_url = codestral_api_url(cx); codestral_api_key_state(cx).update(cx, |key_state, cx| { - key_state.load_if_needed(api_url, |s| s, cx) + key_state.load_if_needed(api_url, |s| s, credentials_provider, cx) }) } diff --git a/crates/credentials_provider/Cargo.toml b/crates/credentials_provider/Cargo.toml index bf47bb24b12b90d54bc04f766efe06489c730b43..da83c0cd79a1b71bbb84746b3e893f33094783d6 100644 --- a/crates/credentials_provider/Cargo.toml +++ b/crates/credentials_provider/Cargo.toml @@ -13,9 +13,5 @@ path = "src/credentials_provider.rs" [dependencies] anyhow.workspace = true -futures.workspace = true gpui.workspace = true -paths.workspace = true -release_channel.workspace = true serde.workspace = true -serde_json.workspace = true diff --git a/crates/credentials_provider/src/credentials_provider.rs b/crates/credentials_provider/src/credentials_provider.rs index 249b8333e114223aa558cd33637fd103294a8f8d..b98e97673cc11272826af24c76e8a0a6a38b9211 100644 --- a/crates/credentials_provider/src/credentials_provider.rs +++ b/crates/credentials_provider/src/credentials_provider.rs @@ -1,26 +1,8 @@ -use std::collections::HashMap; use std::future::Future; -use std::path::PathBuf; use std::pin::Pin; -use std::sync::{Arc, LazyLock}; use anyhow::Result; -use futures::FutureExt as _; -use gpui::{App, AsyncApp}; -use release_channel::ReleaseChannel; - -/// An environment variable whose presence indicates that the system keychain -/// should be used in development. -/// -/// By default, running Zed in development uses the development credentials -/// provider. Setting this environment variable allows you to interact with the -/// system keychain (for instance, if you need to test something). -/// -/// Only works in development. Setting this environment variable in other -/// release channels is a no-op. -static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock = LazyLock::new(|| { - std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty()) -}); +use gpui::AsyncApp; /// A provider for credentials. /// @@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync { cx: &'a AsyncApp, ) -> Pin> + 'a>>; } - -impl dyn CredentialsProvider { - /// Returns the global [`CredentialsProvider`]. - pub fn global(cx: &App) -> Arc { - // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it - // seems like this is a false positive from Clippy. - #[allow(clippy::arc_with_non_send_sync)] - Self::new(cx) - } - - fn new(cx: &App) -> Arc { - let use_development_provider = match ReleaseChannel::try_global(cx) { - Some(ReleaseChannel::Dev) => { - // In development we default to using the development - // credentials provider to avoid getting spammed by relentless - // keychain access prompts. - // - // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment - // variable is set, we will use the actual keychain. - !*ZED_DEVELOPMENT_USE_KEYCHAIN - } - Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) - | None => false, - }; - - if use_development_provider { - Arc::new(DevelopmentCredentialsProvider::new()) - } else { - Arc::new(KeychainCredentialsProvider) - } - } -} - -/// A credentials provider that stores credentials in the system keychain. -struct KeychainCredentialsProvider; - -impl CredentialsProvider for KeychainCredentialsProvider { - fn read_credentials<'a>( - &'a self, - url: &'a str, - cx: &'a AsyncApp, - ) -> Pin)>>> + 'a>> { - async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local() - } - - fn write_credentials<'a>( - &'a self, - url: &'a str, - username: &'a str, - password: &'a [u8], - cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - cx.update(move |cx| cx.write_credentials(url, username, password)) - .await - } - .boxed_local() - } - - fn delete_credentials<'a>( - &'a self, - url: &'a str, - cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local() - } -} - -/// A credentials provider that stores credentials in a local file. -/// -/// This MUST only be used in development, as this is not a secure way of storing -/// credentials on user machines. -/// -/// Its existence is purely to work around the annoyance of having to constantly -/// re-allow access to the system keychain when developing Zed. -struct DevelopmentCredentialsProvider { - path: PathBuf, -} - -impl DevelopmentCredentialsProvider { - fn new() -> Self { - let path = paths::config_dir().join("development_credentials"); - - Self { path } - } - - fn load_credentials(&self) -> Result)>> { - let json = std::fs::read(&self.path)?; - let credentials: HashMap)> = serde_json::from_slice(&json)?; - - Ok(credentials) - } - - fn save_credentials(&self, credentials: &HashMap)>) -> Result<()> { - let json = serde_json::to_string(credentials)?; - std::fs::write(&self.path, json)?; - - Ok(()) - } -} - -impl CredentialsProvider for DevelopmentCredentialsProvider { - fn read_credentials<'a>( - &'a self, - url: &'a str, - _cx: &'a AsyncApp, - ) -> Pin)>>> + 'a>> { - async move { - Ok(self - .load_credentials() - .unwrap_or_default() - .get(url) - .cloned()) - } - .boxed_local() - } - - fn write_credentials<'a>( - &'a self, - url: &'a str, - username: &'a str, - password: &'a [u8], - _cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - let mut credentials = self.load_credentials().unwrap_or_default(); - credentials.insert(url.to_string(), (username.to_string(), password.to_vec())); - - self.save_credentials(&credentials) - } - .boxed_local() - } - - fn delete_credentials<'a>( - &'a self, - url: &'a str, - _cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - let mut credentials = self.load_credentials()?; - credentials.remove(url); - - self.save_credentials(&credentials) - } - .boxed_local() - } -} diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 75a589dea8f9c7fefe7bf13400cbdde54bf90bf1..eabb1641fd4fbec7b2f8ef0ba399a8fe9600dfa3 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -26,6 +26,7 @@ cloud_llm_client.workspace = true collections.workspace = true copilot.workspace = true copilot_ui.workspace = true +credentials_provider.workspace = true db.workspace = true edit_prediction_types.workspace = true edit_prediction_context.workspace = true @@ -65,6 +66,7 @@ uuid.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true +zed_credentials_provider.workspace = true zeta_prompt.workspace = true zstd.workspace = true diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 5eb422246775c4409f7f15e3a672a2d407386acc..9463456132ce391b54aca8327cb6f900d81481d6 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String { mod tests { use super::*; use crate::EditPredictionStore; + use client::RefreshLlmTokenListener; use client::{Client, UserStore}; use clock::FakeSystemClock; use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient}; @@ -548,7 +549,8 @@ mod tests { let http_client = FakeHttpClient::with_404_response(); let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); EditPredictionStore::global(&client, &user_store, cx); }) } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 61690c470829ca4bb16a6af9f1df2ea6e7cc6023..280427df006b510e1854ffb40cd7f995fcd9fdc6 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use client::{Client, EditPredictionUsage, UserStore}; +use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token}; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, @@ -11,6 +11,7 @@ use cloud_llm_client::{ }; use collections::{HashMap, HashSet}; use copilot::{Copilot, Reinstall, SignIn, SignOut}; +use credentials_provider::CredentialsProvider; use db::kvp::{Dismissable, KeyValueStore}; use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; @@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::{LlmApiToken, NeedsLlmTokenRefresh}; +use language_model::LlmApiToken; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; @@ -150,6 +151,7 @@ pub struct EditPredictionStore { rated_predictions: HashSet, #[cfg(test)] settled_event_callback: Option>, + credentials_provider: Arc, } pub(crate) struct EditPredictionRejectionPayload { @@ -746,7 +748,7 @@ impl EditPredictionStore { pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let data_collection_choice = Self::load_data_collection_choice(cx); - let llm_token = LlmApiToken::global(cx); + let llm_token = global_llm_token(cx); let (reject_tx, reject_rx) = mpsc::unbounded(); cx.background_spawn({ @@ -787,6 +789,8 @@ impl EditPredictionStore { .log_err(); }); + let credentials_provider = zed_credentials_provider::global(cx); + let this = Self { projects: HashMap::default(), client, @@ -807,6 +811,8 @@ impl EditPredictionStore { shown_predictions: Default::default(), #[cfg(test)] settled_event_callback: None, + + credentials_provider, }; this @@ -871,7 +877,9 @@ impl EditPredictionStore { let experiments = cx .background_spawn(async move { let http_client = client.http_client(); - let token = llm_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_token, organization_id.clone()) + .await?; let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?; let request = http_client::Request::builder() .method(Method::GET) @@ -2315,7 +2323,10 @@ impl EditPredictionStore { zeta::request_prediction_with_zeta(self, inputs, capture_data, cx) } EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx), - EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), + EditPredictionModel::Mercury => { + self.mercury + .request_prediction(inputs, self.credentials_provider.clone(), cx) + } }; cx.spawn(async move |this, cx| { @@ -2536,12 +2547,15 @@ impl EditPredictionStore { Res: DeserializeOwned, { let http_client = client.http_client(); - let mut token = if require_auth { - Some(llm_token.acquire(&client, organization_id.clone()).await?) + Some( + client + .acquire_llm_token(&llm_token, organization_id.clone()) + .await?, + ) } else { - llm_token - .acquire(&client, organization_id.clone()) + client + .acquire_llm_token(&llm_token, organization_id.clone()) .await .ok() }; @@ -2585,7 +2599,11 @@ impl EditPredictionStore { return Ok((serde_json::from_slice(&body)?, usage)); } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() { did_retry = true; - token = Some(llm_token.refresh(&client, organization_id.clone()).await?); + token = Some( + client + .refresh_llm_token(&llm_token, organization_id.clone()) + .await?, + ); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 6fe61338e764a40aec9cf6f3191f1191bafe9200..1ba8b27aa785024a47a09c3299a1f3786a028ccf 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,6 +1,6 @@ use super::*; use crate::udiff::apply_diff_to_string; -use client::{UserStore, test::FakeServer}; +use client::{RefreshLlmTokenListener, UserStore, test::FakeServer}; use clock::FakeSystemClock; use clock::ReplicaId; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; @@ -23,7 +23,7 @@ use language::{ Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity, Operation, Point, Selection, SelectionGoal, }; -use language_model::RefreshLlmTokenListener; + use lsp::LanguageServerId; use parking_lot::Mutex; use pretty_assertions::{assert_eq, assert_matches}; @@ -2439,7 +2439,8 @@ fn init_test_with_fake_client( client.cloud_client().set_credentials(1, "test".into()); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); let ep_store = EditPredictionStore::global(&client, &user_store, cx); ( @@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx))); cx.update(|cx| { - language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); }); let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index df47a38062344512a784c6d2feb563e9848afb27..155fd449904687081da0a9eae3d4731863f02254 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -5,6 +5,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use cloud_llm_client::EditPredictionRejectReason; +use credentials_provider::CredentialsProvider; use futures::AsyncReadExt as _; use gpui::{ App, AppContext as _, Context, Entity, Global, SharedString, Task, @@ -51,10 +52,11 @@ impl Mercury { debug_tx, .. }: EditPredictionModelInput, + credentials_provider: Arc, cx: &mut Context, ) -> Task>> { self.api_token.update(cx, |key_state, cx| { - _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx); + _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx); }); let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else { return Task::ready(Ok(None)); @@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity { } pub fn load_mercury_api_token(cx: &mut App) -> Task> { + let credentials_provider = zed_credentials_provider::global(cx); mercury_api_token(cx).update(cx, |key_state, cx| { - key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx) + key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx) }) } diff --git a/crates/edit_prediction/src/open_ai_compatible.rs b/crates/edit_prediction/src/open_ai_compatible.rs index ca378ba1fd0bc9bdbb3e85c7610e1b94c1be388f..9a11164822857d78c2fe0d9245faeb5d4f7400a0 100644 --- a/crates/edit_prediction/src/open_ai_compatible.rs +++ b/crates/edit_prediction/src/open_ai_compatible.rs @@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity { pub fn load_open_ai_compatible_api_token( cx: &mut App, ) -> Task> { + let credentials_provider = zed_credentials_provider::global(cx); let api_url = open_ai_compatible_api_url(cx); open_ai_compatible_api_token(cx).update(cx, |key_state, cx| { - key_state.load_if_needed(api_url, |s| s, cx) + key_state.load_if_needed(api_url, |s| s, credentials_provider, cx) }) } diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index 3a204a7052f8a41d6e7c2c49860b62f588358644..48b7381020f48d868d9f6413ef343b30718e5be6 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -1,4 +1,4 @@ -use client::{Client, ProxySettings, UserStore}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore}; use db::AppDatabase; use extension::ExtensionHostProxy; use fs::RealFs; @@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState { debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/env_var/Cargo.toml b/crates/env_var/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..2cbbd08c7833d3e57a09766d42ffffe35c620a93 --- /dev/null +++ b/crates/env_var/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "env_var" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/env_var.rs" + +[dependencies] +gpui.workspace = true diff --git a/crates/env_var/LICENSE-GPL b/crates/env_var/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/env_var/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/env_var/src/env_var.rs b/crates/env_var/src/env_var.rs new file mode 100644 index 0000000000000000000000000000000000000000..79f671e0147ebfaad4ab76a123cc477dc7e55cb7 --- /dev/null +++ b/crates/env_var/src/env_var.rs @@ -0,0 +1,40 @@ +use gpui::SharedString; + +#[derive(Clone)] +pub struct EnvVar { + pub name: SharedString, + /// Value of the environment variable. Also `None` when set to an empty string. + pub value: Option, +} + +impl EnvVar { + pub fn new(name: SharedString) -> Self { + let value = std::env::var(name.as_str()).ok(); + if value.as_ref().is_some_and(|v| v.is_empty()) { + Self { name, value: None } + } else { + Self { name, value } + } + } + + pub fn or(self, other: EnvVar) -> EnvVar { + if self.value.is_some() { self } else { other } + } +} + +/// Creates a `LazyLock` expression for use in a `static` declaration. +#[macro_export] +macro_rules! env_var { + ($name:expr) => { + ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into())) + }; +} + +/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the +/// environment variable exists and is non-empty. +#[macro_export] +macro_rules! bool_env_var { + ($name:expr) => { + ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some()) + }; +} diff --git a/crates/eval_cli/src/headless.rs b/crates/eval_cli/src/headless.rs index 72feaacbae270224240f1da9e6e6c1008ba97c84..0ddd99e8f8abd9dbd73e1d7461526f3e7cb24f11 100644 --- a/crates/eval_cli/src/headless.rs +++ b/crates/eval_cli/src/headless.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::sync::Arc; -use client::{Client, ProxySettings, UserStore}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore}; use db::AppDatabase; use extension::ExtensionHostProxy; use fs::RealFs; @@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc { let extension_host_proxy = ExtensionHostProxy::global(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 911100fc25b498ba5471c85d6177052495974665..4712d86dff6c44f9cdd8576a08349ccfa7d0ecca 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true -client.workspace = true cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true +env_var.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -40,7 +40,6 @@ serde_json.workspace = true smol.workspace = true thiserror.workspace = true util.workspace = true -zed_env_vars.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/api_key.rs b/crates/language_model/src/api_key.rs index 754fde069295d8799820020bef286b1a1a3c590c..4be5a64d3db6231c98b830a524d5e299faace457 100644 --- a/crates/language_model/src/api_key.rs +++ b/crates/language_model/src/api_key.rs @@ -1,5 +1,6 @@ use anyhow::{Result, anyhow}; use credentials_provider::CredentialsProvider; +use env_var::EnvVar; use futures::{FutureExt, future}; use gpui::{AsyncApp, Context, SharedString, Task}; use std::{ @@ -7,7 +8,6 @@ use std::{ sync::Arc, }; use util::ResultExt as _; -use zed_env_vars::EnvVar; use crate::AuthenticateError; @@ -101,6 +101,7 @@ impl ApiKeyState { url: SharedString, key: Option, get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + provider: Arc, cx: &Context, ) -> Task> { if self.is_from_env_var() { @@ -108,18 +109,14 @@ impl ApiKeyState { "bug: attempted to store API key in system keychain when API key is from env var", ))); } - let credentials_provider = ::global(cx); cx.spawn(async move |ent, cx| { if let Some(key) = &key { - credentials_provider + provider .write_credentials(&url, "Bearer", key.as_bytes(), cx) .await .log_err(); } else { - credentials_provider - .delete_credentials(&url, cx) - .await - .log_err(); + provider.delete_credentials(&url, cx).await.log_err(); } ent.update(cx, |ent, cx| { let this = get_this(ent); @@ -144,12 +141,13 @@ impl ApiKeyState { &mut self, url: SharedString, get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + provider: Arc, cx: &mut Context, ) { if url != self.url { if !self.is_from_env_var() { // loading will continue even though this result task is dropped - let _task = self.load_if_needed(url, get_this, cx); + let _task = self.load_if_needed(url, get_this, provider, cx); } } } @@ -163,6 +161,7 @@ impl ApiKeyState { &mut self, url: SharedString, get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + provider: Arc, cx: &mut Context, ) -> Task> { if let LoadStatus::Loaded { .. } = &self.load_status @@ -185,7 +184,7 @@ impl ApiKeyState { let task = if let Some(load_task) = &self.load_task { load_task.clone() } else { - let load_task = Self::load(url.clone(), get_this.clone(), cx).shared(); + let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared(); self.url = url; self.load_status = LoadStatus::NotPresent; self.load_task = Some(load_task.clone()); @@ -206,14 +205,13 @@ impl ApiKeyState { fn load( url: SharedString, get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + provider: Arc, cx: &Context, ) -> Task<()> { - let credentials_provider = ::global(cx); cx.spawn({ async move |ent, cx| { let load_status = - ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx) - .await; + ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await; ent.update(cx, |ent, cx| { let this = get_this(ent); this.url = url; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index ce71cee6bcaf4f7ea1e210cc3756bd3162715f55..3f309b7b1d4152c54324efaaf0ad3bdb7035eea4 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,12 +11,10 @@ pub mod tool_schema; pub mod fake_provider; use anyhow::{Result, anyhow}; -use client::Client; -use client::UserStore; use cloud_llm_client::CompletionRequestStatus; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; use parking_lot::Mutex; @@ -36,15 +34,10 @@ pub use crate::registry::*; pub use crate::request::*; pub use crate::role::*; pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use env_var::{EnvVar, env_var}; pub use provider::*; -pub use zed_env_vars::{EnvVar, env_var}; -pub fn init(user_store: Entity, client: Arc, cx: &mut App) { - init_settings(cx); - RefreshLlmTokenListener::register(client, user_store, cx); -} - -pub fn init_settings(cx: &mut App) { +pub fn init(cx: &mut App) { registry::init(cx); } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index a1362d78292082522f4e883efe42b2ca1e0a0300..db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,16 +1,9 @@ use std::fmt; use std::sync::Arc; -use anyhow::{Context as _, Result}; -use client::Client; -use client::UserStore; use cloud_api_client::ClientApiError; +use cloud_api_client::CloudApiClient; use cloud_api_types::OrganizationId; -use cloud_api_types::websocket_protocol::MessageToClient; -use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; -use gpui::{ - App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, -}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError { pub struct LlmApiToken(Arc>>); impl LlmApiToken { - pub fn global(cx: &App) -> Self { - RefreshLlmTokenListener::global(cx) - .read(cx) - .llm_api_token - .clone() - } - pub async fn acquire( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { + ) -> Result { let lock = self.0.upgradable_read().await; if let Some(token) = lock.as_ref() { Ok(token.to_string()) @@ -49,6 +36,7 @@ impl LlmApiToken { Self::fetch( RwLockUpgradableReadGuard::upgrade(lock).await, client, + system_id, organization_id, ) .await @@ -57,10 +45,11 @@ impl LlmApiToken { pub async fn refresh( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { - Self::fetch(self.0.write().await, client, organization_id).await + ) -> Result { + Self::fetch(self.0.write().await, client, system_id, organization_id).await } /// Clears the existing token before attempting to fetch a new one. @@ -69,28 +58,22 @@ impl LlmApiToken { /// leave a token for the wrong organization. pub async fn clear_and_refresh( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { + ) -> Result { let mut lock = self.0.write().await; *lock = None; - Self::fetch(lock, client, organization_id).await + Self::fetch(lock, client, system_id, organization_id).await } async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { - let system_id = client - .telemetry() - .system_id() - .map(|system_id| system_id.to_string()); - - let result = client - .cloud_client() - .create_llm_token(system_id, organization_id) - .await; + ) -> Result { + let result = client.create_llm_token(system_id, organization_id).await; match result { Ok(response) => { *lock = Some(response.token.0.clone()); @@ -98,112 +81,7 @@ impl LlmApiToken { } Err(err) => { *lock = None; - match err { - ClientApiError::Unauthorized => { - client.request_sign_out(); - Err(err).context("Failed to create LLM token") - } - ClientApiError::Other(err) => Err(err), - } - } - } - } -} - -pub trait NeedsLlmTokenRefresh { - /// Returns whether the LLM token needs to be refreshed. - fn needs_llm_token_refresh(&self) -> bool; -} - -impl NeedsLlmTokenRefresh for http_client::Response { - fn needs_llm_token_refresh(&self) -> bool { - self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some() - || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some() - } -} - -enum TokenRefreshMode { - Refresh, - ClearAndRefresh, -} - -struct GlobalRefreshLlmTokenListener(Entity); - -impl Global for GlobalRefreshLlmTokenListener {} - -pub struct LlmTokenRefreshedEvent; - -pub struct RefreshLlmTokenListener { - client: Arc, - user_store: Entity, - llm_api_token: LlmApiToken, - _subscription: Subscription, -} - -impl EventEmitter for RefreshLlmTokenListener {} - -impl RefreshLlmTokenListener { - pub fn register(client: Arc, user_store: Entity, cx: &mut App) { - let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx)); - cx.set_global(GlobalRefreshLlmTokenListener(listener)); - } - - pub fn global(cx: &App) -> Entity { - GlobalRefreshLlmTokenListener::global(cx).0.clone() - } - - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - client.add_message_to_client_handler({ - let this = cx.weak_entity(); - move |message, cx| { - if let Some(this) = this.upgrade() { - Self::handle_refresh_llm_token(this, message, cx); - } - } - }); - - let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { - if matches!(event, client::user::Event::OrganizationChanged) { - this.refresh(TokenRefreshMode::ClearAndRefresh, cx); - } - }); - - Self { - client, - user_store, - llm_api_token: LlmApiToken::default(), - _subscription: subscription, - } - } - - fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = self - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - match mode { - TokenRefreshMode::Refresh => { - llm_api_token.refresh(&client, organization_id).await?; - } - TokenRefreshMode::ClearAndRefresh => { - llm_api_token - .clear_and_refresh(&client, organization_id) - .await?; - } - } - this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) - }) - .detach_and_log_err(cx); - } - - fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { - match message { - MessageToClient::UserUpdated => { - this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); + Err(err) } } } diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 4db1db8fa6ce5afb9d77a6685bfc0861d0fb8885..3154db91a43d1381f5b3f122a724be249adeb79b 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use ::settings::{Settings, SettingsStore}; use client::{Client, UserStore}; use collections::HashSet; +use credentials_provider::CredentialsProvider; use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; @@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity, client: Arc, cx: &mut App) { + let credentials_provider = client.credentials_provider(); let registry = LanguageModelRegistry::global(cx); registry.update(cx, |registry, cx| { - register_language_model_providers(registry, user_store, client.clone(), cx); + register_language_model_providers( + registry, + user_store, + client.clone(), + credentials_provider.clone(), + cx, + ); }); // Subscribe to extension store events to track LLM extension installations @@ -104,6 +112,7 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { &HashSet::default(), &openai_compatible_providers, client.clone(), + credentials_provider.clone(), cx, ); }); @@ -124,6 +133,7 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { &openai_compatible_providers, &openai_compatible_providers_new, client.clone(), + credentials_provider.clone(), cx, ); }); @@ -138,6 +148,7 @@ fn register_openai_compatible_providers( old: &HashSet>, new: &HashSet>, client: Arc, + credentials_provider: Arc, cx: &mut Context, ) { for provider_id in old { @@ -152,6 +163,7 @@ fn register_openai_compatible_providers( Arc::new(OpenAiCompatibleLanguageModelProvider::new( provider_id.clone(), client.http_client(), + credentials_provider.clone(), cx, )), cx, @@ -164,6 +176,7 @@ fn register_language_model_providers( registry: &mut LanguageModelRegistry, user_store: Entity, client: Arc, + credentials_provider: Arc, cx: &mut Context, ) { registry.register_provider( @@ -177,62 +190,105 @@ fn register_language_model_providers( registry.register_provider( Arc::new(AnthropicLanguageModelProvider::new( client.http_client(), + credentials_provider.clone(), cx, )), cx, ); registry.register_provider( - Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(OpenAiLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(OllamaLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(LmStudioLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(DeepSeekLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(GoogleLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - MistralLanguageModelProvider::global(client.http_client(), cx), + MistralLanguageModelProvider::global( + client.http_client(), + credentials_provider.clone(), + cx, + ), cx, ); registry.register_provider( - Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(BedrockLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( Arc::new(OpenRouterLanguageModelProvider::new( client.http_client(), + credentials_provider.clone(), cx, )), cx, ); registry.register_provider( - Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(VercelLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( Arc::new(VercelAiGatewayLanguageModelProvider::new( client.http_client(), + credentials_provider.clone(), cx, )), cx, ); registry.register_provider( - Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(XAiLanguageModelProvider::new( + client.http_client(), + credentials_provider.clone(), + cx, + )), cx, ); registry.register_provider( - Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)), + Arc::new(OpenCodeLanguageModelProvider::new( + client.http_client(), + credentials_provider, + cx, + )), cx, ); registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx); diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index a98a0ce142dfdbaaaddc056ab378455a45147830..c1b8bc1a3bb1b602b67ae5563d8acc3b05a94d47 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -6,6 +6,7 @@ use anthropic::{ }; use anyhow::Result; use collections::{BTreeMap, HashMap}; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; use http_client::HttpClient; @@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -59,30 +61,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = AnthropicLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = AnthropicLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl AnthropicLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index f53f145dbd387aa948b977d854ba77f1cbe49ded..4320763e2c5c6de7f3fe9238d7a4991565c3bfcd 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -195,12 +195,13 @@ pub struct State { settings: Option, /// Whether credentials came from environment variables (only relevant for static credentials) credentials_from_env: bool, + credentials_provider: Arc, _subscription: Subscription, } impl State { fn reset_auth(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); + let credentials_provider = self.credentials_provider.clone(); cx.spawn(async move |this, cx| { credentials_provider .delete_credentials(AMAZON_AWS_URL, cx) @@ -220,7 +221,7 @@ impl State { cx: &mut Context, ) -> Task> { let auth = credentials.clone().into_auth(); - let credentials_provider = ::global(cx); + let credentials_provider = self.credentials_provider.clone(); cx.spawn(async move |this, cx| { credentials_provider .write_credentials( @@ -287,7 +288,7 @@ impl State { &self, cx: &mut Context, ) -> Task> { - let credentials_provider = ::global(cx); + let credentials_provider = self.credentials_provider.clone(); cx.spawn(async move |this, cx| { // Try environment variables first let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value { @@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider { } impl BedrockLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| State { auth: None, settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()), credentials_from_env: false, + credentials_provider, _subscription: cx.observe_global::(|_, cx| { cx.notify(); }), diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index f9372a4d7ea9c078c58f633cc58bd5597ef49212..29623cc998ad0fe933e9a29c45c651f7be010b07 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,7 +1,9 @@ use ai_onboarding::YoungAccountBanner; use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; -use client::{Client, UserStore, zed_urls}; +use client::{ + Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls, +}; use cloud_api_types::{OrganizationId, Plan}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, @@ -24,10 +26,9 @@ use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, - OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, - RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID, - ZED_CLOUD_PROVIDER_NAME, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID, + OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -111,7 +112,7 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let llm_api_token = LlmApiToken::global(cx); + let llm_api_token = global_llm_token(cx); Self { client: client.clone(), llm_api_token, @@ -226,7 +227,9 @@ impl State { organization_id: Option, ) -> Result { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_api_token, organization_id) + .await?; let request = http_client::Request::builder() .method(Method::GET) @@ -414,8 +417,8 @@ impl CloudLanguageModel { ) -> Result { let http_client = &client.http_client(); - let mut token = llm_api_token - .acquire(&client, organization_id.clone()) + let mut token = client + .acquire_llm_token(&llm_api_token, organization_id.clone()) .await?; let mut refreshed_token = false; @@ -447,8 +450,8 @@ impl CloudLanguageModel { } if !refreshed_token && response.needs_llm_token_refresh() { - token = llm_api_token - .refresh(&client, organization_id.clone()) + token = client + .refresh_llm_token(&llm_api_token, organization_id.clone()) .await?; refreshed_token = true; continue; @@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel { into_google(request, model_id.clone(), GoogleModelMode::Default); async move { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_api_token, organization_id) + .await?; let request_body = CountTokensBody { provider: cloud_llm_client::LanguageModelProvider::Google, diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index bd2469d865fd8421d6ad31208e6a4be413c0fe14..0cfb1af425c7cb0279d98fa124a589437f1bb1a1 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,5 +1,6 @@ use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; +use credentials_provider::CredentialsProvider; use deepseek::DEEPSEEK_API_URL; use futures::Stream; @@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -57,30 +59,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = DeepSeekLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = DeepSeekLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl DeepSeekLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 8fdfb514ac6e872bd24968d33f2c1169401d5a9c..244f7835a85ff67f0c4826321910ea13516371cb 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,5 +1,6 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use google_ai::{ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, @@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; @@ -76,30 +78,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = GoogleLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = GoogleLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl GoogleLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 6c8d3c6e1c50185a4b09e9afc80c688f4c8d1381..0d60fef16791087e35bac7d846b2ec99821d5470 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -1,5 +1,6 @@ use anyhow::{Result, anyhow}; use collections::HashMap; +use credentials_provider::CredentialsProvider; use fs::Fs; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; @@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, http_client: Arc, available_models: Vec, fetch_model_task: Option>>, @@ -64,10 +66,15 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = LmStudioLanguageModelProvider::api_url(cx).into(); - let task = self - .api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx); + let task = self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); self.restart_fetch_models_task(cx); task } @@ -114,10 +121,14 @@ impl State { } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = LmStudioLanguageModelProvider::api_url(cx).into(); - let _task = self - .api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx); + let _task = self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); if self.is_authenticated() { return Task::ready(Ok(())); @@ -152,16 +163,29 @@ impl State { } impl LmStudioLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let this = Self { http_client: http_client.clone(), state: cx.new(|cx| { let subscription = cx.observe_global::({ let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone(); move |this: &mut State, cx| { - let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio; - if &settings != new_settings { - settings = new_settings.clone(); + let new_settings = + AllLanguageModelSettings::get_global(cx).lmstudio.clone(); + if settings != new_settings { + let credentials_provider = this.credentials_provider.clone(); + let api_url = Self::api_url(cx).into(); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); + settings = new_settings; this.restart_fetch_models_task(cx); cx.notify(); } @@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider { Self::api_url(cx).into(), (*API_KEY_ENV_VAR).clone(), ), + credentials_provider, http_client, available_models: Default::default(), fetch_model_task: None, diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 72f0cae2993da4efb3e19cb19ec42b186290920d..4cd1375fe50cd792a3a7bc8c85ba7b5b5af9520a 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,5 +1,6 @@ use anyhow::{Result, anyhow}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window}; @@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -51,15 +53,26 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = MistralLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = MistralLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } @@ -73,20 +86,30 @@ impl MistralLanguageModelProvider { .map(|this| &this.0) } - pub fn global(http_client: Arc, cx: &mut App) -> Arc { + pub fn global( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Arc { if let Some(this) = cx.try_global::() { return this.0.clone(); } let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 551fcd55358c11bdf64bf2f27b32fa9a7f702252..49c326683a225bf73f604a584307ea1316a710c4 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -1,4 +1,5 @@ use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; use fs::Fs; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{Stream, TryFutureExt, stream}; @@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, http_client: Arc, fetched_models: Vec, fetch_model_task: Option>>, @@ -65,10 +67,15 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OllamaLanguageModelProvider::api_url(cx); - let task = self - .api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx); + let task = self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); self.fetched_models.clear(); cx.spawn(async move |this, cx| { @@ -80,10 +87,14 @@ impl State { } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OllamaLanguageModelProvider::api_url(cx); - let task = self - .api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx); + let task = self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); // Always try to fetch models - if no API key is needed (local Ollama), it will work // If API key is needed and provided, it will work @@ -157,7 +168,11 @@ impl State { } impl OllamaLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let this = Self { http_client: http_client.clone(), state: cx.new(|cx| { @@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider { let url_changed = last_settings.api_url != current_settings.api_url; last_settings = current_settings.clone(); if url_changed { + let credentials_provider = this.credentials_provider.clone(); + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); this.fetched_models.clear(); this.authenticate(cx).detach(); } @@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider { fetched_models: Default::default(), fetch_model_task: None, api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }), }; diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 9289c66b2a4c9213826d2d027555511c9746d00e..6a2313487f4a1922cdc2aa20d23ede01c4b7d158 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,5 +1,6 @@ use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; +use credentials_provider::CredentialsProvider; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; @@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -63,30 +65,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenAiLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenAiLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl OpenAiLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 87a08097782198238a5d2467af32cc66b3183664..9f63a1e1a039998c275637f3831b51474c8049ac 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -1,5 +1,6 @@ use anyhow::Result; use convert_case::{Case, Casing}; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; @@ -44,6 +45,7 @@ pub struct State { id: Arc, api_key_state: ApiKeyState, settings: OpenAiCompatibleSettings, + credentials_provider: Arc, } impl State { @@ -52,20 +54,36 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = SharedString::new(self.settings.api_url.as_str()); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = SharedString::new(self.settings.api_url.clone()); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl OpenAiCompatibleLanguageModelProvider { - pub fn new(id: Arc, http_client: Arc, cx: &mut App) -> Self { + pub fn new( + id: Arc, + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> { crate::AllLanguageModelSettings::get_global(cx) .openai_compatible @@ -79,10 +97,12 @@ impl OpenAiCompatibleLanguageModelProvider { return; }; if &this.settings != &settings { + let credentials_provider = this.credentials_provider.clone(); let api_url = SharedString::new(settings.api_url.as_str()); this.api_key_state.handle_url_change( api_url, |this| &mut this.api_key_state, + credentials_provider, cx, ); this.settings = settings; @@ -98,6 +118,7 @@ impl OpenAiCompatibleLanguageModelProvider { EnvVar::new(api_key_env_var_name), ), settings, + credentials_provider, } }); diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index a4a679be73c0276351a6524ad7e8fc40e2c26860..09c8eb768d12c61ed1dc86a1251ad52114be6162 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -1,5 +1,6 @@ use anyhow::Result; use collections::HashMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task}; use http_client::HttpClient; @@ -42,6 +43,7 @@ pub struct OpenRouterLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, @@ -53,16 +55,26 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenRouterLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenRouterLanguageModelProvider::api_url(cx); - let task = self - .api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx); + let task = self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.spawn(async move |this, cx| { let result = task.await; @@ -114,7 +126,11 @@ impl State { } impl OpenRouterLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::({ let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone(); @@ -131,6 +147,7 @@ impl OpenRouterLanguageModelProvider { .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, http_client: http_client.clone(), available_models: Vec::new(), fetch_models_task: None, diff --git a/crates/language_models/src/provider/opencode.rs b/crates/language_models/src/provider/opencode.rs index f3953f3cafa4a1f59ff86004628c0a4022f6257e..aae3a552544ebf2cc59255da954d84cf7b78c7da 100644 --- a/crates/language_models/src/provider/opencode.rs +++ b/crates/language_models/src/provider/opencode.rs @@ -1,5 +1,6 @@ use anyhow::Result; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; @@ -43,6 +44,7 @@ pub struct OpenCodeLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -51,30 +53,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenCodeLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = OpenCodeLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl OpenCodeLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index b71da5b7db05710ee30115ab54379c9ee4e4c750..cedbc9c3cb988375b90864ceb23a3b14fc50abdd 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -1,5 +1,6 @@ use anyhow::Result; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; @@ -38,6 +39,7 @@ pub struct VercelLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -46,30 +48,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = VercelLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = VercelLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl VercelLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/language_models/src/provider/vercel_ai_gateway.rs b/crates/language_models/src/provider/vercel_ai_gateway.rs index 78f900de0c94fd3bbbff3962e92d1a8cb9f3e118..66767edd809531b4b020263654922d742a1a04be 100644 --- a/crates/language_models/src/provider/vercel_ai_gateway.rs +++ b/crates/language_models/src/provider/vercel_ai_gateway.rs @@ -1,5 +1,6 @@ use anyhow::Result; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; @@ -41,6 +42,7 @@ pub struct VercelAiGatewayLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, @@ -52,16 +54,26 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx); - let task = self - .api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx); + let task = self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.spawn(async move |this, cx| { let result = task.await; @@ -100,7 +112,11 @@ impl State { } impl VercelAiGatewayLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::({ let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone(); @@ -116,6 +132,7 @@ impl VercelAiGatewayLanguageModelProvider { .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, http_client: http_client.clone(), available_models: Vec::new(), fetch_models_task: None, diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index c00637bce7e67b624f5cdcae9aebe43fb43971f8..88189864c7b4b650a24afb2b872c1d6105cf9782 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -1,5 +1,6 @@ use anyhow::Result; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window}; use http_client::HttpClient; @@ -39,6 +40,7 @@ pub struct XAiLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + credentials_provider: Arc, } impl State { @@ -47,30 +49,51 @@ impl State { } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = XAiLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) + self.api_key_state.store( + api_url, + api_key, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let credentials_provider = self.credentials_provider.clone(); let api_url = XAiLanguageModelProvider::api_url(cx); - self.api_key_state - .load_if_needed(api_url, |this| &mut this.api_key_state, cx) + self.api_key_state.load_if_needed( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ) } } impl XAiLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn new( + http_client: Arc, + credentials_provider: Arc, + cx: &mut App, + ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { + let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); - this.api_key_state - .handle_url_change(api_url, |this| &mut this.api_key_state, cx); + this.api_key_state.handle_url_change( + api_url, + |this| &mut this.api_key_state, + credentials_provider, + cx, + ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), + credentials_provider, } }); diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index ccffbd29f4bd03b0d4bb0a070f4229a517597468..cd037786a399eb979fd5d9053c57efe3100dd473 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -98,6 +98,7 @@ watch.workspace = true wax.workspace = true which.workspace = true worktree.workspace = true +zed_credentials_provider.workspace = true zeroize.workspace = true zlog.workspace = true ztracing.workspace = true diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index 395056384a79d39c978e14643166148685ea0b90..7b9fc16f10022805ea62df2f8b3df279fc96ae3d 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -684,7 +684,7 @@ impl ContextServerStore { let server_url = url.clone(); let id = id.clone(); cx.spawn(async move |_this, cx| { - let credentials_provider = cx.update(|cx| ::global(cx)); + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await { log::warn!("{} failed to clear OAuth session on removal: {}", id, err); @@ -797,8 +797,7 @@ impl ContextServerStore { if configuration.has_static_auth_header() { None } else { - let credentials_provider = - cx.update(|cx| ::global(cx)); + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); let http_client = cx.update(|cx| cx.http_client()); match Self::load_session(&credentials_provider, url, &cx).await { @@ -1070,7 +1069,7 @@ impl ContextServerStore { .context("Failed to start OAuth callback server")?; let http_client = cx.update(|cx| cx.http_client()); - let credentials_provider = cx.update(|cx| ::global(cx)); + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); let server_url = match configuration.as_ref() { ContextServerConfiguration::Http { url, .. } => url.clone(), _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"), @@ -1233,7 +1232,7 @@ impl ContextServerStore { self.stop_server(&id, cx)?; cx.spawn(async move |this, cx| { - let credentials_provider = cx.update(|cx| ::global(cx)); + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await { log::error!("{} failed to clear OAuth session: {}", id, err); } @@ -1451,7 +1450,7 @@ async fn resolve_start_failure( // (e.g. timeout because the server rejected the token silently). Clear it // so the next start attempt can get a clean 401 and trigger the auth flow. if www_authenticate.is_none() { - let credentials_provider = cx.update(|cx| ::global(cx)); + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await { Ok(Some(_)) => { log::info!("{id} start failed with a cached OAuth session present; clearing it"); diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 9d79481596f4b4259760ff6c2f19f8f5cf709d1e..0228f6886fc741505ffbe02fe82242d5f3e1dfd4 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -59,6 +59,7 @@ ui.workspace = true util.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_credentials_provider.workspace = true [dev-dependencies] fs = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 193be67aad4760763637f116fad23066438b5b61..a2a457d33eb0788ff0bed981ce5666423890f05a 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -185,9 +185,15 @@ fn render_api_key_provider( cx: &mut Context, ) -> impl IntoElement { let weak_page = cx.weak_entity(); + let credentials_provider = zed_credentials_provider::global(cx); _ = window.use_keyed_state(current_url(cx), cx, |_, cx| { let task = api_key_state.update(cx, |key_state, cx| { - key_state.load_if_needed(current_url(cx), |state| state, cx) + key_state.load_if_needed( + current_url(cx), + |state| state, + credentials_provider.clone(), + cx, + ) }); cx.spawn(async move |_, cx| { task.await.ok(); @@ -208,10 +214,17 @@ fn render_api_key_provider( }); let write_key = move |api_key: Option, cx: &mut App| { + let credentials_provider = zed_credentials_provider::global(cx); api_key_state .update(cx, |key_state, cx| { let url = current_url(cx); - key_state.store(url, api_key, |key_state| key_state, cx) + key_state.store( + url, + api_key, + |key_state| key_state, + credentials_provider, + cx, + ) }) .detach_and_log_err(cx); }; diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 17addd24d445a666138a1b37fef872beedd07aed..11227d8fb5c7152dc5b7e03b95fadea6cb714717 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; -use client::{Client, UserStore}; +use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token}; use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Task}; use http_client::{HttpClient, Method}; -use language_model::{LlmApiToken, NeedsLlmTokenRefresh}; +use language_model::LlmApiToken; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { @@ -30,7 +30,7 @@ pub struct State { impl State { pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - let llm_api_token = LlmApiToken::global(cx); + let llm_api_token = global_llm_token(cx); Self { client, @@ -73,8 +73,8 @@ async fn perform_web_search( let http_client = &client.http_client(); let mut retries_remaining = MAX_RETRIES; - let mut token = llm_api_token - .acquire(&client, organization_id.clone()) + let mut token = client + .acquire_llm_token(&llm_api_token, organization_id.clone()) .await?; loop { @@ -100,8 +100,8 @@ async fn perform_web_search( response.body_mut().read_to_string(&mut body).await?; return Ok(serde_json::from_str(&body)?); } else if response.needs_llm_token_refresh() { - token = llm_api_token - .refresh(&client, organization_id.clone()) + token = client + .refresh_llm_token(&llm_api_token, organization_id.clone()) .await?; retries_remaining -= 1; } else { diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 0e1cbc96ff1521626bfe8bcf62091404324132a0..902d147084ce42b34a34477593ecc755bc6aa7cc 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -10,7 +10,7 @@ use agent_ui::AgentPanel; use anyhow::{Context as _, Error, Result}; use clap::Parser; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{Client, ProxySettings, UserStore, parse_zed_link}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link}; use collab_ui::channel_view::ChannelView; use collections::HashMap; use crashes::InitCrashHandler; @@ -664,7 +664,12 @@ fn main() { ); copilot_ui::init(&app_state, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); acp_tools::init(cx); zed::telemetry_log::init(cx); diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 7e081c15a564cb996f176345ee3330f00ee6b6f3..ad44ba4128b436597a74621694ae47c661f57bd1 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()> }); prompt_store::init(cx); let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); git_ui::init(cx); project::AgentRegistryStore::init_global( diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index fbebb37985c2ebd76a63db5b4b807a8a7e0203ce..8d7759948fcabba7388a5c63e0bfa6710aa21f74 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -5189,7 +5189,12 @@ mod tests { cx, ); image_viewer::init(cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); git_graph::init(cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 8c9e74a42e6c3ddb2b340ac58da39752009825f0..d09dc07af839a681cea96d43217c4217927864d5 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -313,7 +313,12 @@ mod tests { let app_state = cx.update(|cx| { let app_state = AppState::test(cx); client::init(&app_state.client, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); editor::init(cx); app_state }); diff --git a/crates/zed_credentials_provider/Cargo.toml b/crates/zed_credentials_provider/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..9f64801d4664111bceb0fb7b9ee8c007977b6389 --- /dev/null +++ b/crates/zed_credentials_provider/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "zed_credentials_provider" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/zed_credentials_provider.rs" + +[dependencies] +anyhow.workspace = true +credentials_provider.workspace = true +futures.workspace = true +gpui.workspace = true +paths.workspace = true +release_channel.workspace = true +serde.workspace = true +serde_json.workspace = true diff --git a/crates/zed_credentials_provider/LICENSE-GPL b/crates/zed_credentials_provider/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/zed_credentials_provider/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zed_credentials_provider/src/zed_credentials_provider.rs b/crates/zed_credentials_provider/src/zed_credentials_provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..6705e58d400b1a66990f2451d318b5950ea08dde --- /dev/null +++ b/crates/zed_credentials_provider/src/zed_credentials_provider.rs @@ -0,0 +1,181 @@ +use std::collections::HashMap; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::{Arc, LazyLock}; + +use anyhow::Result; +use credentials_provider::CredentialsProvider; +use futures::FutureExt as _; +use gpui::{App, AsyncApp, Global}; +use release_channel::ReleaseChannel; + +/// An environment variable whose presence indicates that the system keychain +/// should be used in development. +/// +/// By default, running Zed in development uses the development credentials +/// provider. Setting this environment variable allows you to interact with the +/// system keychain (for instance, if you need to test something). +/// +/// Only works in development. Setting this environment variable in other +/// release channels is a no-op. +static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock = LazyLock::new(|| { + std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty()) +}); + +pub struct ZedCredentialsProvider(pub Arc); + +impl Global for ZedCredentialsProvider {} + +/// Returns the global [`CredentialsProvider`]. +pub fn init_global(cx: &mut App) { + // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it + // seems like this is a false positive from Clippy. + #[allow(clippy::arc_with_non_send_sync)] + let provider = new(cx); + cx.set_global(ZedCredentialsProvider(provider)); +} + +pub fn global(cx: &App) -> Arc { + cx.try_global::() + .map(|provider| provider.0.clone()) + .unwrap_or_else(|| new(cx)) +} + +fn new(cx: &App) -> Arc { + let use_development_provider = match ReleaseChannel::try_global(cx) { + Some(ReleaseChannel::Dev) => { + // In development we default to using the development + // credentials provider to avoid getting spammed by relentless + // keychain access prompts. + // + // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment + // variable is set, we will use the actual keychain. + !*ZED_DEVELOPMENT_USE_KEYCHAIN + } + Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) | None => { + false + } + }; + + if use_development_provider { + Arc::new(DevelopmentCredentialsProvider::new()) + } else { + Arc::new(KeychainCredentialsProvider) + } +} + +/// A credentials provider that stores credentials in the system keychain. +struct KeychainCredentialsProvider; + +impl CredentialsProvider for KeychainCredentialsProvider { + fn read_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>> { + async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local() + } + + fn write_credentials<'a>( + &'a self, + url: &'a str, + username: &'a str, + password: &'a [u8], + cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + cx.update(move |cx| cx.write_credentials(url, username, password)) + .await + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local() + } +} + +/// A credentials provider that stores credentials in a local file. +/// +/// This MUST only be used in development, as this is not a secure way of storing +/// credentials on user machines. +/// +/// Its existence is purely to work around the annoyance of having to constantly +/// re-allow access to the system keychain when developing Zed. +struct DevelopmentCredentialsProvider { + path: PathBuf, +} + +impl DevelopmentCredentialsProvider { + fn new() -> Self { + let path = paths::config_dir().join("development_credentials"); + + Self { path } + } + + fn load_credentials(&self) -> Result)>> { + let json = std::fs::read(&self.path)?; + let credentials: HashMap)> = serde_json::from_slice(&json)?; + + Ok(credentials) + } + + fn save_credentials(&self, credentials: &HashMap)>) -> Result<()> { + let json = serde_json::to_string(credentials)?; + std::fs::write(&self.path, json)?; + + Ok(()) + } +} + +impl CredentialsProvider for DevelopmentCredentialsProvider { + fn read_credentials<'a>( + &'a self, + url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>> { + async move { + Ok(self + .load_credentials() + .unwrap_or_default() + .get(url) + .cloned()) + } + .boxed_local() + } + + fn write_credentials<'a>( + &'a self, + url: &'a str, + username: &'a str, + password: &'a [u8], + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + let mut credentials = self.load_credentials().unwrap_or_default(); + credentials.insert(url.to_string(), (username.to_string(), password.to_vec())); + + self.save_credentials(&credentials) + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + let mut credentials = self.load_credentials()?; + credentials.remove(url); + + self.save_credentials(&credentials) + } + .boxed_local() + } +} diff --git a/crates/zed_env_vars/Cargo.toml b/crates/zed_env_vars/Cargo.toml index 1cf32174c351c28ec7eb16deab7b7986655d4a48..bf863b742568f3f607ba7cb54bc8fc267f045cc9 100644 --- a/crates/zed_env_vars/Cargo.toml +++ b/crates/zed_env_vars/Cargo.toml @@ -15,4 +15,4 @@ path = "src/zed_env_vars.rs" default = [] [dependencies] -gpui.workspace = true +env_var.workspace = true diff --git a/crates/zed_env_vars/src/zed_env_vars.rs b/crates/zed_env_vars/src/zed_env_vars.rs index e601cc9536602ac943bd76bf1bfd8b8ac8979dd9..13451911295735762074bcb1cf152470afa55c36 100644 --- a/crates/zed_env_vars/src/zed_env_vars.rs +++ b/crates/zed_env_vars/src/zed_env_vars.rs @@ -1,45 +1,6 @@ -use gpui::SharedString; +pub use env_var::{EnvVar, bool_env_var, env_var}; use std::sync::LazyLock; /// Whether Zed is running in stateless mode. /// When true, Zed will use in-memory databases instead of persistent storage. pub static ZED_STATELESS: LazyLock = bool_env_var!("ZED_STATELESS"); - -#[derive(Clone)] -pub struct EnvVar { - pub name: SharedString, - /// Value of the environment variable. Also `None` when set to an empty string. - pub value: Option, -} - -impl EnvVar { - pub fn new(name: SharedString) -> Self { - let value = std::env::var(name.as_str()).ok(); - if value.as_ref().is_some_and(|v| v.is_empty()) { - Self { name, value: None } - } else { - Self { name, value } - } - } - - pub fn or(self, other: EnvVar) -> EnvVar { - if self.value.is_some() { self } else { other } - } -} - -/// Creates a `LazyLock` expression for use in a `static` declaration. -#[macro_export] -macro_rules! env_var { - ($name:expr) => { - ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into())) - }; -} - -/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the -/// environment variable exists and is non-empty. -#[macro_export] -macro_rules! bool_env_var { - ($name:expr) => { - ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some()) - }; -}