diff --git a/Cargo.lock b/Cargo.lock
index 8b428dbcd537e33088f40fdde5e3251a6148672a..aae7afecc5ea6f6ba3d63453321c829b677e1c58 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -260,7 +260,6 @@ dependencies = [
"chrono",
"client",
"collections",
- "credentials_provider",
"env_logger 0.11.8",
"feature_flags",
"fs",
@@ -289,6 +288,7 @@ dependencies = [
"util",
"uuid",
"watch",
+ "zed_credentials_provider",
]
[[package]]
@@ -2856,6 +2856,7 @@ dependencies = [
"chrono",
"clock",
"cloud_api_client",
+ "cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
@@ -2869,6 +2870,7 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
+ "language_model",
"log",
"objc2-foundation",
"parking_lot",
@@ -2900,6 +2902,7 @@ dependencies = [
"util",
"windows 0.61.3",
"worktree",
+ "zed_credentials_provider",
]
[[package]]
@@ -3059,6 +3062,7 @@ dependencies = [
"serde",
"serde_json",
"text",
+ "zed_credentials_provider",
"zeta_prompt",
]
@@ -4035,12 +4039,8 @@ name = "credentials_provider"
version = "0.1.0"
dependencies = [
"anyhow",
- "futures 0.3.31",
"gpui",
- "paths",
- "release_channel",
"serde",
- "serde_json",
]
[[package]]
@@ -5115,6 +5115,7 @@ dependencies = [
"collections",
"copilot",
"copilot_ui",
+ "credentials_provider",
"ctor",
"db",
"edit_prediction_context",
@@ -5157,6 +5158,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
+ "zed_credentials_provider",
"zeta_prompt",
"zlog",
"zstd",
@@ -5583,6 +5585,13 @@ dependencies = [
"log",
]
+[[package]]
+name = "env_var"
+version = "0.1.0"
+dependencies = [
+ "gpui",
+]
+
[[package]]
name = "envy"
version = "0.4.2"
@@ -7132,7 +7141,6 @@ dependencies = [
"collections",
"db",
"editor",
- "feature_flags",
"fs",
"git",
"git_ui",
@@ -7190,7 +7198,6 @@ dependencies = [
"ctor",
"db",
"editor",
- "feature_flags",
"file_icons",
"futures 0.3.31",
"fuzzy",
@@ -9317,12 +9324,12 @@ dependencies = [
"anthropic",
"anyhow",
"base64 0.22.1",
- "client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
+ "env_var",
"futures 0.3.31",
"gpui",
"http_client",
@@ -9338,7 +9345,6 @@ dependencies = [
"smol",
"thiserror 2.0.17",
"util",
- "zed_env_vars",
]
[[package]]
@@ -13139,6 +13145,7 @@ dependencies = [
"wax",
"which 6.0.3",
"worktree",
+ "zed_credentials_provider",
"zeroize",
"zlog",
"ztracing",
@@ -15748,6 +15755,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
+ "zed_credentials_provider",
]
[[package]]
@@ -22182,10 +22190,24 @@ dependencies = [
]
[[package]]
-name = "zed_env_vars"
+name = "zed_credentials_provider"
version = "0.1.0"
dependencies = [
+ "anyhow",
+ "credentials_provider",
+ "futures 0.3.31",
"gpui",
+ "paths",
+ "release_channel",
+ "serde",
+ "serde_json",
+]
+
+[[package]]
+name = "zed_env_vars"
+version = "0.1.0"
+dependencies = [
+ "env_var",
]
[[package]]
diff --git a/Cargo.toml b/Cargo.toml
index 3a393237ab9f5a5a8cd4b02517f6d22382ff51ff..81bbb1176ddddcc117fc9082586cbc08dbb95d61 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -61,6 +61,7 @@ members = [
"crates/edit_prediction_ui",
"crates/editor",
"crates/encoding_selector",
+ "crates/env_var",
"crates/etw_tracing",
"crates/eval_cli",
"crates/eval_utils",
@@ -220,6 +221,7 @@ members = [
"crates/x_ai",
"crates/zed",
"crates/zed_actions",
+ "crates/zed_credentials_provider",
"crates/zed_env_vars",
"crates/zeta_prompt",
"crates/zlog",
@@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
encoding_selector = { path = "crates/encoding_selector" }
+env_var = { path = "crates/env_var" }
etw_tracing = { path = "crates/etw_tracing" }
eval_utils = { path = "crates/eval_utils" }
extension = { path = "crates/extension" }
@@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" }
x_ai = { path = "crates/x_ai" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
+zed_credentials_provider = { path = "crates/zed_credentials_provider" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }
diff --git a/assets/icons/diff_split.svg b/assets/icons/diff_split.svg
index de2056466f7ef1081ee00dabb8b4d5baa8fc9217..dcafeb8df5c28bcac1f1fe8cf5783eebd8d8cd8a 100644
--- a/assets/icons/diff_split.svg
+++ b/assets/icons/diff_split.svg
@@ -1,5 +1,4 @@
diff --git a/assets/icons/diff_split_auto.svg b/assets/icons/diff_split_auto.svg
new file mode 100644
index 0000000000000000000000000000000000000000..f9dd7076be75aaf3e90286140a60deece5016114
--- /dev/null
+++ b/assets/icons/diff_split_auto.svg
@@ -0,0 +1,7 @@
+
diff --git a/assets/icons/diff_unified.svg b/assets/icons/diff_unified.svg
index b2d3895ae5466454e9cefc4e77e3c3f2a19cde8c..28735c16f682159b6b0a099176d6fc3b75cd248e 100644
--- a/assets/icons/diff_unified.svg
+++ b/assets/icons/diff_unified.svg
@@ -1,4 +1,4 @@
diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json
index 98053432c5a186ecc886318f2d677f73a62295a2..4930fbea84b2b449f3b5c35fee2a390525cb3551 100644
--- a/assets/keymaps/default-linux.json
+++ b/assets/keymaps/default-linux.json
@@ -284,12 +284,36 @@
"context": "AcpThread",
"bindings": {
"ctrl--": "pane::GoBack",
+ "pageup": "agent::ScrollOutputPageUp",
+ "pagedown": "agent::ScrollOutputPageDown",
+ "home": "agent::ScrollOutputToTop",
+ "end": "agent::ScrollOutputToBottom",
+ "up": "agent::ScrollOutputLineUp",
+ "down": "agent::ScrollOutputLineDown",
+ "shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "shift-pagedown": "agent::ScrollOutputToNextMessage",
+ "ctrl-alt-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-alt-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-alt-home": "agent::ScrollOutputToTop",
+ "ctrl-alt-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-shift-pagedown": "agent::ScrollOutputToNextMessage",
},
},
{
"context": "AcpThread > Editor",
"use_key_equivalents": true,
"bindings": {
+ "ctrl-alt-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-alt-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-alt-home": "agent::ScrollOutputToTop",
+ "ctrl-alt-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-shift-pagedown": "agent::ScrollOutputToNextMessage",
"ctrl-shift-r": "agent::OpenAgentDiff",
"ctrl-shift-d": "git::Diff",
"shift-alt-y": "agent::KeepAll",
diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json
index f0835a139a39602547d9d8da1cba93eaa7ee82a9..85c01bb33b54c30a55b5d046d03eb391d8c058c1 100644
--- a/assets/keymaps/default-macos.json
+++ b/assets/keymaps/default-macos.json
@@ -327,12 +327,36 @@
"context": "AcpThread",
"bindings": {
"ctrl--": "pane::GoBack",
+ "pageup": "agent::ScrollOutputPageUp",
+ "pagedown": "agent::ScrollOutputPageDown",
+ "home": "agent::ScrollOutputToTop",
+ "end": "agent::ScrollOutputToBottom",
+ "up": "agent::ScrollOutputLineUp",
+ "down": "agent::ScrollOutputLineDown",
+ "shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "shift-pagedown": "agent::ScrollOutputToNextMessage",
+ "ctrl-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-home": "agent::ScrollOutputToTop",
+ "ctrl-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-pagedown": "agent::ScrollOutputToNextMessage",
},
},
{
"context": "AcpThread > Editor",
"use_key_equivalents": true,
"bindings": {
+ "ctrl-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-home": "agent::ScrollOutputToTop",
+ "ctrl-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-pagedown": "agent::ScrollOutputToNextMessage",
"shift-ctrl-r": "agent::OpenAgentDiff",
"shift-ctrl-d": "git::Diff",
"shift-alt-y": "agent::KeepAll",
diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json
index 41f36638e1dec40890ddecc6a808c669672e9317..0705717062ab5015de20cc3b93f651f867b5116d 100644
--- a/assets/keymaps/default-windows.json
+++ b/assets/keymaps/default-windows.json
@@ -285,12 +285,36 @@
"context": "AcpThread",
"bindings": {
"ctrl--": "pane::GoBack",
+ "pageup": "agent::ScrollOutputPageUp",
+ "pagedown": "agent::ScrollOutputPageDown",
+ "home": "agent::ScrollOutputToTop",
+ "end": "agent::ScrollOutputToBottom",
+ "up": "agent::ScrollOutputLineUp",
+ "down": "agent::ScrollOutputLineDown",
+ "shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "shift-pagedown": "agent::ScrollOutputToNextMessage",
+ "ctrl-alt-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-alt-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-alt-home": "agent::ScrollOutputToTop",
+ "ctrl-alt-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-shift-pagedown": "agent::ScrollOutputToNextMessage",
},
},
{
"context": "AcpThread > Editor",
"use_key_equivalents": true,
"bindings": {
+ "ctrl-alt-pageup": "agent::ScrollOutputPageUp",
+ "ctrl-alt-pagedown": "agent::ScrollOutputPageDown",
+ "ctrl-alt-home": "agent::ScrollOutputToTop",
+ "ctrl-alt-end": "agent::ScrollOutputToBottom",
+ "ctrl-alt-up": "agent::ScrollOutputLineUp",
+ "ctrl-alt-down": "agent::ScrollOutputLineDown",
+ "ctrl-alt-shift-pageup": "agent::ScrollOutputToPreviousMessage",
+ "ctrl-alt-shift-pagedown": "agent::ScrollOutputToNextMessage",
"ctrl-shift-r": "agent::OpenAgentDiff",
"ctrl-shift-d": "git::Diff",
"shift-alt-y": "agent::KeepAll",
diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs
index e7b67e37bf4a8b71664a78b99b757c6985794ec6..ba8b7ed867ea26bcdcdee7f8bf20390c2f9592b3 100644
--- a/crates/agent/src/edit_agent/evals.rs
+++ b/crates/agent/src/edit_agent/evals.rs
@@ -4,7 +4,7 @@ use crate::{
ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput,
};
use Role::*;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
@@ -1423,7 +1423,8 @@ impl EditAgentTest {
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
});
diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs
index 036a6f1030c43b16d51f864a1d0176891e90b772..9808b95dd0812f9a857da8a9c39e78fde40af1f9 100644
--- a/crates/agent/src/tests/mod.rs
+++ b/crates/agent/src/tests/mod.rs
@@ -6,7 +6,7 @@ use acp_thread::{
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use feature_flags::FeatureFlagAppExt as _;
@@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};
diff --git a/crates/agent/src/tools/evals/streaming_edit_file.rs b/crates/agent/src/tools/evals/streaming_edit_file.rs
index 6a55517037e54ae4166cd22427201d9325ef0f76..0c6290ec098f9c37a0f6a077daf0a041c013d8ff 100644
--- a/crates/agent/src/tools/evals/streaming_edit_file.rs
+++ b/crates/agent/src/tools/evals/streaming_edit_file.rs
@@ -6,7 +6,7 @@ use crate::{
};
use Role::*;
use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
use fs::FakeFs;
use futures::{FutureExt, StreamExt, future::LocalBoxFuture};
use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
@@ -274,7 +274,8 @@ impl StreamingEditToolTest {
cx.set_http_client(http_client);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client, cx);
});
diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml
index 1542466be35bbce80983a73a3fc2e0998799160c..7151f0084b1cb7d9b206f57551ce715ef67483f7 100644
--- a/crates/agent_servers/Cargo.toml
+++ b/crates/agent_servers/Cargo.toml
@@ -32,7 +32,6 @@ futures.workspace = true
gpui.workspace = true
feature_flags.workspace = true
gpui_tokio = { workspace = true, optional = true }
-credentials_provider.workspace = true
google_ai.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -53,6 +52,7 @@ terminal.workspace = true
uuid.workspace = true
util.workspace = true
watch.workspace = true
+zed_credentials_provider.workspace = true
[target.'cfg(unix)'.dependencies]
libc.workspace = true
diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs
index 0dcd2240d6ecf6dc052cdd55953cff8ec1442eae..fb8d0a515244576d2cf02e4989cbd71beca448c7 100644
--- a/crates/agent_servers/src/custom.rs
+++ b/crates/agent_servers/src/custom.rs
@@ -3,7 +3,6 @@ use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use collections::HashSet;
-use credentials_provider::CredentialsProvider;
use fs::Fs;
use gpui::{App, AppContext as _, Entity, Task};
use language_model::{ApiKey, EnvVar};
@@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task> {
if let Some(key) = env_var.value {
return Task::ready(Ok(key));
}
- let credentials_provider = ::global(cx);
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = google_ai::API_URL.to_string();
cx.spawn(async move |cx| {
Ok(
diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs
index 956d106df2a260bd2eb31c14f4f1f1705bf74cd6..aa29a0c230c13949b15f2b39a245ae41ead4884d 100644
--- a/crates/agent_servers/src/e2e_tests.rs
+++ b/crates/agent_servers/src/e2e_tests.rs
@@ -1,6 +1,7 @@
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
+use client::RefreshLlmTokenListener;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::AppContext;
use gpui::{Entity, TestAppContext};
@@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc {
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
- language_model::init(user_store, client, cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store, cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(
diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs
index 4e3dd63b0337f9be54b550f4f4a6a5ca2e7cdd42..b97583377a00d28ea1a8aae6a1380cff3b69e6a0 100644
--- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs
+++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs
@@ -815,7 +815,7 @@ mod tests {
cx.set_global(store);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
editor::init(cx);
});
diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs
index d5cf63f6cdde9a85a54daaa29f8fc2c6833bdd77..7b70740dd1ac462614a9d08d9e48d7d13ac2ed32 100644
--- a/crates/agent_ui/src/agent_diff.rs
+++ b/crates/agent_ui/src/agent_diff.rs
@@ -1809,7 +1809,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
});
let fs = FakeFs::new(cx.executor());
@@ -1966,7 +1966,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
- language_model::init_settings(cx);
+ language_model::init(cx);
workspace::register_project_item::(cx);
});
diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs
index 185a54825d3af18f16f2eb30188ea866c099bf32..e58c7eb3526cc1a53d7b8e6d449e968a5923425a 100644
--- a/crates/agent_ui/src/agent_ui.rs
+++ b/crates/agent_ui/src/agent_ui.rs
@@ -173,6 +173,22 @@ actions!(
ToggleThinkingEffortMenu,
/// Toggles fast mode for models that support it.
ToggleFastMode,
+ /// Scroll the output by one page up.
+ ScrollOutputPageUp,
+ /// Scroll the output by one page down.
+ ScrollOutputPageDown,
+ /// Scroll the output up by three lines.
+ ScrollOutputLineUp,
+ /// Scroll the output down by three lines.
+ ScrollOutputLineDown,
+ /// Scroll the output to the top.
+ ScrollOutputToTop,
+ /// Scroll the output to the bottom.
+ ScrollOutputToBottom,
+ /// Scroll the output to the previous user message.
+ ScrollOutputToPreviousMessage,
+ /// Scroll the output to the next user message.
+ ScrollOutputToNextMessage,
]
);
diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs
index 83a0c158a11c54be1ff54f553ce4b427da2cabc2..1b9d364e9ce03702b47c63e8a856f0ba4b8aba87 100644
--- a/crates/agent_ui/src/conversation_view.rs
+++ b/crates/agent_ui/src/conversation_view.rs
@@ -85,8 +85,11 @@ use crate::{
AuthorizeToolCall, ClearMessageQueue, CycleFavoriteModels, CycleModeSelector,
CycleThinkingEffort, EditFirstQueuedMessage, ExpandMessageEditor, Follow, KeepAll, NewThread,
OpenAddContextMenu, OpenAgentDiff, OpenHistory, RejectAll, RejectOnce,
- RemoveFirstQueuedMessage, SendImmediately, SendNextQueuedMessage, ToggleFastMode,
- ToggleProfileSelector, ToggleThinkingEffortMenu, ToggleThinkingMode, UndoLastReject,
+ RemoveFirstQueuedMessage, ScrollOutputLineDown, ScrollOutputLineUp, ScrollOutputPageDown,
+ ScrollOutputPageUp, ScrollOutputToBottom, ScrollOutputToNextMessage,
+ ScrollOutputToPreviousMessage, ScrollOutputToTop, SendImmediately, SendNextQueuedMessage,
+ ToggleFastMode, ToggleProfileSelector, ToggleThinkingEffortMenu, ToggleThinkingMode,
+ UndoLastReject,
};
const STOPWATCH_THRESHOLD: Duration = Duration::from_secs(30);
@@ -1240,15 +1243,15 @@ impl ConversationView {
if let Some(active) = self.thread_view(&thread_id) {
let entry_view_state = active.read(cx).entry_view_state.clone();
let list_state = active.read(cx).list_state.clone();
- entry_view_state.update(cx, |view_state, cx| {
- view_state.sync_entry(index, thread, window, cx);
- list_state.splice_focusable(
- index..index,
- [view_state
- .entry(index)
- .and_then(|entry| entry.focus_handle(cx))],
- );
- });
+ notify_entry_changed(
+ &entry_view_state,
+ &list_state,
+ index..index,
+ index,
+ thread,
+ window,
+ cx,
+ );
active.update(cx, |active, cx| {
active.sync_editor_mode_for_empty_state(cx);
});
@@ -1257,9 +1260,16 @@ impl ConversationView {
AcpThreadEvent::EntryUpdated(index) => {
if let Some(active) = self.thread_view(&thread_id) {
let entry_view_state = active.read(cx).entry_view_state.clone();
- entry_view_state.update(cx, |view_state, cx| {
- view_state.sync_entry(*index, thread, window, cx)
- });
+ let list_state = active.read(cx).list_state.clone();
+ notify_entry_changed(
+ &entry_view_state,
+ &list_state,
+ *index..*index + 1,
+ *index,
+ thread,
+ window,
+ cx,
+ );
active.update(cx, |active, cx| {
active.auto_expand_streaming_thought(cx);
});
@@ -2598,6 +2608,32 @@ impl ConversationView {
}
}
+/// Syncs an entry's view state with the latest thread data and splices
+/// the list item so the list knows to re-measure it on the next paint.
+///
+/// Used by both `NewEntry` (splice range `index..index` to insert) and
+/// `EntryUpdated` (splice range `index..index+1` to replace), which is
+/// why the caller provides the splice range.
+fn notify_entry_changed(
+ entry_view_state: &Entity,
+ list_state: &ListState,
+ splice_range: std::ops::Range,
+ index: usize,
+ thread: &Entity,
+ window: &mut Window,
+ cx: &mut App,
+) {
+ entry_view_state.update(cx, |view_state, cx| {
+ view_state.sync_entry(index, thread, window, cx);
+ list_state.splice_focusable(
+ splice_range,
+ [view_state
+ .entry(index)
+ .and_then(|entry| entry.focus_handle(cx))],
+ );
+ });
+}
+
fn loading_contents_spinner(size: IconSize) -> AnyElement {
Icon::new(IconName::LoadCircle)
.size(size)
diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs
index c065c3de3d83c0eb5b68bf9a3610ff925762c952..c113eb0b768ee143eb69b5e705c15c91e367e6c2 100644
--- a/crates/agent_ui/src/conversation_view/thread_view.rs
+++ b/crates/agent_ui/src/conversation_view/thread_view.rs
@@ -552,17 +552,10 @@ impl ThreadView {
let scroll_top = list_state.logical_scroll_top();
let _ = thread_view.update(cx, |this, cx| {
if !is_following_tail {
- let is_at_bottom = {
- let current_offset =
- list_state.scroll_px_offset_for_scrollbar().y.abs();
- let max_offset = list_state.max_offset_for_scrollbar().y;
- current_offset >= max_offset - px(1.0)
- };
-
let is_generating =
matches!(this.thread.read(cx).status(), ThreadStatus::Generating);
- if is_at_bottom && is_generating {
+ if list_state.is_at_bottom() && is_generating {
list_state.set_follow_tail(true);
}
}
@@ -4952,7 +4945,7 @@ impl ThreadView {
}
pub fn scroll_to_end(&mut self, cx: &mut Context) {
- self.list_state.scroll_to_end();
+ self.list_state.set_follow_tail(true);
cx.notify();
}
@@ -4974,10 +4967,122 @@ impl ThreadView {
}
pub(crate) fn scroll_to_top(&mut self, cx: &mut Context) {
+ self.list_state.set_follow_tail(false);
self.list_state.scroll_to(ListOffset::default());
cx.notify();
}
+ fn scroll_output_page_up(
+ &mut self,
+ _: &ScrollOutputPageUp,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ let page_height = self.list_state.viewport_bounds().size.height;
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_by(-page_height * 0.9);
+ cx.notify();
+ }
+
+ fn scroll_output_page_down(
+ &mut self,
+ _: &ScrollOutputPageDown,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ let page_height = self.list_state.viewport_bounds().size.height;
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_by(page_height * 0.9);
+ if self.list_state.is_at_bottom() {
+ self.list_state.set_follow_tail(true);
+ }
+ cx.notify();
+ }
+
+ fn scroll_output_line_up(
+ &mut self,
+ _: &ScrollOutputLineUp,
+ window: &mut Window,
+ cx: &mut Context,
+ ) {
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_by(-window.line_height() * 3.);
+ cx.notify();
+ }
+
+ fn scroll_output_line_down(
+ &mut self,
+ _: &ScrollOutputLineDown,
+ window: &mut Window,
+ cx: &mut Context,
+ ) {
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_by(window.line_height() * 3.);
+ if self.list_state.is_at_bottom() {
+ self.list_state.set_follow_tail(true);
+ }
+ cx.notify();
+ }
+
+ fn scroll_output_to_top(
+ &mut self,
+ _: &ScrollOutputToTop,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ self.scroll_to_top(cx);
+ }
+
+ fn scroll_output_to_bottom(
+ &mut self,
+ _: &ScrollOutputToBottom,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ self.scroll_to_end(cx);
+ }
+
+ fn scroll_output_to_previous_message(
+ &mut self,
+ _: &ScrollOutputToPreviousMessage,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ let entries = self.thread.read(cx).entries();
+ let current_ix = self.list_state.logical_scroll_top().item_ix;
+ if let Some(target_ix) = (0..current_ix)
+ .rev()
+ .find(|&i| matches!(entries.get(i), Some(AgentThreadEntry::UserMessage(_))))
+ {
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_to(ListOffset {
+ item_ix: target_ix,
+ offset_in_item: px(0.),
+ });
+ cx.notify();
+ }
+ }
+
+ fn scroll_output_to_next_message(
+ &mut self,
+ _: &ScrollOutputToNextMessage,
+ _window: &mut Window,
+ cx: &mut Context,
+ ) {
+ let entries = self.thread.read(cx).entries();
+ let current_ix = self.list_state.logical_scroll_top().item_ix;
+ if let Some(target_ix) = (current_ix + 1..entries.len())
+ .find(|&i| matches!(entries.get(i), Some(AgentThreadEntry::UserMessage(_))))
+ {
+ self.list_state.set_follow_tail(false);
+ self.list_state.scroll_to(ListOffset {
+ item_ix: target_ix,
+ offset_in_item: px(0.),
+ });
+ cx.notify();
+ }
+ }
+
pub fn open_thread_as_markdown(
&self,
workspace: Entity,
@@ -8541,6 +8646,14 @@ impl Render for ThreadView {
.on_action(cx.listener(Self::handle_toggle_command_pattern))
.on_action(cx.listener(Self::open_permission_dropdown))
.on_action(cx.listener(Self::open_add_context_menu))
+ .on_action(cx.listener(Self::scroll_output_page_up))
+ .on_action(cx.listener(Self::scroll_output_page_down))
+ .on_action(cx.listener(Self::scroll_output_line_up))
+ .on_action(cx.listener(Self::scroll_output_line_down))
+ .on_action(cx.listener(Self::scroll_output_to_top))
+ .on_action(cx.listener(Self::scroll_output_to_bottom))
+ .on_action(cx.listener(Self::scroll_output_to_previous_message))
+ .on_action(cx.listener(Self::scroll_output_to_next_message))
.on_action(cx.listener(|this, _: &ToggleFastMode, _window, cx| {
this.toggle_fast_mode(cx);
}))
diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs
index 20e0b702978b7e72a8526b03570854965335310c..39d70790e0d4a18554b2a1c11510e529d921cd1b 100644
--- a/crates/agent_ui/src/inline_assistant.rs
+++ b/crates/agent_ui/src/inline_assistant.rs
@@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) {
pub mod evals {
use crate::InlineAssistant;
use agent::ThreadStore;
- use client::{Client, UserStore};
+ use client::{Client, RefreshLlmTokenListener, UserStore};
use editor::{Editor, MultiBuffer, MultiBufferOffset};
use eval_utils::{EvalOutput, NoProcessor};
use fs::FakeFs;
@@ -2091,7 +2091,8 @@ pub mod evals {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);
diff --git a/crates/cli/src/cli.rs b/crates/cli/src/cli.rs
index 1a3ce059b8116ac7438f3eb0330b47660cc863de..d8da78c53210230597dab49ce297d9fa694e62f1 100644
--- a/crates/cli/src/cli.rs
+++ b/crates/cli/src/cli.rs
@@ -21,6 +21,7 @@ pub enum CliRequest {
reuse: bool,
env: Option>,
user_data_dir: Option,
+ dev_container: bool,
},
}
diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs
index b8af5896285d3080ca3320a5909b3f58f72de643..41f2d14c1908ac18e7ea297eef19d8d9bd1cf8b5 100644
--- a/crates/cli/src/main.rs
+++ b/crates/cli/src/main.rs
@@ -118,6 +118,12 @@ struct Args {
/// Will attempt to give the correct command to run
#[arg(long)]
system_specs: bool,
+ /// Open the project in a dev container.
+ ///
+ /// Automatically triggers "Reopen in Dev Container" if a `.devcontainer/`
+ /// configuration is found in the project directory.
+ #[arg(long)]
+ dev_container: bool,
/// Pairs of file paths to diff. Can be specified multiple times.
/// When directories are provided, recurses into them and shows all changed files in a single multi-diff view.
#[arg(long, action = clap::ArgAction::Append, num_args = 2, value_names = ["OLD_PATH", "NEW_PATH"])]
@@ -670,6 +676,7 @@ fn main() -> Result<()> {
reuse: args.reuse,
env,
user_data_dir: user_data_dir_for_thread,
+ dev_container: args.dev_container,
})?;
while let Ok(response) = rx.recv() {
diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml
index 1edbb3399e4332e2ebd23f812c66697bda72d587..7bbaccb22e0e6c7508240186103e216f83be2f0c 100644
--- a/crates/client/Cargo.toml
+++ b/crates/client/Cargo.toml
@@ -22,6 +22,7 @@ base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
+cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
@@ -35,6 +36,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
+language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -60,6 +62,7 @@ tokio.workspace = true
url.workspace = true
util.workspace = true
worktree.workspace = true
+zed_credentials_provider.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs
index 6a11a6b924eed3dfd79ff379638ed4085e2b7bcb..dfd9963a0ee52d167f8d4edb0b850f4debed7fd4 100644
--- a/crates/client/src/client.rs
+++ b/crates/client/src/client.rs
@@ -1,6 +1,7 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
+mod llm_token;
mod proxy;
pub mod telemetry;
pub mod user;
@@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
-use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
+use cloud_api_client::{ClientApiError, CloudApiClient};
+use cloud_api_types::OrganizationId;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
@@ -24,6 +26,7 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
+use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;
@@ -51,6 +54,7 @@ use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
+pub use llm_token::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider {
impl ClientCredentialsProvider {
pub fn new(cx: &App) -> Self {
Self {
- provider: ::global(cx),
+ provider: zed_credentials_provider::global(cx),
}
}
@@ -568,6 +572,10 @@ impl Client {
self.http.clone()
}
+ pub fn credentials_provider(&self) -> Arc {
+ self.credentials_provider.provider.clone()
+ }
+
pub fn cloud_client(&self) -> Arc {
self.cloud_client.clone()
}
@@ -1513,6 +1521,66 @@ impl Client {
})
}
+ pub async fn acquire_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option,
+ ) -> Result {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .acquire(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
+ }
+ Err(err) => Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option,
+ ) -> Result {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
+ pub async fn clear_and_refresh_llm_token(
+ &self,
+ llm_token: &LlmApiToken,
+ organization_id: Option,
+ ) -> Result {
+ let system_id = self.telemetry().system_id().map(|x| x.to_string());
+ let cloud_client = self.cloud_client();
+ match llm_token
+ .clear_and_refresh(&cloud_client, system_id, organization_id)
+ .await
+ {
+ Ok(token) => Ok(token),
+ Err(ClientApiError::Unauthorized) => {
+ self.request_sign_out();
+ return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+ }
+ Err(err) => return Err(anyhow::Error::from(err)),
+ }
+ }
+
pub async fn sign_out(self: &Arc, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();
diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs
new file mode 100644
index 0000000000000000000000000000000000000000..f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816
--- /dev/null
+++ b/crates/client/src/llm_token.rs
@@ -0,0 +1,116 @@
+use super::{Client, UserStore};
+use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
+use gpui::{
+ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
+use language_model::LlmApiToken;
+use std::sync::Arc;
+
+pub trait NeedsLlmTokenRefresh {
+ /// Returns whether the LLM token needs to be refreshed.
+ fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response {
+ fn needs_llm_token_refresh(&self) -> bool {
+ self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+ || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
+ }
+}
+
+enum TokenRefreshMode {
+ Refresh,
+ ClearAndRefresh,
+}
+
+pub fn global_llm_token(cx: &App) -> LlmApiToken {
+ RefreshLlmTokenListener::global(cx)
+ .read(cx)
+ .llm_api_token
+ .clone()
+}
+
+struct GlobalRefreshLlmTokenListener(Entity);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct LlmTokenRefreshedEvent;
+
+pub struct RefreshLlmTokenListener {
+ client: Arc,
+ user_store: Entity,
+ llm_api_token: LlmApiToken,
+ _subscription: Subscription,
+}
+
+impl EventEmitter for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+ pub fn register(client: Arc, user_store: Entity, cx: &mut App) {
+ let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
+ cx.set_global(GlobalRefreshLlmTokenListener(listener));
+ }
+
+ pub fn global(cx: &App) -> Entity {
+ GlobalRefreshLlmTokenListener::global(cx).0.clone()
+ }
+
+ fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self {
+ client.add_message_to_client_handler({
+ let this = cx.weak_entity();
+ move |message, cx| {
+ if let Some(this) = this.upgrade() {
+ Self::handle_refresh_llm_token(this, message, cx);
+ }
+ }
+ });
+
+ let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
+ if matches!(event, super::user::Event::OrganizationChanged) {
+ this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
+ }
+ });
+
+ Self {
+ client,
+ user_store,
+ llm_api_token: LlmApiToken::default(),
+ _subscription: subscription,
+ }
+ }
+
+ fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+ cx.spawn(async move |this, cx| {
+ match mode {
+ TokenRefreshMode::Refresh => {
+ client
+ .refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ TokenRefreshMode::ClearAndRefresh => {
+ client
+ .clear_and_refresh_llm_token(&llm_api_token, organization_id)
+ .await?;
+ }
+ }
+ this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) {
+ match message {
+ MessageToClient::UserUpdated => {
+ this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ }
+ }
+ }
+}
diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml
index 0daaee8fb1420c76757ca898655e8dd1a5244d7e..801221d3128b8aa2d25175e086a741d5d85da626 100644
--- a/crates/codestral/Cargo.toml
+++ b/crates/codestral/Cargo.toml
@@ -22,6 +22,7 @@ log.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true
+zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs
index 3930e2e873a91618bfae456bc188bbd90ffa64b9..7685fa8f5b1eae9e98a621484602e199c2b76f96 100644
--- a/crates/codestral/src/codestral.rs
+++ b/crates/codestral/src/codestral.rs
@@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option> {
}
pub fn load_codestral_api_key(cx: &mut App) -> Task> {
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = codestral_api_url(cx);
codestral_api_key_state(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(api_url, |s| s, cx)
+ key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}
diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs
index 91385b298dc661c4a79e4fb52d5be0f38672bff5..d16db59ea4ae2d766018dfc03c245839e4862cb4 100644
--- a/crates/collab_ui/src/collab_panel.rs
+++ b/crates/collab_ui/src/collab_panel.rs
@@ -13,12 +13,13 @@ use db::kvp::KeyValueStore;
use editor::{Editor, EditorElement, EditorStyle};
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
- AnyElement, App, AsyncWindowContext, Bounds, ClickEvent, ClipboardItem, Context, DismissEvent,
- Div, Entity, EventEmitter, FocusHandle, Focusable, FontStyle, InteractiveElement, IntoElement,
- KeyContext, ListOffset, ListState, MouseDownEvent, ParentElement, Pixels, Point, PromptLevel,
- Render, SharedString, Styled, Subscription, Task, TextStyle, WeakEntity, Window, actions,
- anchored, canvas, deferred, div, fill, list, point, prelude::*, px,
+ AnyElement, App, AsyncWindowContext, Bounds, ClickEvent, ClipboardItem, DismissEvent, Div,
+ Empty, Entity, EventEmitter, FocusHandle, Focusable, FontStyle, KeyContext, ListOffset,
+ ListState, MouseDownEvent, Pixels, Point, PromptLevel, SharedString, Subscription, Task,
+ TextStyle, WeakEntity, Window, actions, anchored, canvas, deferred, div, fill, list, point,
+ prelude::*, px,
};
+
use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrevious};
use project::{Fs, Project};
use rpc::{
@@ -1091,27 +1092,30 @@ impl CollabPanel {
room.read(cx).local_participant().role == proto::ChannelRole::Admin
});
+ let end_slot = if is_pending {
+ Label::new("Calling").color(Color::Muted).into_any_element()
+ } else if is_current_user {
+ IconButton::new("leave-call", IconName::Exit)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text("Leave Call"))
+ .on_click(move |_, window, cx| Self::leave_call(window, cx))
+ .into_any_element()
+ } else if role == proto::ChannelRole::Guest {
+ Label::new("Guest").color(Color::Muted).into_any_element()
+ } else if role == proto::ChannelRole::Talker {
+ Label::new("Mic only")
+ .color(Color::Muted)
+ .into_any_element()
+ } else {
+ Empty.into_any_element()
+ };
+
ListItem::new(user.github_login.clone())
.start_slot(Avatar::new(user.avatar_uri.clone()))
.child(render_participant_name_and_handle(user))
.toggle_state(is_selected)
- .end_slot(if is_pending {
- Label::new("Calling").color(Color::Muted).into_any_element()
- } else if is_current_user {
- IconButton::new("leave-call", IconName::Exit)
- .style(ButtonStyle::Subtle)
- .on_click(move |_, window, cx| Self::leave_call(window, cx))
- .tooltip(Tooltip::text("Leave Call"))
- .into_any_element()
- } else if role == proto::ChannelRole::Guest {
- Label::new("Guest").color(Color::Muted).into_any_element()
- } else if role == proto::ChannelRole::Talker {
- Label::new("Mic only")
- .color(Color::Muted)
- .into_any_element()
- } else {
- div().into_any_element()
- })
+ .end_slot(end_slot)
+ .tooltip(Tooltip::text("Click to Follow"))
.when_some(peer_id, |el, peer_id| {
if role == proto::ChannelRole::Guest {
return el;
@@ -1156,6 +1160,7 @@ impl CollabPanel {
.into();
ListItem::new(project_id as usize)
+ .height(px(24.))
.toggle_state(is_selected)
.on_click(cx.listener(move |this, _, window, cx| {
this.workspace
@@ -1173,9 +1178,13 @@ impl CollabPanel {
}))
.start_slot(
h_flex()
- .gap_1()
+ .gap_1p5()
.child(render_tree_branch(is_last, false, window, cx))
- .child(IconButton::new(0, IconName::Folder)),
+ .child(
+ Icon::new(IconName::Folder)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.child(Label::new(project_name.clone()))
.tooltip(Tooltip::text(format!("Open {}", project_name)))
@@ -1192,12 +1201,17 @@ impl CollabPanel {
let id = peer_id.map_or(usize::MAX, |id| id.as_u64() as usize);
ListItem::new(("screen", id))
+ .height(px(24.))
.toggle_state(is_selected)
.start_slot(
h_flex()
- .gap_1()
+ .gap_1p5()
.child(render_tree_branch(is_last, false, window, cx))
- .child(IconButton::new(0, IconName::Screen)),
+ .child(
+ Icon::new(IconName::Screen)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.child(Label::new("Screen"))
.when_some(peer_id, |this, _| {
@@ -1208,7 +1222,7 @@ impl CollabPanel {
})
.ok();
}))
- .tooltip(Tooltip::text("Open shared screen"))
+ .tooltip(Tooltip::text("Open Shared Screen"))
})
}
@@ -1232,7 +1246,9 @@ impl CollabPanel {
) -> impl IntoElement {
let channel_store = self.channel_store.read(cx);
let has_channel_buffer_changed = channel_store.has_channel_buffer_changed(channel_id);
+
ListItem::new("channel-notes")
+ .height(px(24.))
.toggle_state(is_selected)
.on_click(cx.listener(move |this, _, window, cx| {
this.open_channel_notes(channel_id, window, cx);
@@ -1240,17 +1256,25 @@ impl CollabPanel {
.start_slot(
h_flex()
.relative()
- .gap_1()
+ .gap_1p5()
.child(render_tree_branch(false, true, window, cx))
- .child(IconButton::new(0, IconName::File))
- .children(has_channel_buffer_changed.then(|| {
- div()
- .w_1p5()
- .absolute()
- .right(px(2.))
- .top(px(2.))
- .child(Indicator::dot().color(Color::Info))
- })),
+ .child(
+ h_flex()
+ .child(
+ Icon::new(IconName::Reader)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .when(has_channel_buffer_changed, |this| {
+ this.child(
+ div()
+ .absolute()
+ .top_neg_0p5()
+ .right_0()
+ .child(Indicator::dot().color(Color::Info)),
+ )
+ }),
+ ),
)
.child(Label::new("notes"))
.tooltip(Tooltip::text("Open Channel Notes"))
@@ -3144,10 +3168,14 @@ impl CollabPanel {
(IconName::Star, Color::Default, "Add to Favorites")
};
+ let height = px(24.);
+
h_flex()
.id(ix)
.group("")
+ .h(height)
.w_full()
+ .overflow_hidden()
.when(!channel.is_root_channel(), |el| {
el.on_drag(channel.clone(), move |channel, _, _, cx| {
cx.new(|_| DraggedChannelView {
@@ -3175,6 +3203,7 @@ impl CollabPanel {
)
.child(
ListItem::new(ix)
+ .height(height)
// Add one level of depth for the disclosure arrow.
.indent_level(depth + 1)
.indent_step_size(px(20.))
@@ -3256,12 +3285,13 @@ impl CollabPanel {
.child(
h_flex()
.visible_on_hover("")
+ .h_full()
.absolute()
.right_0()
.px_1()
.gap_px()
- .bg(cx.theme().colors().background)
.rounded_l_md()
+ .bg(cx.theme().colors().background)
.child({
let focus_handle = self.focus_handle.clone();
IconButton::new("channel_favorite", favorite_icon)
@@ -3335,9 +3365,8 @@ fn render_tree_branch(
) -> impl IntoElement {
let rem_size = window.rem_size();
let line_height = window.text_style().line_height_in_pixels(rem_size);
- let width = rem_size * 1.5;
let thickness = px(1.);
- let color = cx.theme().colors().text;
+ let color = cx.theme().colors().icon_disabled;
canvas(
|_, _, _| {},
@@ -3367,8 +3396,8 @@ fn render_tree_branch(
));
},
)
- .w(width)
- .h(line_height)
+ .w(rem_size)
+ .h(line_height - px(2.))
}
fn render_participant_name_and_handle(user: &User) -> impl IntoElement {
diff --git a/crates/credentials_provider/Cargo.toml b/crates/credentials_provider/Cargo.toml
index bf47bb24b12b90d54bc04f766efe06489c730b43..da83c0cd79a1b71bbb84746b3e893f33094783d6 100644
--- a/crates/credentials_provider/Cargo.toml
+++ b/crates/credentials_provider/Cargo.toml
@@ -13,9 +13,5 @@ path = "src/credentials_provider.rs"
[dependencies]
anyhow.workspace = true
-futures.workspace = true
gpui.workspace = true
-paths.workspace = true
-release_channel.workspace = true
serde.workspace = true
-serde_json.workspace = true
diff --git a/crates/credentials_provider/src/credentials_provider.rs b/crates/credentials_provider/src/credentials_provider.rs
index 249b8333e114223aa558cd33637fd103294a8f8d..b98e97673cc11272826af24c76e8a0a6a38b9211 100644
--- a/crates/credentials_provider/src/credentials_provider.rs
+++ b/crates/credentials_provider/src/credentials_provider.rs
@@ -1,26 +1,8 @@
-use std::collections::HashMap;
use std::future::Future;
-use std::path::PathBuf;
use std::pin::Pin;
-use std::sync::{Arc, LazyLock};
use anyhow::Result;
-use futures::FutureExt as _;
-use gpui::{App, AsyncApp};
-use release_channel::ReleaseChannel;
-
-/// An environment variable whose presence indicates that the system keychain
-/// should be used in development.
-///
-/// By default, running Zed in development uses the development credentials
-/// provider. Setting this environment variable allows you to interact with the
-/// system keychain (for instance, if you need to test something).
-///
-/// Only works in development. Setting this environment variable in other
-/// release channels is a no-op.
-static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock = LazyLock::new(|| {
- std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
-});
+use gpui::AsyncApp;
/// A provider for credentials.
///
@@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync {
cx: &'a AsyncApp,
) -> Pin> + 'a>>;
}
-
-impl dyn CredentialsProvider {
- /// Returns the global [`CredentialsProvider`].
- pub fn global(cx: &App) -> Arc {
- // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
- // seems like this is a false positive from Clippy.
- #[allow(clippy::arc_with_non_send_sync)]
- Self::new(cx)
- }
-
- fn new(cx: &App) -> Arc {
- let use_development_provider = match ReleaseChannel::try_global(cx) {
- Some(ReleaseChannel::Dev) => {
- // In development we default to using the development
- // credentials provider to avoid getting spammed by relentless
- // keychain access prompts.
- //
- // However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
- // variable is set, we will use the actual keychain.
- !*ZED_DEVELOPMENT_USE_KEYCHAIN
- }
- Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
- | None => false,
- };
-
- if use_development_provider {
- Arc::new(DevelopmentCredentialsProvider::new())
- } else {
- Arc::new(KeychainCredentialsProvider)
- }
- }
-}
-
-/// A credentials provider that stores credentials in the system keychain.
-struct KeychainCredentialsProvider;
-
-impl CredentialsProvider for KeychainCredentialsProvider {
- fn read_credentials<'a>(
- &'a self,
- url: &'a str,
- cx: &'a AsyncApp,
- ) -> Pin)>>> + 'a>> {
- async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
- }
-
- fn write_credentials<'a>(
- &'a self,
- url: &'a str,
- username: &'a str,
- password: &'a [u8],
- cx: &'a AsyncApp,
- ) -> Pin> + 'a>> {
- async move {
- cx.update(move |cx| cx.write_credentials(url, username, password))
- .await
- }
- .boxed_local()
- }
-
- fn delete_credentials<'a>(
- &'a self,
- url: &'a str,
- cx: &'a AsyncApp,
- ) -> Pin> + 'a>> {
- async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
- }
-}
-
-/// A credentials provider that stores credentials in a local file.
-///
-/// This MUST only be used in development, as this is not a secure way of storing
-/// credentials on user machines.
-///
-/// Its existence is purely to work around the annoyance of having to constantly
-/// re-allow access to the system keychain when developing Zed.
-struct DevelopmentCredentialsProvider {
- path: PathBuf,
-}
-
-impl DevelopmentCredentialsProvider {
- fn new() -> Self {
- let path = paths::config_dir().join("development_credentials");
-
- Self { path }
- }
-
- fn load_credentials(&self) -> Result)>> {
- let json = std::fs::read(&self.path)?;
- let credentials: HashMap)> = serde_json::from_slice(&json)?;
-
- Ok(credentials)
- }
-
- fn save_credentials(&self, credentials: &HashMap)>) -> Result<()> {
- let json = serde_json::to_string(credentials)?;
- std::fs::write(&self.path, json)?;
-
- Ok(())
- }
-}
-
-impl CredentialsProvider for DevelopmentCredentialsProvider {
- fn read_credentials<'a>(
- &'a self,
- url: &'a str,
- _cx: &'a AsyncApp,
- ) -> Pin)>>> + 'a>> {
- async move {
- Ok(self
- .load_credentials()
- .unwrap_or_default()
- .get(url)
- .cloned())
- }
- .boxed_local()
- }
-
- fn write_credentials<'a>(
- &'a self,
- url: &'a str,
- username: &'a str,
- password: &'a [u8],
- _cx: &'a AsyncApp,
- ) -> Pin> + 'a>> {
- async move {
- let mut credentials = self.load_credentials().unwrap_or_default();
- credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
-
- self.save_credentials(&credentials)
- }
- .boxed_local()
- }
-
- fn delete_credentials<'a>(
- &'a self,
- url: &'a str,
- _cx: &'a AsyncApp,
- ) -> Pin> + 'a>> {
- async move {
- let mut credentials = self.load_credentials()?;
- credentials.remove(url);
-
- self.save_credentials(&credentials)
- }
- .boxed_local()
- }
-}
diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml
index 75a589dea8f9c7fefe7bf13400cbdde54bf90bf1..eabb1641fd4fbec7b2f8ef0ba399a8fe9600dfa3 100644
--- a/crates/edit_prediction/Cargo.toml
+++ b/crates/edit_prediction/Cargo.toml
@@ -26,6 +26,7 @@ cloud_llm_client.workspace = true
collections.workspace = true
copilot.workspace = true
copilot_ui.workspace = true
+credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
@@ -65,6 +66,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
+zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
zstd.workspace = true
diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs
index 5eb422246775c4409f7f15e3a672a2d407386acc..9463456132ce391b54aca8327cb6f900d81481d6 100644
--- a/crates/edit_prediction/src/capture_example.rs
+++ b/crates/edit_prediction/src/capture_example.rs
@@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String {
mod tests {
use super::*;
use crate::EditPredictionStore;
+ use client::RefreshLlmTokenListener;
use client::{Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
@@ -548,7 +549,8 @@ mod tests {
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}
diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs
index 61690c470829ca4bb16a6af9f1df2ea6e7cc6023..280427df006b510e1854ffb40cd7f995fcd9fdc6 100644
--- a/crates/edit_prediction/src/edit_prediction.rs
+++ b/crates/edit_prediction/src/edit_prediction.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -11,6 +11,7 @@ use cloud_llm_client::{
};
use collections::{HashMap, HashSet};
use copilot::{Copilot, Reinstall, SignIn, SignOut};
+use credentials_provider::CredentialsProvider;
use db::kvp::{Dismissable, KeyValueStore};
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
@@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -150,6 +151,7 @@ pub struct EditPredictionStore {
rated_predictions: HashSet,
#[cfg(test)]
settled_event_callback: Option>,
+ credentials_provider: Arc,
}
pub(crate) struct EditPredictionRejectionPayload {
@@ -746,7 +748,7 @@ impl EditPredictionStore {
pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self {
let data_collection_choice = Self::load_data_collection_choice(cx);
- let llm_token = LlmApiToken::global(cx);
+ let llm_token = global_llm_token(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@@ -787,6 +789,8 @@ impl EditPredictionStore {
.log_err();
});
+ let credentials_provider = zed_credentials_provider::global(cx);
+
let this = Self {
projects: HashMap::default(),
client,
@@ -807,6 +811,8 @@ impl EditPredictionStore {
shown_predictions: Default::default(),
#[cfg(test)]
settled_event_callback: None,
+
+ credentials_provider,
};
this
@@ -871,7 +877,9 @@ impl EditPredictionStore {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
- let token = llm_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -2315,7 +2323,10 @@ impl EditPredictionStore {
zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
}
EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
- EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
+ EditPredictionModel::Mercury => {
+ self.mercury
+ .request_prediction(inputs, self.credentials_provider.clone(), cx)
+ }
};
cx.spawn(async move |this, cx| {
@@ -2536,12 +2547,15 @@ impl EditPredictionStore {
Res: DeserializeOwned,
{
let http_client = client.http_client();
-
let mut token = if require_auth {
- Some(llm_token.acquire(&client, organization_id.clone()).await?)
+ Some(
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ )
} else {
- llm_token
- .acquire(&client, organization_id.clone())
+ client
+ .acquire_llm_token(&llm_token, organization_id.clone())
.await
.ok()
};
@@ -2585,7 +2599,11 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
- token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
+ token = Some(
+ client
+ .refresh_llm_token(&llm_token, organization_id.clone())
+ .await?,
+ );
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs
index 6fe61338e764a40aec9cf6f3191f1191bafe9200..1ba8b27aa785024a47a09c3299a1f3786a028ccf 100644
--- a/crates/edit_prediction/src/edit_prediction_tests.rs
+++ b/crates/edit_prediction/src/edit_prediction_tests.rs
@@ -1,6 +1,6 @@
use super::*;
use crate::udiff::apply_diff_to_string;
-use client::{UserStore, test::FakeServer};
+use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
use clock::FakeSystemClock;
use clock::ReplicaId;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -23,7 +23,7 @@ use language::{
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
-use language_model::RefreshLlmTokenListener;
+
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
@@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
client.cloud_client().set_credentials(1, "test".into());
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
- language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs
index df47a38062344512a784c6d2feb563e9848afb27..155fd449904687081da0a9eae3d4731863f02254 100644
--- a/crates/edit_prediction/src/mercury.rs
+++ b/crates/edit_prediction/src/mercury.rs
@@ -5,6 +5,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
+use credentials_provider::CredentialsProvider;
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Context, Entity, Global, SharedString, Task,
@@ -51,10 +52,11 @@ impl Mercury {
debug_tx,
..
}: EditPredictionModelInput,
+ credentials_provider: Arc,
cx: &mut Context,
) -> Task>> {
self.api_token.update(cx, |key_state, cx| {
- _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
+ _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
});
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
@@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity {
}
pub fn load_mercury_api_token(cx: &mut App) -> Task> {
+ let credentials_provider = zed_credentials_provider::global(cx);
mercury_api_token(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
+ key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
})
}
diff --git a/crates/edit_prediction/src/open_ai_compatible.rs b/crates/edit_prediction/src/open_ai_compatible.rs
index ca378ba1fd0bc9bdbb3e85c7610e1b94c1be388f..9a11164822857d78c2fe0d9245faeb5d4f7400a0 100644
--- a/crates/edit_prediction/src/open_ai_compatible.rs
+++ b/crates/edit_prediction/src/open_ai_compatible.rs
@@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity {
pub fn load_open_ai_compatible_api_token(
cx: &mut App,
) -> Task> {
+ let credentials_provider = zed_credentials_provider::global(cx);
let api_url = open_ai_compatible_api_url(cx);
open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
- key_state.load_if_needed(api_url, |s| s, cx)
+ key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}
diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs
index 3a204a7052f8a41d6e7c2c49860b62f588358644..48b7381020f48d868d9f6413ef343b30718e5be6 100644
--- a/crates/edit_prediction_cli/src/headless.rs
+++ b/crates/edit_prediction_cli/src/headless.rs
@@ -1,4 +1,4 @@
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
diff --git a/crates/env_var/Cargo.toml b/crates/env_var/Cargo.toml
new file mode 100644
index 0000000000000000000000000000000000000000..2cbbd08c7833d3e57a09766d42ffffe35c620a93
--- /dev/null
+++ b/crates/env_var/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "env_var"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/env_var.rs"
+
+[dependencies]
+gpui.workspace = true
diff --git a/crates/env_var/LICENSE-GPL b/crates/env_var/LICENSE-GPL
new file mode 120000
index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4
--- /dev/null
+++ b/crates/env_var/LICENSE-GPL
@@ -0,0 +1 @@
+../../LICENSE-GPL
\ No newline at end of file
diff --git a/crates/env_var/src/env_var.rs b/crates/env_var/src/env_var.rs
new file mode 100644
index 0000000000000000000000000000000000000000..79f671e0147ebfaad4ab76a123cc477dc7e55cb7
--- /dev/null
+++ b/crates/env_var/src/env_var.rs
@@ -0,0 +1,40 @@
+use gpui::SharedString;
+
+#[derive(Clone)]
+pub struct EnvVar {
+ pub name: SharedString,
+ /// Value of the environment variable. Also `None` when set to an empty string.
+ pub value: Option,
+}
+
+impl EnvVar {
+ pub fn new(name: SharedString) -> Self {
+ let value = std::env::var(name.as_str()).ok();
+ if value.as_ref().is_some_and(|v| v.is_empty()) {
+ Self { name, value: None }
+ } else {
+ Self { name, value }
+ }
+ }
+
+ pub fn or(self, other: EnvVar) -> EnvVar {
+ if self.value.is_some() { self } else { other }
+ }
+}
+
+/// Creates a `LazyLock` expression for use in a `static` declaration.
+#[macro_export]
+macro_rules! env_var {
+ ($name:expr) => {
+ ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
+ };
+}
+
+/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the
+/// environment variable exists and is non-empty.
+#[macro_export]
+macro_rules! bool_env_var {
+ ($name:expr) => {
+ ::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
+ };
+}
diff --git a/crates/eval_cli/src/headless.rs b/crates/eval_cli/src/headless.rs
index 72feaacbae270224240f1da9e6e6c1008ba97c84..0ddd99e8f8abd9dbd73e1d7461526f3e7cb24f11 100644
--- a/crates/eval_cli/src/headless.rs
+++ b/crates/eval_cli/src/headless.rs
@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::sync::Arc;
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(user_store.clone(), client.clone(), cx);
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs
index 4d477aa4b393ee8b04829833324cd9092c2a04cd..54dc96ad37f8e51a1074a0a32976f8236cb1a0ed 100644
--- a/crates/feature_flags/src/flags.rs
+++ b/crates/feature_flags/src/flags.rs
@@ -47,12 +47,6 @@ impl FeatureFlag for DiffReviewFeatureFlag {
}
}
-pub struct GitGraphFeatureFlag;
-
-impl FeatureFlag for GitGraphFeatureFlag {
- const NAME: &'static str = "git-graph";
-}
-
pub struct StreamingEditFileToolFeatureFlag;
impl FeatureFlag for StreamingEditFileToolFeatureFlag {
diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs
index 12ad38056ee5e9886609ad993f842061e338f158..f6e30b8597d24632ac632160d7b407abd8d7f26e 100644
--- a/crates/fs/src/fake_git_repo.rs
+++ b/crates/fs/src/fake_git_repo.rs
@@ -1135,10 +1135,88 @@ impl GitRepository for FakeGitRepository {
fn diff_checkpoints(
&self,
- _base_checkpoint: GitRepositoryCheckpoint,
- _target_checkpoint: GitRepositoryCheckpoint,
+ base_checkpoint: GitRepositoryCheckpoint,
+ target_checkpoint: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result> {
- unimplemented!()
+ let executor = self.executor.clone();
+ let checkpoints = self.checkpoints.clone();
+ async move {
+ executor.simulate_random_delay().await;
+ let checkpoints = checkpoints.lock();
+ let base = checkpoints
+ .get(&base_checkpoint.commit_sha)
+ .context(format!(
+ "invalid base checkpoint: {}",
+ base_checkpoint.commit_sha
+ ))?;
+ let target = checkpoints
+ .get(&target_checkpoint.commit_sha)
+ .context(format!(
+ "invalid target checkpoint: {}",
+ target_checkpoint.commit_sha
+ ))?;
+
+ fn collect_files(
+ entry: &FakeFsEntry,
+ prefix: String,
+ out: &mut std::collections::BTreeMap,
+ ) {
+ match entry {
+ FakeFsEntry::File { content, .. } => {
+ out.insert(prefix, String::from_utf8_lossy(content).into_owned());
+ }
+ FakeFsEntry::Dir { entries, .. } => {
+ for (name, child) in entries {
+ let path = if prefix.is_empty() {
+ name.clone()
+ } else {
+ format!("{prefix}/{name}")
+ };
+ collect_files(child, path, out);
+ }
+ }
+ FakeFsEntry::Symlink { .. } => {}
+ }
+ }
+
+ let mut base_files = std::collections::BTreeMap::new();
+ let mut target_files = std::collections::BTreeMap::new();
+ collect_files(base, String::new(), &mut base_files);
+ collect_files(target, String::new(), &mut target_files);
+
+ let all_paths: std::collections::BTreeSet<&String> =
+ base_files.keys().chain(target_files.keys()).collect();
+
+ let mut diff = String::new();
+ for path in all_paths {
+ match (base_files.get(path), target_files.get(path)) {
+ (Some(base_content), Some(target_content))
+ if base_content != target_content =>
+ {
+ diff.push_str(&format!("diff --git a/{path} b/{path}\n"));
+ diff.push_str(&format!("--- a/{path}\n"));
+ diff.push_str(&format!("+++ b/{path}\n"));
+ for line in base_content.lines() {
+ diff.push_str(&format!("-{line}\n"));
+ }
+ for line in target_content.lines() {
+ diff.push_str(&format!("+{line}\n"));
+ }
+ }
+ (Some(_), None) => {
+ diff.push_str(&format!("diff --git a/{path} /dev/null\n"));
+ diff.push_str("deleted file\n");
+ }
+ (None, Some(_)) => {
+ diff.push_str(&format!("diff --git /dev/null b/{path}\n"));
+ diff.push_str("new file\n");
+ }
+ _ => {}
+ }
+ }
+ Ok(diff)
+ }
+ .boxed()
}
fn default_branch(
diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs
index 1d7c64a4d07e05bbb16a2864cea1a37248ed5f51..f4192a22bb42f88f8769ef59f817b2bf2a288fb9 100644
--- a/crates/fs/tests/integration/fake_git_repo.rs
+++ b/crates/fs/tests/integration/fake_git_repo.rs
@@ -159,7 +159,10 @@ async fn test_checkpoints(executor: BackgroundExecutor) {
.unwrap()
);
- repository.restore_checkpoint(checkpoint_1).await.unwrap();
+ repository
+ .restore_checkpoint(checkpoint_1.clone())
+ .await
+ .unwrap();
assert_eq!(
fs.files_with_contents(Path::new("")),
[
@@ -168,4 +171,22 @@ async fn test_checkpoints(executor: BackgroundExecutor) {
(Path::new(path!("/foo/b")).into(), b"ipsum".into())
]
);
+
+ // diff_checkpoints: identical checkpoints produce empty diff
+ let diff = repository
+ .diff_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
+ .await
+ .unwrap();
+ assert!(
+ diff.is_empty(),
+ "identical checkpoints should produce empty diff"
+ );
+
+ // diff_checkpoints: different checkpoints produce non-empty diff
+ let diff = repository
+ .diff_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
+ .await
+ .unwrap();
+ assert!(diff.contains("b"), "diff should mention changed file 'b'");
+ assert!(diff.contains("c"), "diff should mention added file 'c'");
}
diff --git a/crates/git_graph/Cargo.toml b/crates/git_graph/Cargo.toml
index cc3374a85932435d010daabdfe0e4b4eef628de6..e9e31a8361e367275c994e125ae6e04cbd652fc3 100644
--- a/crates/git_graph/Cargo.toml
+++ b/crates/git_graph/Cargo.toml
@@ -24,7 +24,6 @@ anyhow.workspace = true
collections.workspace = true
db.workspace = true
editor.workspace = true
-feature_flags.workspace = true
git.workspace = true
git_ui.workspace = true
gpui.workspace = true
diff --git a/crates/git_graph/src/git_graph.rs b/crates/git_graph/src/git_graph.rs
index d473fbbec618c6e7b309ab2ff9dc9eb5787ddc43..c56fb051b896f32ac364cd15e73ae8708498ca5a 100644
--- a/crates/git_graph/src/git_graph.rs
+++ b/crates/git_graph/src/git_graph.rs
@@ -1,6 +1,5 @@
use collections::{BTreeMap, HashMap, IndexSet};
use editor::Editor;
-use feature_flags::{FeatureFlagAppExt as _, GitGraphFeatureFlag};
use git::{
BuildCommitPermalinkParams, GitHostingProviderRegistry, GitRemote, Oid, ParsedGitRemote,
parse_git_remote_url,
@@ -42,8 +41,10 @@ use theme_settings::ThemeSettings;
use time::{OffsetDateTime, UtcOffset, format_description::BorrowedFormatItem};
use ui::{
ButtonLike, Chip, ColumnWidthConfig, CommonAnimationExt as _, ContextMenu, DiffStat, Divider,
- HighlightedLabel, RedistributableColumnsState, ScrollableHandle, Table, TableInteractionState,
- TableResizeBehavior, Tooltip, WithScrollbar, prelude::*,
+ HeaderResizeInfo, HighlightedLabel, RedistributableColumnsState, ScrollableHandle, Table,
+ TableInteractionState, TableRenderContext, TableResizeBehavior, Tooltip, WithScrollbar,
+ bind_redistributable_columns, prelude::*, render_redistributable_columns_resize_handles,
+ render_table_header, table_row::TableRow,
};
use workspace::{
Workspace,
@@ -730,8 +731,7 @@ pub fn init(cx: &mut App) {
cx.observe_new(|workspace: &mut workspace::Workspace, _, _| {
workspace.register_action_renderer(|div, workspace, _, cx| {
div.when(
- workspace.project().read(cx).active_repository(cx).is_some()
- && cx.has_flag::(),
+ workspace.project().read(cx).active_repository(cx).is_some(),
|div| {
let workspace = workspace.weak_handle();
@@ -901,9 +901,8 @@ pub struct GitGraph {
context_menu: Option<(Entity, Point, Subscription)>,
row_height: Pixels,
table_interaction_state: Entity,
- table_column_widths: Entity,
+ column_widths: Entity,
horizontal_scroll_offset: Pixels,
- graph_viewport_width: Pixels,
selected_entry_idx: Option,
hovered_entry_idx: Option,
graph_canvas_bounds: Rc| >>>,
@@ -933,8 +932,60 @@ impl GitGraph {
font_size + px(12.0)
}
- fn graph_content_width(&self) -> Pixels {
- (LANE_WIDTH * self.graph_data.max_lanes.min(8) as f32) + LEFT_PADDING * 2.0
+ fn graph_canvas_content_width(&self) -> Pixels {
+ (LANE_WIDTH * self.graph_data.max_lanes.max(6) as f32) + LEFT_PADDING * 2.0
+ }
+
+ fn preview_column_fractions(&self, window: &Window, cx: &App) -> [f32; 5] {
+ let fractions = self
+ .column_widths
+ .read(cx)
+ .preview_fractions(window.rem_size());
+ [
+ fractions[0],
+ fractions[1],
+ fractions[2],
+ fractions[3],
+ fractions[4],
+ ]
+ }
+
+ fn table_column_width_config(&self, window: &Window, cx: &App) -> ColumnWidthConfig {
+ let [_, description, date, author, commit] = self.preview_column_fractions(window, cx);
+ let table_total = description + date + author + commit;
+
+ let widths = if table_total > 0.0 {
+ vec![
+ DefiniteLength::Fraction(description / table_total),
+ DefiniteLength::Fraction(date / table_total),
+ DefiniteLength::Fraction(author / table_total),
+ DefiniteLength::Fraction(commit / table_total),
+ ]
+ } else {
+ vec![
+ DefiniteLength::Fraction(0.25),
+ DefiniteLength::Fraction(0.25),
+ DefiniteLength::Fraction(0.25),
+ DefiniteLength::Fraction(0.25),
+ ]
+ };
+
+ ColumnWidthConfig::explicit(widths)
+ }
+
+ fn graph_viewport_width(&self, window: &Window, cx: &App) -> Pixels {
+ self.column_widths
+ .read(cx)
+ .preview_column_width(0, window)
+ .unwrap_or_else(|| self.graph_canvas_content_width())
+ }
+
+ fn clamp_horizontal_scroll_offset(&mut self, graph_viewport_width: Pixels) {
+ let max_horizontal_scroll =
+ (self.graph_canvas_content_width() - graph_viewport_width).max(px(0.));
+ self.horizontal_scroll_offset = self
+ .horizontal_scroll_offset
+ .clamp(px(0.), max_horizontal_scroll);
}
pub fn new(
@@ -972,20 +1023,22 @@ impl GitGraph {
});
let table_interaction_state = cx.new(|cx| TableInteractionState::new(cx));
- let table_column_widths = cx.new(|_cx| {
+ let column_widths = cx.new(|_cx| {
RedistributableColumnsState::new(
- 4,
+ 5,
vec![
- DefiniteLength::Fraction(0.72),
- DefiniteLength::Fraction(0.12),
- DefiniteLength::Fraction(0.10),
- DefiniteLength::Fraction(0.06),
+ DefiniteLength::Fraction(0.14),
+ DefiniteLength::Fraction(0.6192),
+ DefiniteLength::Fraction(0.1032),
+ DefiniteLength::Fraction(0.086),
+ DefiniteLength::Fraction(0.0516),
],
vec![
TableResizeBehavior::Resizable,
TableResizeBehavior::Resizable,
TableResizeBehavior::Resizable,
TableResizeBehavior::Resizable,
+ TableResizeBehavior::Resizable,
],
)
});
@@ -1020,9 +1073,8 @@ impl GitGraph {
context_menu: None,
row_height,
table_interaction_state,
- table_column_widths,
+ column_widths,
horizontal_scroll_offset: px(0.),
- graph_viewport_width: px(88.),
selected_entry_idx: None,
hovered_entry_idx: None,
graph_canvas_bounds: Rc::new(Cell::new(None)),
@@ -2089,8 +2141,12 @@ impl GitGraph {
let vertical_scroll_offset = scroll_offset_y - (first_visible_row as f32 * row_height);
let horizontal_scroll_offset = self.horizontal_scroll_offset;
- let max_lanes = self.graph_data.max_lanes.max(6);
- let graph_width = LANE_WIDTH * max_lanes as f32 + LEFT_PADDING * 2.0;
+ let graph_viewport_width = self.graph_viewport_width(window, cx);
+ let graph_width = if self.graph_canvas_content_width() > graph_viewport_width {
+ self.graph_canvas_content_width()
+ } else {
+ graph_viewport_width
+ };
let last_visible_row =
first_visible_row + (viewport_height / row_height).ceil() as usize + 1;
@@ -2414,9 +2470,9 @@ impl GitGraph {
let new_y = (current_offset.y + delta.y).clamp(max_vertical_scroll, px(0.));
let new_offset = Point::new(current_offset.x, new_y);
- let max_lanes = self.graph_data.max_lanes.max(1);
- let graph_content_width = LANE_WIDTH * max_lanes as f32 + LEFT_PADDING * 2.0;
- let max_horizontal_scroll = (graph_content_width - self.graph_viewport_width).max(px(0.));
+ let graph_viewport_width = self.graph_viewport_width(window, cx);
+ let max_horizontal_scroll =
+ (self.graph_canvas_content_width() - graph_viewport_width).max(px(0.));
let new_horizontal_offset =
(self.horizontal_scroll_offset - delta.x).clamp(px(0.), max_horizontal_scroll);
@@ -2497,6 +2553,8 @@ impl Render for GitGraph {
cx,
);
self.graph_data.add_commits(&commits);
+ let graph_viewport_width = self.graph_viewport_width(window, cx);
+ self.clamp_horizontal_scroll_offset(graph_viewport_width);
(commits.len(), is_loading)
})
} else {
@@ -2527,118 +2585,202 @@ impl Render for GitGraph {
this.child(self.render_loading_spinner(cx))
})
} else {
- div()
+ let header_resize_info = HeaderResizeInfo::from_state(&self.column_widths, cx);
+ let header_context = TableRenderContext::for_column_widths(
+ Some(self.column_widths.read(cx).widths_to_render()),
+ true,
+ );
+ let [
+ graph_fraction,
+ description_fraction,
+ date_fraction,
+ author_fraction,
+ commit_fraction,
+ ] = self.preview_column_fractions(window, cx);
+ let table_fraction =
+ description_fraction + date_fraction + author_fraction + commit_fraction;
+ let table_width_config = self.table_column_width_config(window, cx);
+ let graph_viewport_width = self.graph_viewport_width(window, cx);
+ self.clamp_horizontal_scroll_offset(graph_viewport_width);
+
+ h_flex()
.size_full()
- .flex()
- .flex_row()
.child(
div()
- .w(self.graph_content_width())
- .h_full()
+ .flex_1()
+ .min_w_0()
+ .size_full()
.flex()
.flex_col()
- .child(
- div()
- .flex()
- .items_center()
- .px_1()
- .py_0p5()
- .border_b_1()
- .whitespace_nowrap()
- .border_color(cx.theme().colors().border)
- .child(Label::new("Graph").color(Color::Muted)),
- )
- .child(
- div()
- .id("graph-canvas")
- .flex_1()
- .overflow_hidden()
- .child(self.render_graph(window, cx))
- .on_scroll_wheel(cx.listener(Self::handle_graph_scroll))
- .on_mouse_move(cx.listener(Self::handle_graph_mouse_move))
- .on_click(cx.listener(Self::handle_graph_click))
- .on_hover(cx.listener(|this, &is_hovered: &bool, _, cx| {
- if !is_hovered && this.hovered_entry_idx.is_some() {
- this.hovered_entry_idx = None;
- cx.notify();
- }
- })),
- ),
- )
- .child({
- let row_height = self.row_height;
- let selected_entry_idx = self.selected_entry_idx;
- let hovered_entry_idx = self.hovered_entry_idx;
- let weak_self = cx.weak_entity();
- let focus_handle = self.focus_handle.clone();
- div().flex_1().size_full().child(
- Table::new(4)
- .interactable(&self.table_interaction_state)
- .hide_row_borders()
- .hide_row_hover()
- .header(vec![
- Label::new("Description")
- .color(Color::Muted)
- .into_any_element(),
- Label::new("Date").color(Color::Muted).into_any_element(),
- Label::new("Author").color(Color::Muted).into_any_element(),
- Label::new("Commit").color(Color::Muted).into_any_element(),
- ])
- .width_config(ColumnWidthConfig::redistributable(
- self.table_column_widths.clone(),
- ))
- .map_row(move |(index, row), window, cx| {
- let is_selected = selected_entry_idx == Some(index);
- let is_hovered = hovered_entry_idx == Some(index);
- let is_focused = focus_handle.is_focused(window);
- let weak = weak_self.clone();
- let weak_for_hover = weak.clone();
-
- let hover_bg = cx.theme().colors().element_hover.opacity(0.6);
- let selected_bg = if is_focused {
- cx.theme().colors().element_selected
- } else {
- cx.theme().colors().element_hover
- };
-
- row.h(row_height)
- .when(is_selected, |row| row.bg(selected_bg))
- .when(is_hovered && !is_selected, |row| row.bg(hover_bg))
- .on_hover(move |&is_hovered, _, cx| {
- weak_for_hover
- .update(cx, |this, cx| {
- if is_hovered {
- if this.hovered_entry_idx != Some(index) {
- this.hovered_entry_idx = Some(index);
- cx.notify();
- }
- } else if this.hovered_entry_idx == Some(index) {
- // Only clear if this row was the hovered one
- this.hovered_entry_idx = None;
- cx.notify();
- }
- })
- .ok();
- })
- .on_click(move |event, window, cx| {
- let click_count = event.click_count();
- weak.update(cx, |this, cx| {
- this.select_entry(index, ScrollStrategy::Center, cx);
- if click_count >= 2 {
- this.open_commit_view(index, window, cx);
- }
- })
- .ok();
- })
- .into_any_element()
- })
- .uniform_list(
- "git-graph-commits",
- commit_count,
- cx.processor(Self::render_table_rows),
+ .child(render_table_header(
+ TableRow::from_vec(
+ vec![
+ Label::new("Graph")
+ .color(Color::Muted)
+ .truncate()
+ .into_any_element(),
+ Label::new("Description")
+ .color(Color::Muted)
+ .into_any_element(),
+ Label::new("Date").color(Color::Muted).into_any_element(),
+ Label::new("Author").color(Color::Muted).into_any_element(),
+ Label::new("Commit").color(Color::Muted).into_any_element(),
+ ],
+ 5,
),
- )
- })
+ header_context,
+ Some(header_resize_info),
+ Some(self.column_widths.entity_id()),
+ cx,
+ ))
+ .child({
+ let row_height = self.row_height;
+ let selected_entry_idx = self.selected_entry_idx;
+ let hovered_entry_idx = self.hovered_entry_idx;
+ let weak_self = cx.weak_entity();
+ let focus_handle = self.focus_handle.clone();
+
+ bind_redistributable_columns(
+ div()
+ .relative()
+ .flex_1()
+ .w_full()
+ .overflow_hidden()
+ .child(
+ h_flex()
+ .size_full()
+ .child(
+ div()
+ .w(DefiniteLength::Fraction(graph_fraction))
+ .h_full()
+ .min_w_0()
+ .overflow_hidden()
+ .child(
+ div()
+ .id("graph-canvas")
+ .size_full()
+ .overflow_hidden()
+ .child(
+ div()
+ .size_full()
+ .child(self.render_graph(window, cx)),
+ )
+ .on_scroll_wheel(
+ cx.listener(Self::handle_graph_scroll),
+ )
+ .on_mouse_move(
+ cx.listener(Self::handle_graph_mouse_move),
+ )
+ .on_click(cx.listener(Self::handle_graph_click))
+ .on_hover(cx.listener(
+ |this, &is_hovered: &bool, _, cx| {
+ if !is_hovered
+ && this.hovered_entry_idx.is_some()
+ {
+ this.hovered_entry_idx = None;
+ cx.notify();
+ }
+ },
+ )),
+ ),
+ )
+ .child(
+ div()
+ .w(DefiniteLength::Fraction(table_fraction))
+ .h_full()
+ .min_w_0()
+ .child(
+ Table::new(4)
+ .interactable(&self.table_interaction_state)
+ .hide_row_borders()
+ .hide_row_hover()
+ .width_config(table_width_config)
+ .map_row(move |(index, row), window, cx| {
+ let is_selected =
+ selected_entry_idx == Some(index);
+ let is_hovered =
+ hovered_entry_idx == Some(index);
+ let is_focused =
+ focus_handle.is_focused(window);
+ let weak = weak_self.clone();
+ let weak_for_hover = weak.clone();
+
+ let hover_bg = cx
+ .theme()
+ .colors()
+ .element_hover
+ .opacity(0.6);
+ let selected_bg = if is_focused {
+ cx.theme().colors().element_selected
+ } else {
+ cx.theme().colors().element_hover
+ };
+
+ row.h(row_height)
+ .when(is_selected, |row| row.bg(selected_bg))
+ .when(
+ is_hovered && !is_selected,
+ |row| row.bg(hover_bg),
+ )
+ .on_hover(move |&is_hovered, _, cx| {
+ weak_for_hover
+ .update(cx, |this, cx| {
+ if is_hovered {
+ if this.hovered_entry_idx
+ != Some(index)
+ {
+ this.hovered_entry_idx =
+ Some(index);
+ cx.notify();
+ }
+ } else if this
+ .hovered_entry_idx
+ == Some(index)
+ {
+ this.hovered_entry_idx =
+ None;
+ cx.notify();
+ }
+ })
+ .ok();
+ })
+ .on_click(move |event, window, cx| {
+ let click_count = event.click_count();
+ weak.update(cx, |this, cx| {
+ this.select_entry(
+ index,
+ ScrollStrategy::Center,
+ cx,
+ );
+ if click_count >= 2 {
+ this.open_commit_view(
+ index,
+ window,
+ cx,
+ );
+ }
+ })
+ .ok();
+ })
+ .into_any_element()
+ })
+ .uniform_list(
+ "git-graph-commits",
+ commit_count,
+ cx.processor(Self::render_table_rows),
+ ),
+ ),
+ ),
+ )
+ .child(render_redistributable_columns_resize_handles(
+ &self.column_widths,
+ window,
+ cx,
+ )),
+ self.column_widths.clone(),
+ )
+ }),
+ )
.on_drag_move::(cx.listener(|this, event, window, cx| {
this.commit_details_split_state.update(cx, |state, cx| {
state.on_drag_move(event, window, cx);
@@ -3734,9 +3876,11 @@ mod tests {
});
cx.run_until_parked();
- git_graph.update_in(&mut *cx, |this, window, cx| {
- this.render(window, cx);
- });
+ cx.draw(
+ point(px(0.), px(0.)),
+ gpui::size(px(1200.), px(800.)),
+ |_, _| git_graph.clone().into_any_element(),
+ );
cx.run_until_parked();
let commit_count_after_switch_back =
diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml
index d95e25fbc7821d42fac4386b522c4effb9462715..e06d16708697f721d9377365223dc444ba7b08ae 100644
--- a/crates/git_ui/Cargo.toml
+++ b/crates/git_ui/Cargo.toml
@@ -27,7 +27,6 @@ db.workspace = true
editor.workspace = true
file_icons.workspace = true
futures.workspace = true
-feature_flags.workspace = true
fuzzy.workspace = true
git.workspace = true
gpui.workspace = true
diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs
index a298380336515aad24e9c55d637d392fa6898b35..aac44c7f9c6eaf6f18c72bea390c0a0b7ad1a4bd 100644
--- a/crates/git_ui/src/commit_view.rs
+++ b/crates/git_ui/src/commit_view.rs
@@ -3,7 +3,6 @@ use buffer_diff::BufferDiff;
use collections::HashMap;
use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle};
use editor::{Addon, Editor, EditorEvent, ExcerptRange, MultiBuffer, multibuffer_context_lines};
-use feature_flags::{FeatureFlagAppExt as _, GitGraphFeatureFlag};
use git::repository::{CommitDetails, CommitDiff, RepoPath, is_binary_content};
use git::status::{FileStatus, StatusCode, TrackedStatus};
use git::{
@@ -1045,21 +1044,19 @@ impl Render for CommitViewToolbar {
}),
)
.when(!is_stash, |this| {
- this.when(cx.has_flag::(), |this| {
- this.child(
- IconButton::new("show-in-git-graph", IconName::GitGraph)
- .icon_size(IconSize::Small)
- .tooltip(Tooltip::text("Show in Git Graph"))
- .on_click(move |_, window, cx| {
- window.dispatch_action(
- Box::new(crate::git_panel::OpenAtCommit {
- sha: sha_for_graph.clone(),
- }),
- cx,
- );
- }),
- )
- })
+ this.child(
+ IconButton::new("show-in-git-graph", IconName::GitGraph)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text("Show in Git Graph"))
+ .on_click(move |_, window, cx| {
+ window.dispatch_action(
+ Box::new(crate::git_panel::OpenAtCommit {
+ sha: sha_for_graph.clone(),
+ }),
+ cx,
+ );
+ }),
+ )
.children(remote_info.map(|(provider_name, url)| {
let icon = match provider_name.as_str() {
"GitHub" => IconName::Github,
diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs
index ec8569988200079877d8edc999ebb0dcd155b88c..123f04a442597db66ae00963453d009b0ee8518c 100644
--- a/crates/git_ui/src/git_panel.rs
+++ b/crates/git_ui/src/git_panel.rs
@@ -20,7 +20,6 @@ use editor::{
actions::ExpandAllDiffHunks,
};
use editor::{EditorStyle, RewrapOptions};
-use feature_flags::{FeatureFlagAppExt as _, GitGraphFeatureFlag};
use file_icons::FileIcons;
use futures::StreamExt as _;
use git::commit::ParsedCommitMessage;
@@ -4535,7 +4534,6 @@ impl GitPanel {
let commit = branch.most_recent_commit.as_ref()?.clone();
let workspace = self.workspace.clone();
let this = cx.entity();
- let can_open_git_graph = cx.has_flag::();
Some(
h_flex()
@@ -4613,18 +4611,16 @@ impl GitPanel {
),
)
})
- .when(can_open_git_graph, |this| {
- this.child(
- panel_icon_button("git-graph-button", IconName::GitGraph)
- .icon_size(IconSize::Small)
- .tooltip(|_window, cx| {
- Tooltip::for_action("Open Git Graph", &Open, cx)
- })
- .on_click(|_, window, cx| {
- window.dispatch_action(Open.boxed_clone(), cx)
- }),
- )
- }),
+ .child(
+ panel_icon_button("git-graph-button", IconName::GitGraph)
+ .icon_size(IconSize::Small)
+ .tooltip(|_window, cx| {
+ Tooltip::for_action("Open Git Graph", &Open, cx)
+ })
+ .on_click(|_, window, cx| {
+ window.dispatch_action(Open.boxed_clone(), cx)
+ }),
+ ),
),
)
}
diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs
index ed441e3b40534690d02b31109e719c60dd5802e0..b4c8e7ca9015190fb8bb1698f79f1b025bfa4829 100644
--- a/crates/gpui/src/elements/list.rs
+++ b/crates/gpui/src/elements/list.rs
@@ -427,6 +427,13 @@ impl ListState {
self.0.borrow().follow_tail
}
+ /// Returns whether the list is scrolled to the bottom (within 1px).
+ pub fn is_at_bottom(&self) -> bool {
+ let current_offset = self.scroll_px_offset_for_scrollbar().y.abs();
+ let max_offset = self.max_offset_for_scrollbar().y;
+ current_offset >= max_offset - px(1.0)
+ }
+
/// Scroll the list to the given offset
pub fn scroll_to(&self, mut scroll_top: ListOffset) {
let state = &mut *self.0.borrow_mut();
diff --git a/crates/grammars/src/javascript/highlights.scm b/crates/grammars/src/javascript/highlights.scm
index 4af87cc578e3060e72d1e1374f4904d8c7629ddf..f6354dd3a016f544e5be1616c3dfb12144855775 100644
--- a/crates/grammars/src/javascript/highlights.scm
+++ b/crates/grammars/src/javascript/highlights.scm
@@ -328,26 +328,26 @@
; JSX elements
(jsx_opening_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_closing_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_self_closing_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_opening_element
diff --git a/crates/grammars/src/tsx/highlights.scm b/crates/grammars/src/tsx/highlights.scm
index 482bba7f081a44b78a2f2d72c3435d8a6419b874..0f203e7112cf14268d0edfed39b5624375d1a859 100644
--- a/crates/grammars/src/tsx/highlights.scm
+++ b/crates/grammars/src/tsx/highlights.scm
@@ -389,26 +389,26 @@
(jsx_opening_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_closing_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_self_closing_element
[
- (identifier) @type
+ (identifier) @type @tag.component.jsx
(member_expression
- object: (identifier) @type
- property: (property_identifier) @type)
+ object: (identifier) @type @tag.component.jsx
+ property: (property_identifier) @type @tag.component.jsx)
])
(jsx_opening_element
diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs
index 6929ae4e4ca8ca0ee00c9793c948892043dd6dd6..e29b7d3593025556771d62dc0124786672c540de 100644
--- a/crates/icons/src/icons.rs
+++ b/crates/icons/src/icons.rs
@@ -95,6 +95,7 @@ pub enum IconName {
DebugStepOver,
Diff,
DiffSplit,
+ DiffSplitAuto,
DiffUnified,
Disconnected,
Download,
diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml
index 911100fc25b498ba5471c85d6177052495974665..4712d86dff6c44f9cdd8576a08349ccfa7d0ecca 100644
--- a/crates/language_model/Cargo.toml
+++ b/crates/language_model/Cargo.toml
@@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
-client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
+env_var.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
@@ -40,7 +40,6 @@ serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true
-zed_env_vars.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
diff --git a/crates/language_model/src/api_key.rs b/crates/language_model/src/api_key.rs
index 754fde069295d8799820020bef286b1a1a3c590c..4be5a64d3db6231c98b830a524d5e299faace457 100644
--- a/crates/language_model/src/api_key.rs
+++ b/crates/language_model/src/api_key.rs
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
+use env_var::EnvVar;
use futures::{FutureExt, future};
use gpui::{AsyncApp, Context, SharedString, Task};
use std::{
@@ -7,7 +8,6 @@ use std::{
sync::Arc,
};
use util::ResultExt as _;
-use zed_env_vars::EnvVar;
use crate::AuthenticateError;
@@ -101,6 +101,7 @@ impl ApiKeyState {
url: SharedString,
key: Option,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ provider: Arc,
cx: &Context,
) -> Task> {
if self.is_from_env_var() {
@@ -108,18 +109,14 @@ impl ApiKeyState {
"bug: attempted to store API key in system keychain when API key is from env var",
)));
}
- let credentials_provider = ::global(cx);
cx.spawn(async move |ent, cx| {
if let Some(key) = &key {
- credentials_provider
+ provider
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
.await
.log_err();
} else {
- credentials_provider
- .delete_credentials(&url, cx)
- .await
- .log_err();
+ provider.delete_credentials(&url, cx).await.log_err();
}
ent.update(cx, |ent, cx| {
let this = get_this(ent);
@@ -144,12 +141,13 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ provider: Arc,
cx: &mut Context,
) {
if url != self.url {
if !self.is_from_env_var() {
// loading will continue even though this result task is dropped
- let _task = self.load_if_needed(url, get_this, cx);
+ let _task = self.load_if_needed(url, get_this, provider, cx);
}
}
}
@@ -163,6 +161,7 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ provider: Arc,
cx: &mut Context,
) -> Task> {
if let LoadStatus::Loaded { .. } = &self.load_status
@@ -185,7 +184,7 @@ impl ApiKeyState {
let task = if let Some(load_task) = &self.load_task {
load_task.clone()
} else {
- let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
+ let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
self.url = url;
self.load_status = LoadStatus::NotPresent;
self.load_task = Some(load_task.clone());
@@ -206,14 +205,13 @@ impl ApiKeyState {
fn load(
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ provider: Arc,
cx: &Context,
) -> Task<()> {
- let credentials_provider = ::global(cx);
cx.spawn({
async move |ent, cx| {
let load_status =
- ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
- .await;
+ ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
ent.update(cx, |ent, cx| {
let this = get_this(ent);
this.url = url;
diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs
index ce71cee6bcaf4f7ea1e210cc3756bd3162715f55..3f309b7b1d4152c54324efaaf0ad3bdb7035eea4 100644
--- a/crates/language_model/src/language_model.rs
+++ b/crates/language_model/src/language_model.rs
@@ -11,12 +11,10 @@ pub mod tool_schema;
pub mod fake_provider;
use anyhow::{Result, anyhow};
-use client::Client;
-use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
@@ -36,15 +34,10 @@ pub use crate::registry::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
+pub use env_var::{EnvVar, env_var};
pub use provider::*;
-pub use zed_env_vars::{EnvVar, env_var};
-pub fn init(user_store: Entity, client: Arc, cx: &mut App) {
- init_settings(cx);
- RefreshLlmTokenListener::register(client, user_store, cx);
-}
-
-pub fn init_settings(cx: &mut App) {
+pub fn init(cx: &mut App) {
registry::init(cx);
}
diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs
index a1362d78292082522f4e883efe42b2ca1e0a0300..db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c 100644
--- a/crates/language_model/src/model/cloud_model.rs
+++ b/crates/language_model/src/model/cloud_model.rs
@@ -1,16 +1,9 @@
use std::fmt;
use std::sync::Arc;
-use anyhow::{Context as _, Result};
-use client::Client;
-use client::UserStore;
use cloud_api_client::ClientApiError;
+use cloud_api_client::CloudApiClient;
use cloud_api_types::OrganizationId;
-use cloud_api_types::websocket_protocol::MessageToClient;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{
- App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
-};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc>>);
impl LlmApiToken {
- pub fn global(cx: &App) -> Self {
- RefreshLlmTokenListener::global(cx)
- .read(cx)
- .llm_api_token
- .clone()
- }
-
pub async fn acquire(
&self,
- client: &Arc,
+ client: &CloudApiClient,
+ system_id: Option,
organization_id: Option,
- ) -> Result {
+ ) -> Result {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
@@ -49,6 +36,7 @@ impl LlmApiToken {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
+ system_id,
organization_id,
)
.await
@@ -57,10 +45,11 @@ impl LlmApiToken {
pub async fn refresh(
&self,
- client: &Arc,
+ client: &CloudApiClient,
+ system_id: Option,
organization_id: Option,
- ) -> Result {
- Self::fetch(self.0.write().await, client, organization_id).await
+ ) -> Result {
+ Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
@@ -69,28 +58,22 @@ impl LlmApiToken {
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
- client: &Arc,
+ client: &CloudApiClient,
+ system_id: Option,
organization_id: Option,
- ) -> Result {
+ ) -> Result {
let mut lock = self.0.write().await;
*lock = None;
- Self::fetch(lock, client, organization_id).await
+ Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option>,
- client: &Arc,
+ client: &CloudApiClient,
+ system_id: Option,
organization_id: Option,
- ) -> Result {
- let system_id = client
- .telemetry()
- .system_id()
- .map(|system_id| system_id.to_string());
-
- let result = client
- .cloud_client()
- .create_llm_token(system_id, organization_id)
- .await;
+ ) -> Result {
+ let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@@ -98,112 +81,7 @@ impl LlmApiToken {
}
Err(err) => {
*lock = None;
- match err {
- ClientApiError::Unauthorized => {
- client.request_sign_out();
- Err(err).context("Failed to create LLM token")
- }
- ClientApiError::Other(err) => Err(err),
- }
- }
- }
- }
-}
-
-pub trait NeedsLlmTokenRefresh {
- /// Returns whether the LLM token needs to be refreshed.
- fn needs_llm_token_refresh(&self) -> bool;
-}
-
-impl NeedsLlmTokenRefresh for http_client::Response {
- fn needs_llm_token_refresh(&self) -> bool {
- self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
- || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
- }
-}
-
-enum TokenRefreshMode {
- Refresh,
- ClearAndRefresh,
-}
-
-struct GlobalRefreshLlmTokenListener(Entity);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct LlmTokenRefreshedEvent;
-
-pub struct RefreshLlmTokenListener {
- client: Arc,
- user_store: Entity,
- llm_api_token: LlmApiToken,
- _subscription: Subscription,
-}
-
-impl EventEmitter for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
- pub fn register(client: Arc, user_store: Entity, cx: &mut App) {
- let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
- cx.set_global(GlobalRefreshLlmTokenListener(listener));
- }
-
- pub fn global(cx: &App) -> Entity {
- GlobalRefreshLlmTokenListener::global(cx).0.clone()
- }
-
- fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self {
- client.add_message_to_client_handler({
- let this = cx.weak_entity();
- move |message, cx| {
- if let Some(this) = this.upgrade() {
- Self::handle_refresh_llm_token(this, message, cx);
- }
- }
- });
-
- let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
- if matches!(event, client::user::Event::OrganizationChanged) {
- this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
- }
- });
-
- Self {
- client,
- user_store,
- llm_api_token: LlmApiToken::default(),
- _subscription: subscription,
- }
- }
-
- fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = self
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |this, cx| {
- match mode {
- TokenRefreshMode::Refresh => {
- llm_api_token.refresh(&client, organization_id).await?;
- }
- TokenRefreshMode::ClearAndRefresh => {
- llm_api_token
- .clear_and_refresh(&client, organization_id)
- .await?;
- }
- }
- this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
- })
- .detach_and_log_err(cx);
- }
-
- fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) {
- match message {
- MessageToClient::UserUpdated => {
- this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+ Err(err)
}
}
}
diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs
index 4db1db8fa6ce5afb9d77a6685bfc0861d0fb8885..3154db91a43d1381f5b3f122a724be249adeb79b 100644
--- a/crates/language_models/src/language_models.rs
+++ b/crates/language_models/src/language_models.rs
@@ -3,6 +3,7 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
+use credentials_provider::CredentialsProvider;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
@@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity, client: Arc, cx: &mut App) {
+ let credentials_provider = client.credentials_provider();
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_language_model_providers(registry, user_store, client.clone(), cx);
+ register_language_model_providers(
+ registry,
+ user_store,
+ client.clone(),
+ credentials_provider.clone(),
+ cx,
+ );
});
// Subscribe to extension store events to track LLM extension installations
@@ -104,6 +112,7 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) {
&HashSet::default(),
&openai_compatible_providers,
client.clone(),
+ credentials_provider.clone(),
cx,
);
});
@@ -124,6 +133,7 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) {
&openai_compatible_providers,
&openai_compatible_providers_new,
client.clone(),
+ credentials_provider.clone(),
cx,
);
});
@@ -138,6 +148,7 @@ fn register_openai_compatible_providers(
old: &HashSet>,
new: &HashSet>,
client: Arc,
+ credentials_provider: Arc,
cx: &mut Context,
) {
for provider_id in old {
@@ -152,6 +163,7 @@ fn register_openai_compatible_providers(
Arc::new(OpenAiCompatibleLanguageModelProvider::new(
provider_id.clone(),
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
@@ -164,6 +176,7 @@ fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity,
client: Arc,
+ credentials_provider: Arc,
cx: &mut Context,
) {
registry.register_provider(
@@ -177,62 +190,105 @@ fn register_language_model_providers(
registry.register_provider(
Arc::new(AnthropicLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OpenAiLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OllamaLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(LmStudioLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(DeepSeekLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(GoogleLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- MistralLanguageModelProvider::global(client.http_client(), cx),
+ MistralLanguageModelProvider::global(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ ),
cx,
);
registry.register_provider(
- Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(BedrockLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
Arc::new(OpenRouterLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(VercelLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
Arc::new(VercelAiGatewayLanguageModelProvider::new(
client.http_client(),
+ credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
- Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(XAiLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider.clone(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
+ Arc::new(OpenCodeLanguageModelProvider::new(
+ client.http_client(),
+ credentials_provider,
+ cx,
+ )),
cx,
);
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs
index a98a0ce142dfdbaaaddc056ab378455a45147830..c1b8bc1a3bb1b602b67ae5563d8acc3b05a94d47 100644
--- a/crates/language_models/src/provider/anthropic.rs
+++ b/crates/language_models/src/provider/anthropic.rs
@@ -6,6 +6,7 @@ use anthropic::{
};
use anyhow::Result;
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
@@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME);
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
}
impl State {
@@ -59,30 +61,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl AnthropicLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs
index f53f145dbd387aa948b977d854ba77f1cbe49ded..4320763e2c5c6de7f3fe9238d7a4991565c3bfcd 100644
--- a/crates/language_models/src/provider/bedrock.rs
+++ b/crates/language_models/src/provider/bedrock.rs
@@ -195,12 +195,13 @@ pub struct State {
settings: Option,
/// Whether credentials came from environment variables (only relevant for static credentials)
credentials_from_env: bool,
+ credentials_provider: Arc,
_subscription: Subscription,
}
impl State {
fn reset_auth(&self, cx: &mut Context) -> Task> {
- let credentials_provider = ::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(AMAZON_AWS_URL, cx)
@@ -220,7 +221,7 @@ impl State {
cx: &mut Context,
) -> Task> {
let auth = credentials.clone().into_auth();
- let credentials_provider = ::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(
@@ -287,7 +288,7 @@ impl State {
&self,
cx: &mut Context,
) -> Task> {
- let credentials_provider = ::global(cx);
+ let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
// Try environment variables first
let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
@@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider {
}
impl BedrockLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| State {
auth: None,
settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
credentials_from_env: false,
+ credentials_provider,
_subscription: cx.observe_global::(|_, cx| {
cx.notify();
}),
diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs
index f9372a4d7ea9c078c58f633cc58bd5597ef49212..29623cc998ad0fe933e9a29c45c651f7be010b07 100644
--- a/crates/language_models/src/provider/cloud.rs
+++ b/crates/language_models/src/provider/cloud.rs
@@ -1,7 +1,9 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
-use client::{Client, UserStore, zed_urls};
+use client::{
+ Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
+};
use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@@ -24,10 +26,9 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
- OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
- RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
- ZED_CLOUD_PROVIDER_NAME,
+ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
+ OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+ ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
};
use release_channel::AppVersion;
use schemars::JsonSchema;
@@ -111,7 +112,7 @@ impl State {
cx: &mut Context,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let llm_api_token = LlmApiToken::global(cx);
+ let llm_api_token = global_llm_token(cx);
Self {
client: client.clone(),
llm_api_token,
@@ -226,7 +227,9 @@ impl State {
organization_id: Option,
) -> Result {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -414,8 +417,8 @@ impl CloudLanguageModel {
) -> Result {
let http_client = &client.http_client();
- let mut token = llm_api_token
- .acquire(&client, organization_id.clone())
+ let mut token = client
+ .acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
let mut refreshed_token = false;
@@ -447,8 +450,8 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
- token = llm_api_token
- .refresh(&client, organization_id.clone())
+ token = client
+ .refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
refreshed_token = true;
continue;
@@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel {
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client, organization_id).await?;
+ let token = client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs
index bd2469d865fd8421d6ad31208e6a4be413c0fe14..0cfb1af425c7cb0279d98fa124a589437f1bb1a1 100644
--- a/crates/language_models/src/provider/deepseek.rs
+++ b/crates/language_models/src/provider/deepseek.rs
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use deepseek::DEEPSEEK_API_URL;
use futures::Stream;
@@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
}
impl State {
@@ -57,30 +59,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl DeepSeekLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs
index 8fdfb514ac6e872bd24968d33f2c1169401d5a9c..244f7835a85ff67f0c4826321910ea13516371cb 100644
--- a/crates/language_models/src/provider/google.rs
+++ b/crates/language_models/src/provider/google.rs
@@ -1,5 +1,6 @@
use anyhow::{Context as _, Result};
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
@@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
}
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
@@ -76,30 +78,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl GoogleLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs
index 6c8d3c6e1c50185a4b09e9afc80c688f4c8d1381..0d60fef16791087e35bac7d846b2ec99821d5470 100644
--- a/crates/language_models/src/provider/lmstudio.rs
+++ b/crates/language_models/src/provider/lmstudio.rs
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::HashMap;
+use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
@@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
http_client: Arc,
available_models: Vec,
fetch_model_task: Option>>,
@@ -64,10 +66,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
- let task = self
- .api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
self.restart_fetch_models_task(cx);
task
}
@@ -114,10 +121,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
- let _task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let _task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
if self.is_authenticated() {
return Task::ready(Ok(()));
@@ -152,16 +163,29 @@ impl State {
}
impl LmStudioLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::({
let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
move |this: &mut State, cx| {
- let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
- if &settings != new_settings {
- settings = new_settings.clone();
+ let new_settings =
+ AllLanguageModelSettings::get_global(cx).lmstudio.clone();
+ if settings != new_settings {
+ let credentials_provider = this.credentials_provider.clone();
+ let api_url = Self::api_url(cx).into();
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
+ settings = new_settings;
this.restart_fetch_models_task(cx);
cx.notify();
}
@@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider {
Self::api_url(cx).into(),
(*API_KEY_ENV_VAR).clone(),
),
+ credentials_provider,
http_client,
available_models: Default::default(),
fetch_model_task: None,
diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs
index 72f0cae2993da4efb3e19cb19ec42b186290920d..4cd1375fe50cd792a3a7bc8c85ba7b5b5af9520a 100644
--- a/crates/language_models/src/provider/mistral.rs
+++ b/crates/language_models/src/provider/mistral.rs
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
@@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
}
impl State {
@@ -51,15 +53,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
@@ -73,20 +86,30 @@ impl MistralLanguageModelProvider {
.map(|this| &this.0)
}
- pub fn global(http_client: Arc, cx: &mut App) -> Arc {
+ pub fn global(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Arc {
if let Some(this) = cx.try_global::() {
return this.0.clone();
}
let state = cx.new(|cx| {
cx.observe_global::(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs
index 551fcd55358c11bdf64bf2f27b32fa9a7f702252..49c326683a225bf73f604a584307ea1316a710c4 100644
--- a/crates/language_models/src/provider/ollama.rs
+++ b/crates/language_models/src/provider/ollama.rs
@@ -1,4 +1,5 @@
use anyhow::{Result, anyhow};
+use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
@@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
http_client: Arc,
fetched_models: Vec,
fetch_model_task: Option>>,
@@ -65,10 +67,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
self.fetched_models.clear();
cx.spawn(async move |this, cx| {
@@ -80,10 +87,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
- let task = self
- .api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
// Always try to fetch models - if no API key is needed (local Ollama), it will work
// If API key is needed and provided, it will work
@@ -157,7 +168,11 @@ impl State {
}
impl OllamaLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
@@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider {
let url_changed = last_settings.api_url != current_settings.api_url;
last_settings = current_settings.clone();
if url_changed {
+ let credentials_provider = this.credentials_provider.clone();
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
this.fetched_models.clear();
this.authenticate(cx).detach();
}
@@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider {
fetched_models: Default::default(),
fetch_model_task: None,
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
}),
};
diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs
index 9289c66b2a4c9213826d2d027555511c9746d00e..6a2313487f4a1922cdc2aa20d23ede01c4b7d158 100644
--- a/crates/language_models/src/provider/open_ai.rs
+++ b/crates/language_models/src/provider/open_ai.rs
@@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
+use credentials_provider::CredentialsProvider;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
@@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ credentials_provider: Arc,
}
impl State {
@@ -63,30 +65,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ self.api_key_state.store(
+ api_url,
+ api_key,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
fn authenticate(&mut self, cx: &mut Context) -> Task> {
+ let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ self.api_key_state.load_if_needed(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ )
}
}
impl OpenAiLanguageModelProvider {
- pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ pub fn new(
+ http_client: Arc,
+ credentials_provider: Arc,
+ cx: &mut App,
+ ) -> Self {
let state = cx.new(|cx| {
cx.observe_global::(|this: &mut State, cx| {
+ let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
- this.api_key_state
- .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ |this| &mut this.api_key_state,
+ credentials_provider,
+ cx,
+ );
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ credentials_provider,
}
});
diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs
index 87a08097782198238a5d2467af32cc66b3183664..9f63a1e1a039998c275637f3831b51474c8049ac 100644
--- a/crates/language_models/src/provider/open_ai_compatible.rs
+++ b/crates/language_models/src/provider/open_ai_compatible.rs
@@ -1,5 +1,6 @@
use anyhow::Result;
use convert_case::{Case, Casing};
+use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@@ -44,6 +45,7 @@ pub struct State {
id: Arc |