language_model: Decouple from Zed-specific implementation details (#52913)

Jakub Konka created

This PR decouples `language_model`'s dependence on Zed-specific
implementation details. In particular
* `credentials_provider` is split into a generic `credentials_provider`
crate that provides a trait, and `zed_credentials_provider` that
implements the said trait for Zed-specific providers and has functions
that can populate a global state with them
* `zed_env_vars` is split into a generic `env_var` crate that provides
generic tooling for managing env vars, and `zed_env_vars` that contains
Zed-specific statics
* `client` is now dependent on `language_model` and not vice versa

Release Notes:

- N/A

Change summary

Cargo.lock                                                        |  40 
Cargo.toml                                                        |   4 
crates/agent/src/edit_agent/evals.rs                              |   5 
crates/agent/src/tests/mod.rs                                     |   8 
crates/agent/src/tools/evals/streaming_edit_file.rs               |   5 
crates/agent_servers/Cargo.toml                                   |   2 
crates/agent_servers/src/custom.rs                                |   3 
crates/agent_servers/src/e2e_tests.rs                             |   4 
crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs |   2 
crates/agent_ui/src/agent_diff.rs                                 |   4 
crates/agent_ui/src/inline_assistant.rs                           |   5 
crates/client/Cargo.toml                                          |   3 
crates/client/src/client.rs                                       |  72 
crates/client/src/llm_token.rs                                    | 116 
crates/codestral/Cargo.toml                                       |   1 
crates/codestral/src/codestral.rs                                 |   3 
crates/credentials_provider/Cargo.toml                            |   4 
crates/credentials_provider/src/credentials_provider.rs           | 167 
crates/edit_prediction/Cargo.toml                                 |   2 
crates/edit_prediction/src/capture_example.rs                     |   4 
crates/edit_prediction/src/edit_prediction.rs                     |  38 
crates/edit_prediction/src/edit_prediction_tests.rs               |   9 
crates/edit_prediction/src/mercury.rs                             |   7 
crates/edit_prediction/src/open_ai_compatible.rs                  |   3 
crates/edit_prediction_cli/src/headless.rs                        |   5 
crates/env_var/Cargo.toml                                         |  15 
crates/env_var/LICENSE-GPL                                        |   1 
crates/env_var/src/env_var.rs                                     |  40 
crates/eval_cli/src/headless.rs                                   |   5 
crates/language_model/Cargo.toml                                  |   3 
crates/language_model/src/api_key.rs                              |  22 
crates/language_model/src/language_model.rs                       |  13 
crates/language_model/src/model/cloud_model.rs                    | 158 
crates/language_models/src/language_models.rs                     |  78 
crates/language_models/src/provider/anthropic.rs                  |  37 
crates/language_models/src/provider/bedrock.rs                    |  14 
crates/language_models/src/provider/cloud.rs                      |  29 
crates/language_models/src/provider/deepseek.rs                   |  37 
crates/language_models/src/provider/google.rs                     |  37 
crates/language_models/src/provider/lmstudio.rs                   |  45 
crates/language_models/src/provider/mistral.rs                    |  37 
crates/language_models/src/provider/ollama.rs                     |  38 
crates/language_models/src/provider/open_ai.rs                    |  37 
crates/language_models/src/provider/open_ai_compatible.rs         |  31 
crates/language_models/src/provider/open_router.rs                |  29 
crates/language_models/src/provider/opencode.rs                   |  37 
crates/language_models/src/provider/vercel.rs                     |  37 
crates/language_models/src/provider/vercel_ai_gateway.rs          |  29 
crates/language_models/src/provider/x_ai.rs                       |  37 
crates/project/Cargo.toml                                         |   1 
crates/project/src/context_server_store.rs                        |  11 
crates/settings_ui/Cargo.toml                                     |   1 
crates/settings_ui/src/pages/edit_prediction_provider_setup.rs    |  17 
crates/web_search_providers/src/cloud.rs                          |  14 
crates/zed/src/main.rs                                            |   9 
crates/zed/src/visual_test_runner.rs                              |   7 
crates/zed/src/zed.rs                                             |   7 
crates/zed/src/zed/edit_prediction_registry.rs                    |   7 
crates/zed_credentials_provider/Cargo.toml                        |  22 
crates/zed_credentials_provider/LICENSE-GPL                       |   1 
crates/zed_credentials_provider/src/zed_credentials_provider.rs   | 181 +
crates/zed_env_vars/Cargo.toml                                    |   2 
crates/zed_env_vars/src/zed_env_vars.rs                           |  41 
63 files changed, 1,122 insertions(+), 561 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -260,7 +260,6 @@ dependencies = [
  "chrono",
  "client",
  "collections",
- "credentials_provider",
  "env_logger 0.11.8",
  "feature_flags",
  "fs",
@@ -289,6 +288,7 @@ dependencies = [
  "util",
  "uuid",
  "watch",
+ "zed_credentials_provider",
 ]
 
 [[package]]
@@ -2856,6 +2856,7 @@ dependencies = [
  "chrono",
  "clock",
  "cloud_api_client",
+ "cloud_api_types",
  "cloud_llm_client",
  "collections",
  "credentials_provider",
@@ -2869,6 +2870,7 @@ dependencies = [
  "http_client",
  "http_client_tls",
  "httparse",
+ "language_model",
  "log",
  "objc2-foundation",
  "parking_lot",
@@ -2900,6 +2902,7 @@ dependencies = [
  "util",
  "windows 0.61.3",
  "worktree",
+ "zed_credentials_provider",
 ]
 
 [[package]]
@@ -3059,6 +3062,7 @@ dependencies = [
  "serde",
  "serde_json",
  "text",
+ "zed_credentials_provider",
  "zeta_prompt",
 ]
 
@@ -4035,12 +4039,8 @@ name = "credentials_provider"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "futures 0.3.31",
  "gpui",
- "paths",
- "release_channel",
  "serde",
- "serde_json",
 ]
 
 [[package]]
@@ -5115,6 +5115,7 @@ dependencies = [
  "collections",
  "copilot",
  "copilot_ui",
+ "credentials_provider",
  "ctor",
  "db",
  "edit_prediction_context",
@@ -5157,6 +5158,7 @@ dependencies = [
  "workspace",
  "worktree",
  "zed_actions",
+ "zed_credentials_provider",
  "zeta_prompt",
  "zlog",
  "zstd",
@@ -5583,6 +5585,13 @@ dependencies = [
  "log",
 ]
 
+[[package]]
+name = "env_var"
+version = "0.1.0"
+dependencies = [
+ "gpui",
+]
+
 [[package]]
 name = "envy"
 version = "0.4.2"
@@ -9315,12 +9324,12 @@ dependencies = [
  "anthropic",
  "anyhow",
  "base64 0.22.1",
- "client",
  "cloud_api_client",
  "cloud_api_types",
  "cloud_llm_client",
  "collections",
  "credentials_provider",
+ "env_var",
  "futures 0.3.31",
  "gpui",
  "http_client",
@@ -9336,7 +9345,6 @@ dependencies = [
  "smol",
  "thiserror 2.0.17",
  "util",
- "zed_env_vars",
 ]
 
 [[package]]
@@ -13137,6 +13145,7 @@ dependencies = [
  "wax",
  "which 6.0.3",
  "worktree",
+ "zed_credentials_provider",
  "zeroize",
  "zlog",
  "ztracing",
@@ -15746,6 +15755,7 @@ dependencies = [
  "util",
  "workspace",
  "zed_actions",
+ "zed_credentials_provider",
 ]
 
 [[package]]
@@ -22180,10 +22190,24 @@ dependencies = [
 ]
 
 [[package]]
-name = "zed_env_vars"
+name = "zed_credentials_provider"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
+ "credentials_provider",
+ "futures 0.3.31",
  "gpui",
+ "paths",
+ "release_channel",
+ "serde",
+ "serde_json",
+]
+
+[[package]]
+name = "zed_env_vars"
+version = "0.1.0"
+dependencies = [
+ "env_var",
 ]
 
 [[package]]

Cargo.toml 🔗

@@ -61,6 +61,7 @@ members = [
     "crates/edit_prediction_ui",
     "crates/editor",
     "crates/encoding_selector",
+    "crates/env_var",
     "crates/etw_tracing",
     "crates/eval_cli",
     "crates/eval_utils",
@@ -220,6 +221,7 @@ members = [
     "crates/x_ai",
     "crates/zed",
     "crates/zed_actions",
+    "crates/zed_credentials_provider",
     "crates/zed_env_vars",
     "crates/zeta_prompt",
     "crates/zlog",
@@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" }
 diagnostics = { path = "crates/diagnostics" }
 editor = { path = "crates/editor" }
 encoding_selector = { path = "crates/encoding_selector" }
+env_var = { path = "crates/env_var" }
 etw_tracing = { path = "crates/etw_tracing" }
 eval_utils = { path = "crates/eval_utils" }
 extension = { path = "crates/extension" }
@@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" }
 x_ai = { path = "crates/x_ai" }
 zed = { path = "crates/zed" }
 zed_actions = { path = "crates/zed_actions" }
+zed_credentials_provider = { path = "crates/zed_credentials_provider" }
 zed_env_vars = { path = "crates/zed_env_vars" }
 edit_prediction = { path = "crates/edit_prediction" }
 zeta_prompt = { path = "crates/zeta_prompt" }

crates/agent/src/edit_agent/evals.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput,
 };
 use Role::*;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
 use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
 use fs::FakeFs;
 use futures::{FutureExt, future::LocalBoxFuture};
@@ -1423,7 +1423,8 @@ impl EditAgentTest {
             let client = Client::production(cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
             settings::init(cx);
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             language_models::init(user_store, client.clone(), cx);
         });
 

crates/agent/src/tests/mod.rs 🔗

@@ -6,7 +6,7 @@ use acp_thread::{
 use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
 use anyhow::Result;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
 use collections::IndexMap;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use feature_flags::FeatureFlagAppExt as _;
@@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
         let clock = Arc::new(clock::FakeSystemClock::new());
         let client = Client::new(clock, http_client, cx);
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        language_model::init(user_store.clone(), client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
         language_models::init(user_store, client.clone(), cx);
         LanguageModelRegistry::test(cx);
     });
@@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                 cx.set_http_client(Arc::new(http_client));
                 let client = Client::production(cx);
                 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-                language_model::init(user_store.clone(), client.clone(), cx);
+                language_model::init(cx);
+                RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
                 language_models::init(user_store, client.clone(), cx);
             }
         };

crates/agent/src/tools/evals/streaming_edit_file.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
 };
 use Role::*;
 use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
 use fs::FakeFs;
 use futures::{FutureExt, StreamExt, future::LocalBoxFuture};
 use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
@@ -274,7 +274,8 @@ impl StreamingEditToolTest {
             cx.set_http_client(http_client);
             let client = Client::production(cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             language_models::init(user_store, client, cx);
         });
 

crates/agent_servers/Cargo.toml 🔗

@@ -32,7 +32,6 @@ futures.workspace = true
 gpui.workspace = true
 feature_flags.workspace = true
 gpui_tokio = { workspace = true, optional = true }
-credentials_provider.workspace = true
 google_ai.workspace = true
 http_client.workspace = true
 indoc.workspace = true
@@ -53,6 +52,7 @@ terminal.workspace = true
 uuid.workspace = true
 util.workspace = true
 watch.workspace = true
+zed_credentials_provider.workspace = true
 
 [target.'cfg(unix)'.dependencies]
 libc.workspace = true

crates/agent_servers/src/custom.rs 🔗

@@ -3,7 +3,6 @@ use acp_thread::AgentConnection;
 use agent_client_protocol as acp;
 use anyhow::{Context as _, Result};
 use collections::HashSet;
-use credentials_provider::CredentialsProvider;
 use fs::Fs;
 use gpui::{App, AppContext as _, Entity, Task};
 use language_model::{ApiKey, EnvVar};
@@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
     if let Some(key) = env_var.value {
         return Task::ready(Ok(key));
     }
-    let credentials_provider = <dyn CredentialsProvider>::global(cx);
+    let credentials_provider = zed_credentials_provider::global(cx);
     let api_url = google_ai::API_URL.to_string();
     cx.spawn(async move |cx| {
         Ok(

crates/agent_servers/src/e2e_tests.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{AgentServer, AgentServerDelegate};
 use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
 use agent_client_protocol as acp;
+use client::RefreshLlmTokenListener;
 use futures::{FutureExt, StreamExt, channel::mpsc, select};
 use gpui::AppContext;
 use gpui::{Entity, TestAppContext};
@@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
         cx.set_http_client(Arc::new(http_client));
         let client = client::Client::production(cx);
         let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
-        language_model::init(user_store, client, cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store, cx);
 
         #[cfg(test)]
         project::agent_server_store::AllAgentServersSettings::override_global(

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1809,7 +1809,7 @@ mod tests {
             cx.set_global(settings_store);
             prompt_store::init(cx);
             theme_settings::init(theme::LoadThemes::JustBase, cx);
-            language_model::init_settings(cx);
+            language_model::init(cx);
         });
 
         let fs = FakeFs::new(cx.executor());
@@ -1966,7 +1966,7 @@ mod tests {
             cx.set_global(settings_store);
             prompt_store::init(cx);
             theme_settings::init(theme::LoadThemes::JustBase, cx);
-            language_model::init_settings(cx);
+            language_model::init(cx);
             workspace::register_project_item::<Editor>(cx);
         });
 

crates/agent_ui/src/inline_assistant.rs 🔗

@@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
 pub mod evals {
     use crate::InlineAssistant;
     use agent::ThreadStore;
-    use client::{Client, UserStore};
+    use client::{Client, RefreshLlmTokenListener, UserStore};
     use editor::{Editor, MultiBuffer, MultiBufferOffset};
     use eval_utils::{EvalOutput, NoProcessor};
     use fs::FakeFs;
@@ -2091,7 +2091,8 @@ pub mod evals {
             client::init(&client, cx);
             workspace::init(app_state.clone(), cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             language_models::init(user_store, client.clone(), cx);
 
             cx.set_global(inline_assistant);

crates/client/Cargo.toml 🔗

@@ -22,6 +22,7 @@ base64.workspace = true
 chrono = { workspace = true, features = ["serde"] }
 clock.workspace = true
 cloud_api_client.workspace = true
+cloud_api_types.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
 credentials_provider.workspace = true
@@ -35,6 +36,7 @@ gpui_tokio.workspace = true
 http_client.workspace = true
 http_client_tls.workspace = true
 httparse = "1.10"
+language_model.workspace = true
 log.workspace = true
 parking_lot.workspace = true
 paths.workspace = true
@@ -60,6 +62,7 @@ tokio.workspace = true
 url.workspace = true
 util.workspace = true
 worktree.workspace = true
+zed_credentials_provider.workspace = true
 
 [dev-dependencies]
 clock = { workspace = true, features = ["test-support"] }

crates/client/src/client.rs 🔗

@@ -1,6 +1,7 @@
 #[cfg(any(test, feature = "test-support"))]
 pub mod test;
 
+mod llm_token;
 mod proxy;
 pub mod telemetry;
 pub mod user;
@@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
     http::{HeaderValue, Request, StatusCode},
 };
 use clock::SystemClock;
-use cloud_api_client::CloudApiClient;
 use cloud_api_client::websocket_protocol::MessageToClient;
+use cloud_api_client::{ClientApiError, CloudApiClient};
+use cloud_api_types::OrganizationId;
 use credentials_provider::CredentialsProvider;
 use feature_flags::FeatureFlagAppExt as _;
 use futures::{
@@ -24,6 +26,7 @@ use futures::{
 };
 use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
 use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
+use language_model::LlmApiToken;
 use parking_lot::{Mutex, RwLock};
 use postage::watch;
 use proxy::connect_proxy_stream;
@@ -51,6 +54,7 @@ use tokio::net::TcpStream;
 use url::Url;
 use util::{ConnectionResult, ResultExt};
 
+pub use llm_token::*;
 pub use rpc::*;
 pub use telemetry_events::Event;
 pub use user::*;
@@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider {
 impl ClientCredentialsProvider {
     pub fn new(cx: &App) -> Self {
         Self {
-            provider: <dyn CredentialsProvider>::global(cx),
+            provider: zed_credentials_provider::global(cx),
         }
     }
 
@@ -568,6 +572,10 @@ impl Client {
         self.http.clone()
     }
 
+    pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
+        self.credentials_provider.provider.clone()
+    }
+
     pub fn cloud_client(&self) -> Arc<CloudApiClient> {
         self.cloud_client.clone()
     }
@@ -1513,6 +1521,66 @@ impl Client {
         })
     }
 
+    pub async fn acquire_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .acquire(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
+            }
+            Err(err) => Err(anyhow::Error::from(err)),
+        }
+    }
+
+    pub async fn refresh_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .refresh(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+            }
+            Err(err) => return Err(anyhow::Error::from(err)),
+        }
+    }
+
+    pub async fn clear_and_refresh_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .clear_and_refresh(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+            }
+            Err(err) => return Err(anyhow::Error::from(err)),
+        }
+    }
+
     pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
         self.state.write().credentials = None;
         self.cloud_client.clear_credentials();

crates/client/src/llm_token.rs 🔗

@@ -0,0 +1,116 @@
+use super::{Client, UserStore};
+use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
+use gpui::{
+    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
+use language_model::LlmApiToken;
+use std::sync::Arc;
+
+pub trait NeedsLlmTokenRefresh {
+    /// Returns whether the LLM token needs to be refreshed.
+    fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+    fn needs_llm_token_refresh(&self) -> bool {
+        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
+    }
+}
+
+enum TokenRefreshMode {
+    Refresh,
+    ClearAndRefresh,
+}
+
+pub fn global_llm_token(cx: &App) -> LlmApiToken {
+    RefreshLlmTokenListener::global(cx)
+        .read(cx)
+        .llm_api_token
+        .clone()
+}
+
+struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct LlmTokenRefreshedEvent;
+
+pub struct RefreshLlmTokenListener {
+    client: Arc<Client>,
+    user_store: Entity<UserStore>,
+    llm_api_token: LlmApiToken,
+    _subscription: Subscription,
+}
+
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
+        cx.set_global(GlobalRefreshLlmTokenListener(listener));
+    }
+
+    pub fn global(cx: &App) -> Entity<Self> {
+        GlobalRefreshLlmTokenListener::global(cx).0.clone()
+    }
+
+    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+        client.add_message_to_client_handler({
+            let this = cx.weak_entity();
+            move |message, cx| {
+                if let Some(this) = this.upgrade() {
+                    Self::handle_refresh_llm_token(this, message, cx);
+                }
+            }
+        });
+
+        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
+            if matches!(event, super::user::Event::OrganizationChanged) {
+                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
+            }
+        });
+
+        Self {
+            client,
+            user_store,
+            llm_api_token: LlmApiToken::default(),
+            _subscription: subscription,
+        }
+    }
+
+    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
+        let client = self.client.clone();
+        let llm_api_token = self.llm_api_token.clone();
+        let organization_id = self
+            .user_store
+            .read(cx)
+            .current_organization()
+            .map(|organization| organization.id.clone());
+        cx.spawn(async move |this, cx| {
+            match mode {
+                TokenRefreshMode::Refresh => {
+                    client
+                        .refresh_llm_token(&llm_api_token, organization_id)
+                        .await?;
+                }
+                TokenRefreshMode::ClearAndRefresh => {
+                    client
+                        .clear_and_refresh_llm_token(&llm_api_token, organization_id)
+                        .await?;
+                }
+            }
+            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+        })
+        .detach_and_log_err(cx);
+    }
+
+    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
+        match message {
+            MessageToClient::UserUpdated => {
+                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+            }
+        }
+    }
+}

crates/codestral/Cargo.toml 🔗

@@ -22,6 +22,7 @@ log.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 text.workspace = true
+zed_credentials_provider.workspace = true
 zeta_prompt.workspace = true
 
 [dev-dependencies]

crates/codestral/src/codestral.rs 🔗

@@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
 }
 
 pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+    let credentials_provider = zed_credentials_provider::global(cx);
     let api_url = codestral_api_url(cx);
     codestral_api_key_state(cx).update(cx, |key_state, cx| {
-        key_state.load_if_needed(api_url, |s| s, cx)
+        key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
     })
 }
 

crates/credentials_provider/Cargo.toml 🔗

@@ -13,9 +13,5 @@ path = "src/credentials_provider.rs"
 
 [dependencies]
 anyhow.workspace = true
-futures.workspace = true
 gpui.workspace = true
-paths.workspace = true
-release_channel.workspace = true
 serde.workspace = true
-serde_json.workspace = true

crates/credentials_provider/src/credentials_provider.rs 🔗

@@ -1,26 +1,8 @@
-use std::collections::HashMap;
 use std::future::Future;
-use std::path::PathBuf;
 use std::pin::Pin;
-use std::sync::{Arc, LazyLock};
 
 use anyhow::Result;
-use futures::FutureExt as _;
-use gpui::{App, AsyncApp};
-use release_channel::ReleaseChannel;
-
-/// An environment variable whose presence indicates that the system keychain
-/// should be used in development.
-///
-/// By default, running Zed in development uses the development credentials
-/// provider. Setting this environment variable allows you to interact with the
-/// system keychain (for instance, if you need to test something).
-///
-/// Only works in development. Setting this environment variable in other
-/// release channels is a no-op.
-static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
-    std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
-});
+use gpui::AsyncApp;
 
 /// A provider for credentials.
 ///
@@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync {
         cx: &'a AsyncApp,
     ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 }
-
-impl dyn CredentialsProvider {
-    /// Returns the global [`CredentialsProvider`].
-    pub fn global(cx: &App) -> Arc<Self> {
-        // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
-        // seems like this is a false positive from Clippy.
-        #[allow(clippy::arc_with_non_send_sync)]
-        Self::new(cx)
-    }
-
-    fn new(cx: &App) -> Arc<Self> {
-        let use_development_provider = match ReleaseChannel::try_global(cx) {
-            Some(ReleaseChannel::Dev) => {
-                // In development we default to using the development
-                // credentials provider to avoid getting spammed by relentless
-                // keychain access prompts.
-                //
-                // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
-                // variable is set, we will use the actual keychain.
-                !*ZED_DEVELOPMENT_USE_KEYCHAIN
-            }
-            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
-            | None => false,
-        };
-
-        if use_development_provider {
-            Arc::new(DevelopmentCredentialsProvider::new())
-        } else {
-            Arc::new(KeychainCredentialsProvider)
-        }
-    }
-}
-
-/// A credentials provider that stores credentials in the system keychain.
-struct KeychainCredentialsProvider;
-
-impl CredentialsProvider for KeychainCredentialsProvider {
-    fn read_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
-        async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
-    }
-
-    fn write_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        username: &'a str,
-        password: &'a [u8],
-        cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
-        async move {
-            cx.update(move |cx| cx.write_credentials(url, username, password))
-                .await
-        }
-        .boxed_local()
-    }
-
-    fn delete_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
-        async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
-    }
-}
-
-/// A credentials provider that stores credentials in a local file.
-///
-/// This MUST only be used in development, as this is not a secure way of storing
-/// credentials on user machines.
-///
-/// Its existence is purely to work around the annoyance of having to constantly
-/// re-allow access to the system keychain when developing Zed.
-struct DevelopmentCredentialsProvider {
-    path: PathBuf,
-}
-
-impl DevelopmentCredentialsProvider {
-    fn new() -> Self {
-        let path = paths::config_dir().join("development_credentials");
-
-        Self { path }
-    }
-
-    fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
-        let json = std::fs::read(&self.path)?;
-        let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
-
-        Ok(credentials)
-    }
-
-    fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
-        let json = serde_json::to_string(credentials)?;
-        std::fs::write(&self.path, json)?;
-
-        Ok(())
-    }
-}
-
-impl CredentialsProvider for DevelopmentCredentialsProvider {
-    fn read_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        _cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
-        async move {
-            Ok(self
-                .load_credentials()
-                .unwrap_or_default()
-                .get(url)
-                .cloned())
-        }
-        .boxed_local()
-    }
-
-    fn write_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        username: &'a str,
-        password: &'a [u8],
-        _cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
-        async move {
-            let mut credentials = self.load_credentials().unwrap_or_default();
-            credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
-
-            self.save_credentials(&credentials)
-        }
-        .boxed_local()
-    }
-
-    fn delete_credentials<'a>(
-        &'a self,
-        url: &'a str,
-        _cx: &'a AsyncApp,
-    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
-        async move {
-            let mut credentials = self.load_credentials()?;
-            credentials.remove(url);
-
-            self.save_credentials(&credentials)
-        }
-        .boxed_local()
-    }
-}

crates/edit_prediction/Cargo.toml 🔗

@@ -26,6 +26,7 @@ cloud_llm_client.workspace = true
 collections.workspace = true
 copilot.workspace = true
 copilot_ui.workspace = true
+credentials_provider.workspace = true
 db.workspace = true
 edit_prediction_types.workspace = true
 edit_prediction_context.workspace = true
@@ -65,6 +66,7 @@ uuid.workspace = true
 workspace.workspace = true
 worktree.workspace = true
 zed_actions.workspace = true
+zed_credentials_provider.workspace = true
 zeta_prompt.workspace = true
 zstd.workspace = true
 

crates/edit_prediction/src/capture_example.rs 🔗

@@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String {
 mod tests {
     use super::*;
     use crate::EditPredictionStore;
+    use client::RefreshLlmTokenListener;
     use client::{Client, UserStore};
     use clock::FakeSystemClock;
     use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
@@ -548,7 +549,8 @@ mod tests {
             let http_client = FakeHttpClient::with_404_response();
             let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             EditPredictionStore::global(&client, &user_store, cx);
         })
     }

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::Result;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
 use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
 use cloud_llm_client::predict_edits_v3::{
     PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -11,6 +11,7 @@ use cloud_llm_client::{
 };
 use collections::{HashMap, HashSet};
 use copilot::{Copilot, Reinstall, SignIn, SignOut};
+use credentials_provider::CredentialsProvider;
 use db::kvp::{Dismissable, KeyValueStore};
 use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
 use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
@@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec;
 use language::language_settings::all_language_settings;
 use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
@@ -150,6 +151,7 @@ pub struct EditPredictionStore {
     rated_predictions: HashSet<EditPredictionId>,
     #[cfg(test)]
     settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 pub(crate) struct EditPredictionRejectionPayload {
@@ -746,7 +748,7 @@ impl EditPredictionStore {
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         let data_collection_choice = Self::load_data_collection_choice(cx);
 
-        let llm_token = LlmApiToken::global(cx);
+        let llm_token = global_llm_token(cx);
 
         let (reject_tx, reject_rx) = mpsc::unbounded();
         cx.background_spawn({
@@ -787,6 +789,8 @@ impl EditPredictionStore {
             .log_err();
         });
 
+        let credentials_provider = zed_credentials_provider::global(cx);
+
         let this = Self {
             projects: HashMap::default(),
             client,
@@ -807,6 +811,8 @@ impl EditPredictionStore {
             shown_predictions: Default::default(),
             #[cfg(test)]
             settled_event_callback: None,
+
+            credentials_provider,
         };
 
         this
@@ -871,7 +877,9 @@ impl EditPredictionStore {
             let experiments = cx
                 .background_spawn(async move {
                     let http_client = client.http_client();
-                    let token = llm_token.acquire(&client, organization_id).await?;
+                    let token = client
+                        .acquire_llm_token(&llm_token, organization_id.clone())
+                        .await?;
                     let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
                     let request = http_client::Request::builder()
                         .method(Method::GET)
@@ -2315,7 +2323,10 @@ impl EditPredictionStore {
                 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
             }
             EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
-            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
+            EditPredictionModel::Mercury => {
+                self.mercury
+                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
+            }
         };
 
         cx.spawn(async move |this, cx| {
@@ -2536,12 +2547,15 @@ impl EditPredictionStore {
         Res: DeserializeOwned,
     {
         let http_client = client.http_client();
-
         let mut token = if require_auth {
-            Some(llm_token.acquire(&client, organization_id.clone()).await?)
+            Some(
+                client
+                    .acquire_llm_token(&llm_token, organization_id.clone())
+                    .await?,
+            )
         } else {
-            llm_token
-                .acquire(&client, organization_id.clone())
+            client
+                .acquire_llm_token(&llm_token, organization_id.clone())
                 .await
                 .ok()
         };
@@ -2585,7 +2599,11 @@ impl EditPredictionStore {
                 return Ok((serde_json::from_slice(&body)?, usage));
             } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
                 did_retry = true;
-                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
+                token = Some(
+                    client
+                        .refresh_llm_token(&llm_token, organization_id.clone())
+                        .await?,
+                );
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1,6 +1,6 @@
 use super::*;
 use crate::udiff::apply_diff_to_string;
-use client::{UserStore, test::FakeServer};
+use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
 use clock::FakeSystemClock;
 use clock::ReplicaId;
 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -23,7 +23,7 @@ use language::{
     Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
     DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
 };
-use language_model::RefreshLlmTokenListener;
+
 use lsp::LanguageServerId;
 use parking_lot::Mutex;
 use pretty_assertions::{assert_eq, assert_matches};
@@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
         client.cloud_client().set_credentials(1, "test".into());
 
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        language_model::init(user_store.clone(), client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
         let ep_store = EditPredictionStore::global(&client, &user_store, cx);
 
         (
@@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
         cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
     let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
     cx.update(|cx| {
-        language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     });
 
     let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));

crates/edit_prediction/src/mercury.rs 🔗

@@ -5,6 +5,7 @@ use crate::{
 };
 use anyhow::{Context as _, Result};
 use cloud_llm_client::EditPredictionRejectReason;
+use credentials_provider::CredentialsProvider;
 use futures::AsyncReadExt as _;
 use gpui::{
     App, AppContext as _, Context, Entity, Global, SharedString, Task,
@@ -51,10 +52,11 @@ impl Mercury {
             debug_tx,
             ..
         }: EditPredictionModelInput,
+        credentials_provider: Arc<dyn CredentialsProvider>,
         cx: &mut Context<EditPredictionStore>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         self.api_token.update(cx, |key_state, cx| {
-            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
+            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
         });
         let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
             return Task::ready(Ok(None));
@@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
 }
 
 pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
+    let credentials_provider = zed_credentials_provider::global(cx);
     mercury_api_token(cx).update(cx, |key_state, cx| {
-        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
+        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
     })
 }
 

crates/edit_prediction/src/open_ai_compatible.rs 🔗

@@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
 pub fn load_open_ai_compatible_api_token(
     cx: &mut App,
 ) -> Task<Result<(), language_model::AuthenticateError>> {
+    let credentials_provider = zed_credentials_provider::global(cx);
     let api_url = open_ai_compatible_api_url(cx);
     open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
-        key_state.load_if_needed(api_url, |s| s, cx)
+        key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
     })
 }
 

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -1,4 +1,4 @@
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
 use db::AppDatabase;
 use extension::ExtensionHostProxy;
 use fs::RealFs;
@@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
 
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(user_store.clone(), client.clone(), cx);
+    language_model::init(cx);
+    RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/env_var/Cargo.toml 🔗

@@ -0,0 +1,15 @@
+[package]
+name = "env_var"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/env_var.rs"
+
+[dependencies]
+gpui.workspace = true

crates/env_var/src/env_var.rs 🔗

@@ -0,0 +1,40 @@
+use gpui::SharedString;
+
+#[derive(Clone)]
+pub struct EnvVar {
+    pub name: SharedString,
+    /// Value of the environment variable. Also `None` when set to an empty string.
+    pub value: Option<String>,
+}
+
+impl EnvVar {
+    pub fn new(name: SharedString) -> Self {
+        let value = std::env::var(name.as_str()).ok();
+        if value.as_ref().is_some_and(|v| v.is_empty()) {
+            Self { name, value: None }
+        } else {
+            Self { name, value }
+        }
+    }
+
+    pub fn or(self, other: EnvVar) -> EnvVar {
+        if self.value.is_some() { self } else { other }
+    }
+}
+
+/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
+#[macro_export]
+macro_rules! env_var {
+    ($name:expr) => {
+        ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
+    };
+}
+
+/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
+/// environment variable exists and is non-empty.
+#[macro_export]
+macro_rules! bool_env_var {
+    ($name:expr) => {
+        ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
+    };
+}

crates/eval_cli/src/headless.rs 🔗

@@ -1,7 +1,7 @@
 use std::path::PathBuf;
 use std::sync::Arc;
 
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
 use db::AppDatabase;
 use extension::ExtensionHostProxy;
 use fs::RealFs;
@@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
     let extension_host_proxy = ExtensionHostProxy::global(cx);
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(user_store.clone(), client.clone(), cx);
+    language_model::init(cx);
+    RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/language_model/Cargo.toml 🔗

@@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] }
 anyhow.workspace = true
 credentials_provider.workspace = true
 base64.workspace = true
-client.workspace = true
 cloud_api_client.workspace = true
 cloud_api_types.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
+env_var.workspace = true
 futures.workspace = true
 gpui.workspace = true
 http_client.workspace = true
@@ -40,7 +40,6 @@ serde_json.workspace = true
 smol.workspace = true
 thiserror.workspace = true
 util.workspace = true
-zed_env_vars.workspace = true
 
 [dev-dependencies]
 gpui = { workspace = true, features = ["test-support"] }

crates/language_model/src/api_key.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Result, anyhow};
 use credentials_provider::CredentialsProvider;
+use env_var::EnvVar;
 use futures::{FutureExt, future};
 use gpui::{AsyncApp, Context, SharedString, Task};
 use std::{
@@ -7,7 +8,6 @@ use std::{
     sync::Arc,
 };
 use util::ResultExt as _;
-use zed_env_vars::EnvVar;
 
 use crate::AuthenticateError;
 
@@ -101,6 +101,7 @@ impl ApiKeyState {
         url: SharedString,
         key: Option<String>,
         get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+        provider: Arc<dyn CredentialsProvider>,
         cx: &Context<Ent>,
     ) -> Task<Result<()>> {
         if self.is_from_env_var() {
@@ -108,18 +109,14 @@ impl ApiKeyState {
                 "bug: attempted to store API key in system keychain when API key is from env var",
             )));
         }
-        let credentials_provider = <dyn CredentialsProvider>::global(cx);
         cx.spawn(async move |ent, cx| {
             if let Some(key) = &key {
-                credentials_provider
+                provider
                     .write_credentials(&url, "Bearer", key.as_bytes(), cx)
                     .await
                     .log_err();
             } else {
-                credentials_provider
-                    .delete_credentials(&url, cx)
-                    .await
-                    .log_err();
+                provider.delete_credentials(&url, cx).await.log_err();
             }
             ent.update(cx, |ent, cx| {
                 let this = get_this(ent);
@@ -144,12 +141,13 @@ impl ApiKeyState {
         &mut self,
         url: SharedString,
         get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+        provider: Arc<dyn CredentialsProvider>,
         cx: &mut Context<Ent>,
     ) {
         if url != self.url {
             if !self.is_from_env_var() {
                 // loading will continue even though this result task is dropped
-                let _task = self.load_if_needed(url, get_this, cx);
+                let _task = self.load_if_needed(url, get_this, provider, cx);
             }
         }
     }
@@ -163,6 +161,7 @@ impl ApiKeyState {
         &mut self,
         url: SharedString,
         get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+        provider: Arc<dyn CredentialsProvider>,
         cx: &mut Context<Ent>,
     ) -> Task<Result<(), AuthenticateError>> {
         if let LoadStatus::Loaded { .. } = &self.load_status
@@ -185,7 +184,7 @@ impl ApiKeyState {
         let task = if let Some(load_task) = &self.load_task {
             load_task.clone()
         } else {
-            let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
+            let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
             self.url = url;
             self.load_status = LoadStatus::NotPresent;
             self.load_task = Some(load_task.clone());
@@ -206,14 +205,13 @@ impl ApiKeyState {
     fn load<Ent: 'static>(
         url: SharedString,
         get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+        provider: Arc<dyn CredentialsProvider>,
         cx: &Context<Ent>,
     ) -> Task<()> {
-        let credentials_provider = <dyn CredentialsProvider>::global(cx);
         cx.spawn({
             async move |ent, cx| {
                 let load_status =
-                    ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
-                        .await;
+                    ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
                 ent.update(cx, |ent, cx| {
                     let this = get_this(ent);
                     this.url = url;

crates/language_model/src/language_model.rs 🔗

@@ -11,12 +11,10 @@ pub mod tool_schema;
 pub mod fake_provider;
 
 use anyhow::{Result, anyhow};
-use client::Client;
-use client::UserStore;
 use cloud_llm_client::CompletionRequestStatus;
 use futures::FutureExt;
 use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 use http_client::{StatusCode, http};
 use icons::IconName;
 use parking_lot::Mutex;
@@ -36,15 +34,10 @@ pub use crate::registry::*;
 pub use crate::request::*;
 pub use crate::role::*;
 pub use crate::tool_schema::LanguageModelToolSchemaFormat;
+pub use env_var::{EnvVar, env_var};
 pub use provider::*;
-pub use zed_env_vars::{EnvVar, env_var};
 
-pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
-    init_settings(cx);
-    RefreshLlmTokenListener::register(client, user_store, cx);
-}
-
-pub fn init_settings(cx: &mut App) {
+pub fn init(cx: &mut App) {
     registry::init(cx);
 }
 

crates/language_model/src/model/cloud_model.rs 🔗

@@ -1,16 +1,9 @@
 use std::fmt;
 use std::sync::Arc;
 
-use anyhow::{Context as _, Result};
-use client::Client;
-use client::UserStore;
 use cloud_api_client::ClientApiError;
+use cloud_api_client::CloudApiClient;
 use cloud_api_types::OrganizationId;
-use cloud_api_types::websocket_protocol::MessageToClient;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{
-    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
-};
 use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 use thiserror::Error;
 
@@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
 pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
 impl LlmApiToken {
-    pub fn global(cx: &App) -> Self {
-        RefreshLlmTokenListener::global(cx)
-            .read(cx)
-            .llm_api_token
-            .clone()
-    }
-
     pub async fn acquire(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
+    ) -> Result<String, ClientApiError> {
         let lock = self.0.upgradable_read().await;
         if let Some(token) = lock.as_ref() {
             Ok(token.to_string())
@@ -49,6 +36,7 @@ impl LlmApiToken {
             Self::fetch(
                 RwLockUpgradableReadGuard::upgrade(lock).await,
                 client,
+                system_id,
                 organization_id,
             )
             .await
@@ -57,10 +45,11 @@ impl LlmApiToken {
 
     pub async fn refresh(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
-        Self::fetch(self.0.write().await, client, organization_id).await
+    ) -> Result<String, ClientApiError> {
+        Self::fetch(self.0.write().await, client, system_id, organization_id).await
     }
 
     /// Clears the existing token before attempting to fetch a new one.
@@ -69,28 +58,22 @@ impl LlmApiToken {
     /// leave a token for the wrong organization.
     pub async fn clear_and_refresh(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
+    ) -> Result<String, ClientApiError> {
         let mut lock = self.0.write().await;
         *lock = None;
-        Self::fetch(lock, client, organization_id).await
+        Self::fetch(lock, client, system_id, organization_id).await
     }
 
     async fn fetch(
         mut lock: RwLockWriteGuard<'_, Option<String>>,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
-        let system_id = client
-            .telemetry()
-            .system_id()
-            .map(|system_id| system_id.to_string());
-
-        let result = client
-            .cloud_client()
-            .create_llm_token(system_id, organization_id)
-            .await;
+    ) -> Result<String, ClientApiError> {
+        let result = client.create_llm_token(system_id, organization_id).await;
         match result {
             Ok(response) => {
                 *lock = Some(response.token.0.clone());
@@ -98,112 +81,7 @@ impl LlmApiToken {
             }
             Err(err) => {
                 *lock = None;
-                match err {
-                    ClientApiError::Unauthorized => {
-                        client.request_sign_out();
-                        Err(err).context("Failed to create LLM token")
-                    }
-                    ClientApiError::Other(err) => Err(err),
-                }
-            }
-        }
-    }
-}
-
-pub trait NeedsLlmTokenRefresh {
-    /// Returns whether the LLM token needs to be refreshed.
-    fn needs_llm_token_refresh(&self) -> bool;
-}
-
-impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
-    fn needs_llm_token_refresh(&self) -> bool {
-        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
-            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
-    }
-}
-
-enum TokenRefreshMode {
-    Refresh,
-    ClearAndRefresh,
-}
-
-struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct LlmTokenRefreshedEvent;
-
-pub struct RefreshLlmTokenListener {
-    client: Arc<Client>,
-    user_store: Entity<UserStore>,
-    llm_api_token: LlmApiToken,
-    _subscription: Subscription,
-}
-
-impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
-    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
-        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
-        cx.set_global(GlobalRefreshLlmTokenListener(listener));
-    }
-
-    pub fn global(cx: &App) -> Entity<Self> {
-        GlobalRefreshLlmTokenListener::global(cx).0.clone()
-    }
-
-    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        client.add_message_to_client_handler({
-            let this = cx.weak_entity();
-            move |message, cx| {
-                if let Some(this) = this.upgrade() {
-                    Self::handle_refresh_llm_token(this, message, cx);
-                }
-            }
-        });
-
-        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
-            if matches!(event, client::user::Event::OrganizationChanged) {
-                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
-            }
-        });
-
-        Self {
-            client,
-            user_store,
-            llm_api_token: LlmApiToken::default(),
-            _subscription: subscription,
-        }
-    }
-
-    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
-        let client = self.client.clone();
-        let llm_api_token = self.llm_api_token.clone();
-        let organization_id = self
-            .user_store
-            .read(cx)
-            .current_organization()
-            .map(|organization| organization.id.clone());
-        cx.spawn(async move |this, cx| {
-            match mode {
-                TokenRefreshMode::Refresh => {
-                    llm_api_token.refresh(&client, organization_id).await?;
-                }
-                TokenRefreshMode::ClearAndRefresh => {
-                    llm_api_token
-                        .clear_and_refresh(&client, organization_id)
-                        .await?;
-                }
-            }
-            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
-        })
-        .detach_and_log_err(cx);
-    }
-
-    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
-        match message {
-            MessageToClient::UserUpdated => {
-                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+                Err(err)
             }
         }
     }

crates/language_models/src/language_models.rs 🔗

@@ -3,6 +3,7 @@ use std::sync::Arc;
 use ::settings::{Settings, SettingsStore};
 use client::{Client, UserStore};
 use collections::HashSet;
+use credentials_provider::CredentialsProvider;
 use gpui::{App, Context, Entity};
 use language_model::{LanguageModelProviderId, LanguageModelRegistry};
 use provider::deepseek::DeepSeekLanguageModelProvider;
@@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider;
 pub use crate::settings::*;
 
 pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
+    let credentials_provider = client.credentials_provider();
     let registry = LanguageModelRegistry::global(cx);
     registry.update(cx, |registry, cx| {
-        register_language_model_providers(registry, user_store, client.clone(), cx);
+        register_language_model_providers(
+            registry,
+            user_store,
+            client.clone(),
+            credentials_provider.clone(),
+            cx,
+        );
     });
 
     // Subscribe to extension store events to track LLM extension installations
@@ -104,6 +112,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
             &HashSet::default(),
             &openai_compatible_providers,
             client.clone(),
+            credentials_provider.clone(),
             cx,
         );
     });
@@ -124,6 +133,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
                     &openai_compatible_providers,
                     &openai_compatible_providers_new,
                     client.clone(),
+                    credentials_provider.clone(),
                     cx,
                 );
             });
@@ -138,6 +148,7 @@ fn register_openai_compatible_providers(
     old: &HashSet<Arc<str>>,
     new: &HashSet<Arc<str>>,
     client: Arc<Client>,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     cx: &mut Context<LanguageModelRegistry>,
 ) {
     for provider_id in old {
@@ -152,6 +163,7 @@ fn register_openai_compatible_providers(
                 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
                     provider_id.clone(),
                     client.http_client(),
+                    credentials_provider.clone(),
                     cx,
                 )),
                 cx,
@@ -164,6 +176,7 @@ fn register_language_model_providers(
     registry: &mut LanguageModelRegistry,
     user_store: Entity<UserStore>,
     client: Arc<Client>,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     cx: &mut Context<LanguageModelRegistry>,
 ) {
     registry.register_provider(
@@ -177,62 +190,105 @@ fn register_language_model_providers(
     registry.register_provider(
         Arc::new(AnthropicLanguageModelProvider::new(
             client.http_client(),
+            credentials_provider.clone(),
             cx,
         )),
         cx,
     );
     registry.register_provider(
-        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(OpenAiLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(OllamaLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(LmStudioLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(DeepSeekLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(GoogleLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        MistralLanguageModelProvider::global(client.http_client(), cx),
+        MistralLanguageModelProvider::global(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        ),
         cx,
     );
     registry.register_provider(
-        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(BedrockLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
         Arc::new(OpenRouterLanguageModelProvider::new(
             client.http_client(),
+            credentials_provider.clone(),
             cx,
         )),
         cx,
     );
     registry.register_provider(
-        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(VercelLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
         Arc::new(VercelAiGatewayLanguageModelProvider::new(
             client.http_client(),
+            credentials_provider.clone(),
             cx,
         )),
         cx,
     );
     registry.register_provider(
-        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(XAiLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider.clone(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
+        Arc::new(OpenCodeLanguageModelProvider::new(
+            client.http_client(),
+            credentials_provider,
+            cx,
+        )),
         cx,
     );
     registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);

crates/language_models/src/provider/anthropic.rs 🔗

@@ -6,6 +6,7 @@ use anthropic::{
 };
 use anyhow::Result;
 use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
 use http_client::HttpClient;
@@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -59,30 +61,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = AnthropicLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = AnthropicLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl AnthropicLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/bedrock.rs 🔗

@@ -195,12 +195,13 @@ pub struct State {
     settings: Option<AmazonBedrockSettings>,
     /// Whether credentials came from environment variables (only relevant for static credentials)
     credentials_from_env: bool,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     _subscription: Subscription,
 }
 
 impl State {
     fn reset_auth(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
-        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let credentials_provider = self.credentials_provider.clone();
         cx.spawn(async move |this, cx| {
             credentials_provider
                 .delete_credentials(AMAZON_AWS_URL, cx)
@@ -220,7 +221,7 @@ impl State {
         cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
         let auth = credentials.clone().into_auth();
-        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let credentials_provider = self.credentials_provider.clone();
         cx.spawn(async move |this, cx| {
             credentials_provider
                 .write_credentials(
@@ -287,7 +288,7 @@ impl State {
         &self,
         cx: &mut Context<Self>,
     ) -> Task<Result<(), AuthenticateError>> {
-        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let credentials_provider = self.credentials_provider.clone();
         cx.spawn(async move |this, cx| {
             // Try environment variables first
             let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
@@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider {
 }
 
 impl BedrockLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| State {
             auth: None,
             settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
             credentials_from_env: false,
+            credentials_provider,
             _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
                 cx.notify();
             }),

crates/language_models/src/provider/cloud.rs 🔗

@@ -1,7 +1,9 @@
 use ai_onboarding::YoungAccountBanner;
 use anthropic::AnthropicModelMode;
 use anyhow::{Context as _, Result, anyhow};
-use client::{Client, UserStore, zed_urls};
+use client::{
+    Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
+};
 use cloud_api_types::{OrganizationId, Plan};
 use cloud_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@@ -24,10 +26,9 @@ use language_model::{
     LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
-    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
-    OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
-    RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
-    ZED_CLOUD_PROVIDER_NAME,
+    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
+    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
 };
 use release_channel::AppVersion;
 use schemars::JsonSchema;
@@ -111,7 +112,7 @@ impl State {
         cx: &mut Context<Self>,
     ) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
-        let llm_api_token = LlmApiToken::global(cx);
+        let llm_api_token = global_llm_token(cx);
         Self {
             client: client.clone(),
             llm_api_token,
@@ -226,7 +227,9 @@ impl State {
         organization_id: Option<OrganizationId>,
     ) -> Result<ListModelsResponse> {
         let http_client = &client.http_client();
-        let token = llm_api_token.acquire(&client, organization_id).await?;
+        let token = client
+            .acquire_llm_token(&llm_api_token, organization_id)
+            .await?;
 
         let request = http_client::Request::builder()
             .method(Method::GET)
@@ -414,8 +417,8 @@ impl CloudLanguageModel {
     ) -> Result<PerformLlmCompletionResponse> {
         let http_client = &client.http_client();
 
-        let mut token = llm_api_token
-            .acquire(&client, organization_id.clone())
+        let mut token = client
+            .acquire_llm_token(&llm_api_token, organization_id.clone())
             .await?;
         let mut refreshed_token = false;
 
@@ -447,8 +450,8 @@ impl CloudLanguageModel {
             }
 
             if !refreshed_token && response.needs_llm_token_refresh() {
-                token = llm_api_token
-                    .refresh(&client, organization_id.clone())
+                token = client
+                    .refresh_llm_token(&llm_api_token, organization_id.clone())
                     .await?;
                 refreshed_token = true;
                 continue;
@@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel {
                     into_google(request, model_id.clone(), GoogleModelMode::Default);
                 async move {
                     let http_client = &client.http_client();
-                    let token = llm_api_token.acquire(&client, organization_id).await?;
+                    let token = client
+                        .acquire_llm_token(&llm_api_token, organization_id)
+                        .await?;
 
                     let request_body = CountTokensBody {
                         provider: cloud_llm_client::LanguageModelProvider::Google,

crates/language_models/src/provider/deepseek.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Result, anyhow};
 use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
 use deepseek::DEEPSEEK_API_URL;
 
 use futures::Stream;
@@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -57,30 +59,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = DeepSeekLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = DeepSeekLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl DeepSeekLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/google.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Context as _, Result};
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
 use google_ai::{
     FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
@@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
@@ -76,30 +78,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = GoogleLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = GoogleLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl GoogleLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/lmstudio.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Result, anyhow};
 use collections::HashMap;
+use credentials_provider::CredentialsProvider;
 use fs::Fs;
 use futures::Stream;
 use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
@@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<lmstudio::Model>,
     fetch_model_task: Option<Task<Result<()>>>,
@@ -64,10 +66,15 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
-        let task = self
-            .api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+        let task = self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
         self.restart_fetch_models_task(cx);
         task
     }
@@ -114,10 +121,14 @@ impl State {
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
-        let _task = self
-            .api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+        let _task = self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
 
         if self.is_authenticated() {
             return Task::ready(Ok(()));
@@ -152,16 +163,29 @@ impl State {
 }
 
 impl LmStudioLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let this = Self {
             http_client: http_client.clone(),
             state: cx.new(|cx| {
                 let subscription = cx.observe_global::<SettingsStore>({
                     let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
                     move |this: &mut State, cx| {
-                        let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
-                        if &settings != new_settings {
-                            settings = new_settings.clone();
+                        let new_settings =
+                            AllLanguageModelSettings::get_global(cx).lmstudio.clone();
+                        if settings != new_settings {
+                            let credentials_provider = this.credentials_provider.clone();
+                            let api_url = Self::api_url(cx).into();
+                            this.api_key_state.handle_url_change(
+                                api_url,
+                                |this| &mut this.api_key_state,
+                                credentials_provider,
+                                cx,
+                            );
+                            settings = new_settings;
                             this.restart_fetch_models_task(cx);
                             cx.notify();
                         }
@@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider {
                         Self::api_url(cx).into(),
                         (*API_KEY_ENV_VAR).clone(),
                     ),
+                    credentials_provider,
                     http_client,
                     available_models: Default::default(),
                     fetch_model_task: None,

crates/language_models/src/provider/mistral.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Result, anyhow};
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
@@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -51,15 +53,26 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = MistralLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = MistralLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
@@ -73,20 +86,30 @@ impl MistralLanguageModelProvider {
             .map(|this| &this.0)
     }
 
-    pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
+    pub fn global(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Arc<Self> {
         if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
             return this.0.clone();
         }
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/ollama.rs 🔗

@@ -1,4 +1,5 @@
 use anyhow::{Result, anyhow};
+use credentials_provider::CredentialsProvider;
 use fs::Fs;
 use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
 use futures::{Stream, TryFutureExt, stream};
@@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     http_client: Arc<dyn HttpClient>,
     fetched_models: Vec<ollama::Model>,
     fetch_model_task: Option<Task<Result<()>>>,
@@ -65,10 +67,15 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OllamaLanguageModelProvider::api_url(cx);
-        let task = self
-            .api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+        let task = self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
 
         self.fetched_models.clear();
         cx.spawn(async move |this, cx| {
@@ -80,10 +87,14 @@ impl State {
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OllamaLanguageModelProvider::api_url(cx);
-        let task = self
-            .api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+        let task = self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
 
         // Always try to fetch models - if no API key is needed (local Ollama), it will work
         // If API key is needed and provided, it will work
@@ -157,7 +168,11 @@ impl State {
 }
 
 impl OllamaLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let this = Self {
             http_client: http_client.clone(),
             state: cx.new(|cx| {
@@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider {
                             let url_changed = last_settings.api_url != current_settings.api_url;
                             last_settings = current_settings.clone();
                             if url_changed {
+                                let credentials_provider = this.credentials_provider.clone();
+                                let api_url = Self::api_url(cx);
+                                this.api_key_state.handle_url_change(
+                                    api_url,
+                                    |this| &mut this.api_key_state,
+                                    credentials_provider,
+                                    cx,
+                                );
                                 this.fetched_models.clear();
                                 this.authenticate(cx).detach();
                             }
@@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider {
                     fetched_models: Default::default(),
                     fetch_model_task: None,
                     api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                    credentials_provider,
                 }
             }),
         };

crates/language_models/src/provider/open_ai.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Result, anyhow};
 use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
 use futures::Stream;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
@@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -63,30 +65,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenAiLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenAiLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl OpenAiLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/open_ai_compatible.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use convert_case::{Case, Casing};
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::HttpClient;
@@ -44,6 +45,7 @@ pub struct State {
     id: Arc<str>,
     api_key_state: ApiKeyState,
     settings: OpenAiCompatibleSettings,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -52,20 +54,36 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = SharedString::new(self.settings.api_url.as_str());
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = SharedString::new(self.settings.api_url.clone());
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl OpenAiCompatibleLanguageModelProvider {
-    pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        id: Arc<str>,
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
             crate::AllLanguageModelSettings::get_global(cx)
                 .openai_compatible
@@ -79,10 +97,12 @@ impl OpenAiCompatibleLanguageModelProvider {
                     return;
                 };
                 if &this.settings != &settings {
+                    let credentials_provider = this.credentials_provider.clone();
                     let api_url = SharedString::new(settings.api_url.as_str());
                     this.api_key_state.handle_url_change(
                         api_url,
                         |this| &mut this.api_key_state,
+                        credentials_provider,
                         cx,
                     );
                     this.settings = settings;
@@ -98,6 +118,7 @@ impl OpenAiCompatibleLanguageModelProvider {
                     EnvVar::new(api_key_env_var_name),
                 ),
                 settings,
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/open_router.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use collections::HashMap;
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
 use http_client::HttpClient;
@@ -42,6 +43,7 @@ pub struct OpenRouterLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<open_router::Model>,
     fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@@ -53,16 +55,26 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenRouterLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenRouterLanguageModelProvider::api_url(cx);
-        let task = self
-            .api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+        let task = self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
 
         cx.spawn(async move |this, cx| {
             let result = task.await;
@@ -114,7 +126,11 @@ impl State {
 }
 
 impl OpenRouterLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>({
                 let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
@@ -131,6 +147,7 @@ impl OpenRouterLanguageModelProvider {
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
                 http_client: http_client.clone(),
                 available_models: Vec::new(),
                 fetch_models_task: None,

crates/language_models/src/provider/opencode.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::HttpClient;
@@ -43,6 +44,7 @@ pub struct OpenCodeLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -51,30 +53,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenCodeLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = OpenCodeLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl OpenCodeLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/vercel.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::HttpClient;
@@ -38,6 +39,7 @@ pub struct VercelLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -46,30 +48,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = VercelLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = VercelLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl VercelLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/language_models/src/provider/vercel_ai_gateway.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
@@ -41,6 +42,7 @@ pub struct VercelAiGatewayLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<AvailableModel>,
     fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@@ -52,16 +54,26 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
-        let task = self
-            .api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+        let task = self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        );
 
         cx.spawn(async move |this, cx| {
             let result = task.await;
@@ -100,7 +112,11 @@ impl State {
 }
 
 impl VercelAiGatewayLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>({
                 let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
@@ -116,6 +132,7 @@ impl VercelAiGatewayLanguageModelProvider {
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
                 http_client: http_client.clone(),
                 available_models: Vec::new(),
                 fetch_models_task: None,

crates/language_models/src/provider/x_ai.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::Result;
 use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
 use http_client::HttpClient;
@@ -39,6 +40,7 @@ pub struct XAiLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
 impl State {
@@ -47,30 +49,51 @@ impl State {
     }
 
     fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = XAiLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+        self.api_key_state.store(
+            api_url,
+            api_key,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+        let credentials_provider = self.credentials_provider.clone();
         let api_url = XAiLanguageModelProvider::api_url(cx);
-        self.api_key_state
-            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+        self.api_key_state.load_if_needed(
+            api_url,
+            |this| &mut this.api_key_state,
+            credentials_provider,
+            cx,
+        )
     }
 }
 
 impl XAiLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn new(
+        http_client: Arc<dyn HttpClient>,
+        credentials_provider: Arc<dyn CredentialsProvider>,
+        cx: &mut App,
+    ) -> Self {
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+                let credentials_provider = this.credentials_provider.clone();
                 let api_url = Self::api_url(cx);
-                this.api_key_state
-                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+                this.api_key_state.handle_url_change(
+                    api_url,
+                    |this| &mut this.api_key_state,
+                    credentials_provider,
+                    cx,
+                );
                 cx.notify();
             })
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+                credentials_provider,
             }
         });
 

crates/project/Cargo.toml 🔗

@@ -98,6 +98,7 @@ watch.workspace = true
 wax.workspace = true
 which.workspace = true
 worktree.workspace = true
+zed_credentials_provider.workspace = true
 zeroize.workspace = true
 zlog.workspace = true
 ztracing.workspace = true

crates/project/src/context_server_store.rs 🔗

@@ -684,7 +684,7 @@ impl ContextServerStore {
             let server_url = url.clone();
             let id = id.clone();
             cx.spawn(async move |_this, cx| {
-                let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+                let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
                 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
                 {
                     log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
@@ -797,8 +797,7 @@ impl ContextServerStore {
                 if configuration.has_static_auth_header() {
                     None
                 } else {
-                    let credentials_provider =
-                        cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
                     let http_client = cx.update(|cx| cx.http_client());
 
                     match Self::load_session(&credentials_provider, url, &cx).await {
@@ -1070,7 +1069,7 @@ impl ContextServerStore {
             .context("Failed to start OAuth callback server")?;
 
         let http_client = cx.update(|cx| cx.http_client());
-        let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
         let server_url = match configuration.as_ref() {
             ContextServerConfiguration::Http { url, .. } => url.clone(),
             _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
@@ -1233,7 +1232,7 @@ impl ContextServerStore {
         self.stop_server(&id, cx)?;
 
         cx.spawn(async move |this, cx| {
-            let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+            let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
             if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
                 log::error!("{} failed to clear OAuth session: {}", id, err);
             }
@@ -1451,7 +1450,7 @@ async fn resolve_start_failure(
     // (e.g. timeout because the server rejected the token silently). Clear it
     // so the next start attempt can get a clean 401 and trigger the auth flow.
     if www_authenticate.is_none() {
-        let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
         match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
             Ok(Some(_)) => {
                 log::info!("{id} start failed with a cached OAuth session present; clearing it");

crates/settings_ui/Cargo.toml 🔗

@@ -59,6 +59,7 @@ ui.workspace = true
 util.workspace = true
 workspace.workspace = true
 zed_actions.workspace = true
+zed_credentials_provider.workspace = true
 
 [dev-dependencies]
 fs = { workspace = true, features = ["test-support"] }

crates/settings_ui/src/pages/edit_prediction_provider_setup.rs 🔗

@@ -185,9 +185,15 @@ fn render_api_key_provider(
     cx: &mut Context<SettingsWindow>,
 ) -> impl IntoElement {
     let weak_page = cx.weak_entity();
+    let credentials_provider = zed_credentials_provider::global(cx);
     _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
         let task = api_key_state.update(cx, |key_state, cx| {
-            key_state.load_if_needed(current_url(cx), |state| state, cx)
+            key_state.load_if_needed(
+                current_url(cx),
+                |state| state,
+                credentials_provider.clone(),
+                cx,
+            )
         });
         cx.spawn(async move |_, cx| {
             task.await.ok();
@@ -208,10 +214,17 @@ fn render_api_key_provider(
     });
 
     let write_key = move |api_key: Option<String>, cx: &mut App| {
+        let credentials_provider = zed_credentials_provider::global(cx);
         api_key_state
             .update(cx, |key_state, cx| {
                 let url = current_url(cx);
-                key_state.store(url, api_key, |key_state| key_state, cx)
+                key_state.store(
+                    url,
+                    api_key,
+                    |key_state| key_state,
+                    credentials_provider,
+                    cx,
+                )
             })
             .detach_and_log_err(cx);
     };

crates/web_search_providers/src/cloud.rs 🔗

@@ -1,13 +1,13 @@
 use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
 use cloud_api_types::OrganizationId;
 use cloud_llm_client::{WebSearchBody, WebSearchResponse};
 use futures::AsyncReadExt as _;
 use gpui::{App, AppContext, Context, Entity, Task};
 use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
 use web_search::{WebSearchProvider, WebSearchProviderId};
 
 pub struct CloudWebSearchProvider {
@@ -30,7 +30,7 @@ pub struct State {
 
 impl State {
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        let llm_api_token = LlmApiToken::global(cx);
+        let llm_api_token = global_llm_token(cx);
 
         Self {
             client,
@@ -73,8 +73,8 @@ async fn perform_web_search(
 
     let http_client = &client.http_client();
     let mut retries_remaining = MAX_RETRIES;
-    let mut token = llm_api_token
-        .acquire(&client, organization_id.clone())
+    let mut token = client
+        .acquire_llm_token(&llm_api_token, organization_id.clone())
         .await?;
 
     loop {
@@ -100,8 +100,8 @@ async fn perform_web_search(
             response.body_mut().read_to_string(&mut body).await?;
             return Ok(serde_json::from_str(&body)?);
         } else if response.needs_llm_token_refresh() {
-            token = llm_api_token
-                .refresh(&client, organization_id.clone())
+            token = client
+                .refresh_llm_token(&llm_api_token, organization_id.clone())
                 .await?;
             retries_remaining -= 1;
         } else {

crates/zed/src/main.rs 🔗

@@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
 use anyhow::{Context as _, Error, Result};
 use clap::Parser;
 use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
-use client::{Client, ProxySettings, UserStore, parse_zed_link};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
 use collab_ui::channel_view::ChannelView;
 use collections::HashMap;
 use crashes::InitCrashHandler;
@@ -664,7 +664,12 @@ fn main() {
         );
 
         copilot_ui::init(&app_state, cx);
-        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(
+            app_state.client.clone(),
+            app_state.user_store.clone(),
+            cx,
+        );
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         acp_tools::init(cx);
         zed::telemetry_log::init(cx);

crates/zed/src/visual_test_runner.rs 🔗

@@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
         });
         prompt_store::init(cx);
         let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
-        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+        language_model::init(cx);
+        client::RefreshLlmTokenListener::register(
+            app_state.client.clone(),
+            app_state.user_store.clone(),
+            cx,
+        );
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         git_ui::init(cx);
         project::AgentRegistryStore::init_global(

crates/zed/src/zed.rs 🔗

@@ -5189,7 +5189,12 @@ mod tests {
                 cx,
             );
             image_viewer::init(cx);
-            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+            language_model::init(cx);
+            client::RefreshLlmTokenListener::register(
+                app_state.client.clone(),
+                app_state.user_store.clone(),
+                cx,
+            );
             language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             web_search::init(cx);
             git_graph::init(cx);

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -313,7 +313,12 @@ mod tests {
         let app_state = cx.update(|cx| {
             let app_state = AppState::test(cx);
             client::init(&app_state.client, cx);
-            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+            language_model::init(cx);
+            client::RefreshLlmTokenListener::register(
+                app_state.client.clone(),
+                app_state.user_store.clone(),
+                cx,
+            );
             editor::init(cx);
             app_state
         });

crates/zed_credentials_provider/Cargo.toml 🔗

@@ -0,0 +1,22 @@
+[package]
+name = "zed_credentials_provider"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zed_credentials_provider.rs"
+
+[dependencies]
+anyhow.workspace = true
+credentials_provider.workspace = true
+futures.workspace = true
+gpui.workspace = true
+paths.workspace = true
+release_channel.workspace = true
+serde.workspace = true
+serde_json.workspace = true

crates/zed_credentials_provider/src/zed_credentials_provider.rs 🔗

@@ -0,0 +1,181 @@
+use std::collections::HashMap;
+use std::future::Future;
+use std::path::PathBuf;
+use std::pin::Pin;
+use std::sync::{Arc, LazyLock};
+
+use anyhow::Result;
+use credentials_provider::CredentialsProvider;
+use futures::FutureExt as _;
+use gpui::{App, AsyncApp, Global};
+use release_channel::ReleaseChannel;
+
+/// An environment variable whose presence indicates that the system keychain
+/// should be used in development.
+///
+/// By default, running Zed in development uses the development credentials
+/// provider. Setting this environment variable allows you to interact with the
+/// system keychain (for instance, if you need to test something).
+///
+/// Only works in development. Setting this environment variable in other
+/// release channels is a no-op.
+static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
+    std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
+});
+
+pub struct ZedCredentialsProvider(pub Arc<dyn CredentialsProvider>);
+
+impl Global for ZedCredentialsProvider {}
+
+/// Returns the global [`CredentialsProvider`].
+pub fn init_global(cx: &mut App) {
+    // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
+    // seems like this is a false positive from Clippy.
+    #[allow(clippy::arc_with_non_send_sync)]
+    let provider = new(cx);
+    cx.set_global(ZedCredentialsProvider(provider));
+}
+
+pub fn global(cx: &App) -> Arc<dyn CredentialsProvider> {
+    cx.try_global::<ZedCredentialsProvider>()
+        .map(|provider| provider.0.clone())
+        .unwrap_or_else(|| new(cx))
+}
+
+fn new(cx: &App) -> Arc<dyn CredentialsProvider> {
+    let use_development_provider = match ReleaseChannel::try_global(cx) {
+        Some(ReleaseChannel::Dev) => {
+            // In development we default to using the development
+            // credentials provider to avoid getting spammed by relentless
+            // keychain access prompts.
+            //
+            // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
+            // variable is set, we will use the actual keychain.
+            !*ZED_DEVELOPMENT_USE_KEYCHAIN
+        }
+        Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) | None => {
+            false
+        }
+    };
+
+    if use_development_provider {
+        Arc::new(DevelopmentCredentialsProvider::new())
+    } else {
+        Arc::new(KeychainCredentialsProvider)
+    }
+}
+
+/// A credentials provider that stores credentials in the system keychain.
+struct KeychainCredentialsProvider;
+
+impl CredentialsProvider for KeychainCredentialsProvider {
+    fn read_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+        async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
+    }
+
+    fn write_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        username: &'a str,
+        password: &'a [u8],
+        cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            cx.update(move |cx| cx.write_credentials(url, username, password))
+                .await
+        }
+        .boxed_local()
+    }
+
+    fn delete_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
+    }
+}
+
+/// A credentials provider that stores credentials in a local file.
+///
+/// This MUST only be used in development, as this is not a secure way of storing
+/// credentials on user machines.
+///
+/// Its existence is purely to work around the annoyance of having to constantly
+/// re-allow access to the system keychain when developing Zed.
+struct DevelopmentCredentialsProvider {
+    path: PathBuf,
+}
+
+impl DevelopmentCredentialsProvider {
+    fn new() -> Self {
+        let path = paths::config_dir().join("development_credentials");
+
+        Self { path }
+    }
+
+    fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
+        let json = std::fs::read(&self.path)?;
+        let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
+
+        Ok(credentials)
+    }
+
+    fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
+        let json = serde_json::to_string(credentials)?;
+        std::fs::write(&self.path, json)?;
+
+        Ok(())
+    }
+}
+
+impl CredentialsProvider for DevelopmentCredentialsProvider {
+    fn read_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        _cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+        async move {
+            Ok(self
+                .load_credentials()
+                .unwrap_or_default()
+                .get(url)
+                .cloned())
+        }
+        .boxed_local()
+    }
+
+    fn write_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        username: &'a str,
+        password: &'a [u8],
+        _cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            let mut credentials = self.load_credentials().unwrap_or_default();
+            credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
+
+            self.save_credentials(&credentials)
+        }
+        .boxed_local()
+    }
+
+    fn delete_credentials<'a>(
+        &'a self,
+        url: &'a str,
+        _cx: &'a AsyncApp,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            let mut credentials = self.load_credentials()?;
+            credentials.remove(url);
+
+            self.save_credentials(&credentials)
+        }
+        .boxed_local()
+    }
+}

crates/zed_env_vars/src/zed_env_vars.rs 🔗

@@ -1,45 +1,6 @@
-use gpui::SharedString;
+pub use env_var::{EnvVar, bool_env_var, env_var};
 use std::sync::LazyLock;
 
 /// Whether Zed is running in stateless mode.
 /// When true, Zed will use in-memory databases instead of persistent storage.
 pub static ZED_STATELESS: LazyLock<bool> = bool_env_var!("ZED_STATELESS");
-
-#[derive(Clone)]
-pub struct EnvVar {
-    pub name: SharedString,
-    /// Value of the environment variable. Also `None` when set to an empty string.
-    pub value: Option<String>,
-}
-
-impl EnvVar {
-    pub fn new(name: SharedString) -> Self {
-        let value = std::env::var(name.as_str()).ok();
-        if value.as_ref().is_some_and(|v| v.is_empty()) {
-            Self { name, value: None }
-        } else {
-            Self { name, value }
-        }
-    }
-
-    pub fn or(self, other: EnvVar) -> EnvVar {
-        if self.value.is_some() { self } else { other }
-    }
-}
-
-/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
-#[macro_export]
-macro_rules! env_var {
-    ($name:expr) => {
-        ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
-    };
-}
-
-/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
-/// environment variable exists and is non-empty.
-#[macro_export]
-macro_rules! bool_env_var {
-    ($name:expr) => {
-        ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
-    };
-}