Detailed changes
@@ -263,6 +263,39 @@ jobs:
- name: steps::show_sccache_stats
run: sccache --show-stats || true
timeout-minutes: 60
+ clippy_mac_x86_64:
+ needs:
+ - orchestrate
+ if: needs.orchestrate.outputs.run_tests == 'true'
+ runs-on: namespace-profile-mac-large
+ steps:
+ - name: steps::checkout_repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
+ with:
+ clean: false
+ - name: steps::setup_cargo_config
+ run: |
+ mkdir -p ./../.cargo
+ cp ./.cargo/ci-config.toml ./../.cargo/config.toml
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@v1
+ with:
+ cache: rust
+ path: ~/.rustup
+ - name: steps::install_rustup_target
+ run: rustup target add x86_64-apple-darwin
+ - name: steps::setup_sccache
+ run: ./script/setup-sccache
+ env:
+ R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
+ R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
+ R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
+ SCCACHE_BUCKET: sccache-zed
+ - name: steps::clippy
+ run: ./script/clippy --target x86_64-apple-darwin
+ - name: steps::show_sccache_stats
+ run: sccache --show-stats || true
+ timeout-minutes: 60
run_tests_windows:
needs:
- orchestrate
@@ -731,6 +764,7 @@ jobs:
- clippy_windows
- clippy_linux
- clippy_mac
+ - clippy_mac_x86_64
- run_tests_windows
- run_tests_linux
- run_tests_mac
@@ -760,6 +794,7 @@ jobs:
check_result "clippy_windows" "$RESULT_CLIPPY_WINDOWS"
check_result "clippy_linux" "$RESULT_CLIPPY_LINUX"
check_result "clippy_mac" "$RESULT_CLIPPY_MAC"
+ check_result "clippy_mac_x86_64" "$RESULT_CLIPPY_MAC_X86_64"
check_result "run_tests_windows" "$RESULT_RUN_TESTS_WINDOWS"
check_result "run_tests_linux" "$RESULT_RUN_TESTS_LINUX"
check_result "run_tests_mac" "$RESULT_RUN_TESTS_MAC"
@@ -779,6 +814,7 @@ jobs:
RESULT_CLIPPY_WINDOWS: ${{ needs.clippy_windows.result }}
RESULT_CLIPPY_LINUX: ${{ needs.clippy_linux.result }}
RESULT_CLIPPY_MAC: ${{ needs.clippy_mac.result }}
+ RESULT_CLIPPY_MAC_X86_64: ${{ needs.clippy_mac_x86_64.result }}
RESULT_RUN_TESTS_WINDOWS: ${{ needs.run_tests_windows.result }}
RESULT_RUN_TESTS_LINUX: ${{ needs.run_tests_linux.result }}
RESULT_RUN_TESTS_MAC: ${{ needs.run_tests_mac.result }}
@@ -334,7 +334,6 @@ dependencies = [
"agent_settings",
"ai_onboarding",
"anyhow",
- "arrayvec",
"assistant_slash_command",
"assistant_slash_commands",
"assistant_text_thread",
@@ -363,6 +362,7 @@ dependencies = [
"git",
"gpui",
"gpui_tokio",
+ "heapless",
"html_to_markdown",
"http_client",
"image",
@@ -733,9 +733,6 @@ name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
-dependencies = [
- "serde",
-]
[[package]]
name = "as-raw-xcb-connection"
@@ -3575,6 +3572,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
+ "base64 0.22.1",
"collections",
"futures 0.3.31",
"gpui",
@@ -3583,14 +3581,17 @@ dependencies = [
"net",
"parking_lot",
"postage",
+ "rand 0.9.2",
"schemars",
"serde",
"serde_json",
"settings",
+ "sha2",
"slotmap",
"smol",
"tempfile",
"terminal",
+ "tiny_http",
"url",
"util",
]
@@ -5238,7 +5239,6 @@ version = "0.1.0"
dependencies = [
"ai_onboarding",
"anyhow",
- "arrayvec",
"brotli",
"buffer_diff",
"client",
@@ -5256,6 +5256,7 @@ dependencies = [
"fs",
"futures 0.3.31",
"gpui",
+ "heapless",
"indoc",
"itertools 0.14.0",
"language",
@@ -8027,6 +8028,15 @@ dependencies = [
"smallvec",
]
+[[package]]
+name = "hash32"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606"
+dependencies = [
+ "byteorder",
+]
+
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -8111,6 +8121,16 @@ dependencies = [
"http 0.2.12",
]
+[[package]]
+name = "heapless"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2af2455f757db2b292a9b1768c4b70186d443bcb3b316252d6b540aec1cd89ed"
+dependencies = [
+ "hash32",
+ "stable_deref_trait",
+]
+
[[package]]
name = "heck"
version = "0.3.3"
@@ -9494,6 +9514,7 @@ dependencies = [
"ollama",
"open_ai",
"open_router",
+ "opencode",
"partial-json-fixer",
"pretty_assertions",
"release_channel",
@@ -10004,6 +10025,7 @@ dependencies = [
"tokio",
"ui",
"util",
+ "webrtc-sys",
"zed-scap",
]
@@ -11644,6 +11666,20 @@ dependencies = [
"thiserror 2.0.17",
]
+[[package]]
+name = "opencode"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "google_ai",
+ "http_client",
+ "schemars",
+ "serde",
+ "serde_json",
+ "strum 0.27.2",
+]
+
[[package]]
name = "opener"
version = "0.7.2"
@@ -13172,6 +13208,7 @@ dependencies = [
"clock",
"collections",
"context_server",
+ "credentials_provider",
"dap",
"encoding_rs",
"extension",
@@ -14671,10 +14708,10 @@ dependencies = [
name = "rope"
version = "0.1.0"
dependencies = [
- "arrayvec",
"criterion",
"ctor",
"gpui",
+ "heapless",
"log",
"rand 0.9.2",
"rayon",
@@ -16735,8 +16772,8 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
name = "sum_tree"
version = "0.1.0"
dependencies = [
- "arrayvec",
"ctor",
+ "heapless",
"log",
"proptest",
"rand 0.9.2",
@@ -134,6 +134,7 @@ members = [
"crates/notifications",
"crates/ollama",
"crates/onboarding",
+ "crates/opencode",
"crates/open_ai",
"crates/open_path_prompt",
"crates/open_router",
@@ -381,6 +382,7 @@ node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
ollama = { path = "crates/ollama" }
onboarding = { path = "crates/onboarding" }
+opencode = { path = "crates/opencode" }
open_ai = { path = "crates/open_ai" }
open_path_prompt = { path = "crates/open_path_prompt" }
open_router = { path = "crates/open_router", features = ["schemars"] }
@@ -480,7 +482,6 @@ aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty", rev = "9d9640d4" }
any_vec = "0.14"
anyhow = "1.0.86"
-arrayvec = { version = "0.7.4", features = ["serde"] }
ashpd = { version = "0.13", default-features = false, features = [
"async-io",
"notification",
@@ -564,6 +565,7 @@ futures-lite = "1.13"
gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "37f3c0575d379c218a9c455ee67585184e40d43f" }
git2 = { version = "0.20.1", default-features = false, features = ["vendored-libgit2"] }
globset = "0.4"
+heapless = "0.9.2"
handlebars = "4.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
@@ -779,6 +781,7 @@ wax = "0.7"
which = "6.0.0"
wasm-bindgen = "0.2.113"
web-time = "1.1.0"
+webrtc-sys = "0.3.23"
wgpu = { git = "https://github.com/zed-industries/wgpu.git", branch = "v29" }
windows-core = "0.61"
yawc = "0.2.5"
@@ -849,6 +852,7 @@ windows-capture = { git = "https://github.com/zed-industries/windows-capture.git
calloop = { git = "https://github.com/zed-industries/calloop" }
livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
libwebrtc = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
+webrtc-sys = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
[profile.dev]
split-debuginfo = "unpacked"
@@ -0,0 +1,3 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M11.2 3.2H4.8V12.8H11.2V3.2ZM14.4 16H1.6V0H14.4V16Z" fill="black"/>
+</svg>
@@ -0,0 +1,7 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M7.99567 13.0812C8.93101 13.0812 9.68925 12.3229 9.68925 11.3876C9.68925 10.4522 8.93101 9.694 7.99567 9.694C7.06033 9.694 6.30209 10.4522 6.30209 11.3876C6.30209 12.3229 7.06033 13.0812 7.99567 13.0812Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M4.61023 6.30643C5.54557 6.30643 6.30381 5.54819 6.30381 4.61286C6.30381 3.67752 5.54557 2.91928 4.61023 2.91928C3.6749 2.91928 2.91666 3.67752 2.91666 4.61286C2.91666 5.54819 3.6749 6.30643 4.61023 6.30643Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.3915 6.30643C12.3268 6.30643 13.0851 5.54819 13.0851 4.61286C13.0851 3.67752 12.3268 2.91928 11.3915 2.91928C10.4561 2.91928 9.69791 3.67752 9.69791 4.61286C9.69791 5.54819 10.4561 6.30643 11.3915 6.30643Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.3889 6.306V7.43505C11.3889 7.77377 11.1631 7.99958 10.8244 7.99958H5.17912C4.8404 7.99958 4.61459 7.77377 4.61459 7.43505V6.306" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M8 8V9.69358" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -785,11 +785,16 @@
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
"alt-l": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"alt-k": "editor::AcceptNextWordEditPrediction",
"alt-j": "editor::AcceptNextLineEditPrediction",
},
},
+ {
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
+ "bindings": {
+ "tab": "editor::AcceptEditPrediction",
+ },
+ },
{
"context": "Editor && showing_code_actions",
"bindings": {
@@ -1451,8 +1456,8 @@
{
"context": "GitPicker",
"bindings": {
- "alt-1": "git_picker::ActivateBranchesTab",
- "alt-2": "git_picker::ActivateWorktreesTab",
+ "alt-1": "git_picker::ActivateWorktreesTab",
+ "alt-2": "git_picker::ActivateBranchesTab",
"alt-3": "git_picker::ActivateStashTab",
},
},
@@ -847,11 +847,16 @@
"context": "Editor && edit_prediction",
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"ctrl-cmd-right": "editor::AcceptNextWordEditPrediction",
"ctrl-cmd-down": "editor::AcceptNextLineEditPrediction",
},
},
+ {
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
+ "bindings": {
+ "tab": "editor::AcceptEditPrediction",
+ },
+ },
{
"context": "Editor && showing_code_actions",
"use_key_equivalents": true,
@@ -1526,8 +1531,8 @@
{
"context": "GitPicker",
"bindings": {
- "cmd-1": "git_picker::ActivateBranchesTab",
- "cmd-2": "git_picker::ActivateWorktreesTab",
+ "cmd-1": "git_picker::ActivateWorktreesTab",
+ "cmd-2": "git_picker::ActivateBranchesTab",
"cmd-3": "git_picker::ActivateStashTab",
},
},
@@ -779,11 +779,17 @@
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
"alt-l": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"alt-k": "editor::AcceptNextWordEditPrediction",
"alt-j": "editor::AcceptNextLineEditPrediction",
},
},
+ {
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
+ "use_key_equivalents": true,
+ "bindings": {
+ "tab": "editor::AcceptEditPrediction",
+ },
+ },
{
"context": "Editor && showing_code_actions",
"use_key_equivalents": true,
@@ -1440,8 +1446,8 @@
{
"context": "GitPicker",
"bindings": {
- "alt-1": "git_picker::ActivateBranchesTab",
- "alt-2": "git_picker::ActivateWorktreesTab",
+ "alt-1": "git_picker::ActivateWorktreesTab",
+ "alt-2": "git_picker::ActivateBranchesTab",
"alt-3": "git_picker::ActivateStashTab",
},
},
@@ -1060,7 +1060,7 @@
},
},
{
- "context": "Editor && edit_prediction",
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
"bindings": {
// This is identical to the binding in the base keymap, but the vim bindings above to
// "vim::Tab" shadow it, so it needs to be bound again.
@@ -460,6 +460,8 @@
"show_sign_in": true,
// Whether to show the menus in the titlebar.
"show_menus": false,
+ // The layout of window control buttons in the title bar (Linux only).
+ "button_layout": "platform_default",
},
"audio": {
// Opt into the new audio system.
@@ -2245,6 +2247,9 @@
"api_url": "https://api.openai.com/v1",
},
"openai_compatible": {},
+ "opencode": {
+ "api_url": "https://opencode.ai/zen",
+ },
"open_router": {
"api_url": "https://openrouter.ai/api/v1",
},
@@ -119,6 +119,16 @@
"style": ["type"],
},
// References
+ {
+ "token_type": "parameter",
+ "token_modifiers": ["declaration"],
+ "style": ["variable.parameter"]
+ },
+ {
+ "token_type": "parameter",
+ "token_modifiers": ["definition"],
+ "style": ["variable.parameter"]
+ },
{
"token_type": "parameter",
"token_modifiers": [],
@@ -201,6 +211,11 @@
"token_modifiers": [],
"style": ["comment"],
},
+ {
+ "token_type": "string",
+ "token_modifiers": ["documentation"],
+ "style": ["string.doc"],
+ },
{
"token_type": "string",
"token_modifiers": [],
@@ -502,13 +502,15 @@ pub enum SelectedPermissionParams {
#[derive(Debug)]
pub struct SelectedPermissionOutcome {
pub option_id: acp::PermissionOptionId,
+ pub option_kind: acp::PermissionOptionKind,
pub params: Option<SelectedPermissionParams>,
}
impl SelectedPermissionOutcome {
- pub fn new(option_id: acp::PermissionOptionId) -> Self {
+ pub fn new(option_id: acp::PermissionOptionId, option_kind: acp::PermissionOptionKind) -> Self {
Self {
option_id,
+ option_kind,
params: None,
}
}
@@ -519,12 +521,6 @@ impl SelectedPermissionOutcome {
}
}
-impl From<acp::PermissionOptionId> for SelectedPermissionOutcome {
- fn from(option_id: acp::PermissionOptionId) -> Self {
- Self::new(option_id)
- }
-}
-
impl From<SelectedPermissionOutcome> for acp::SelectedPermissionOutcome {
fn from(value: SelectedPermissionOutcome) -> Self {
Self::new(value.option_id)
@@ -924,6 +920,7 @@ impl Plan {
}
acp::PlanEntryStatus::InProgress => {
stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
+ stats.pending += 1;
}
acp::PlanEntryStatus::Completed => {
stats.completed += 1;
@@ -1013,7 +1010,7 @@ pub struct AcpThread {
session_id: acp::SessionId,
work_dirs: Option<PathList>,
parent_session_id: Option<acp::SessionId>,
- title: SharedString,
+ title: Option<SharedString>,
provisional_title: Option<SharedString>,
entries: Vec<AgentThreadEntry>,
plan: Plan,
@@ -1176,7 +1173,7 @@ impl Error for LoadError {}
impl AcpThread {
pub fn new(
parent_session_id: Option<acp::SessionId>,
- title: impl Into<SharedString>,
+ title: Option<SharedString>,
work_dirs: Option<PathList>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
@@ -1203,7 +1200,7 @@ impl AcpThread {
shared_buffers: Default::default(),
entries: Default::default(),
plan: Default::default(),
- title: title.into(),
+ title,
provisional_title: None,
project,
running_turn: None,
@@ -1259,10 +1256,10 @@ impl AcpThread {
&self.project
}
- pub fn title(&self) -> SharedString {
- self.provisional_title
+ pub fn title(&self) -> Option<SharedString> {
+ self.title
.clone()
- .unwrap_or_else(|| self.title.clone())
+ .or_else(|| self.provisional_title.clone())
}
pub fn has_provisional_title(&self) -> bool {
@@ -1387,8 +1384,8 @@ impl AcpThread {
if let acp::MaybeUndefined::Value(title) = info_update.title {
let had_provisional = self.provisional_title.take().is_some();
let title: SharedString = title.into();
- if title != self.title {
- self.title = title;
+ if self.title.as_ref() != Some(&title) {
+ self.title = Some(title);
cx.emit(AcpThreadEvent::TitleUpdated);
} else if had_provisional {
cx.emit(AcpThreadEvent::TitleUpdated);
@@ -1676,8 +1673,8 @@ impl AcpThread {
pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
let had_provisional = self.provisional_title.take().is_some();
- if title != self.title {
- self.title = title.clone();
+ if self.title.as_ref() != Some(&title) {
+ self.title = Some(title.clone());
cx.emit(AcpThreadEvent::TitleUpdated);
if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
return set_title.run(title, cx);
@@ -2012,14 +2009,13 @@ impl AcpThread {
&mut self,
id: acp::ToolCallId,
outcome: SelectedPermissionOutcome,
- option_kind: acp::PermissionOptionKind,
cx: &mut Context<Self>,
) {
let Some((ix, call)) = self.tool_call_mut(&id) else {
return;
};
- let new_status = match option_kind {
+ let new_status = match outcome.option_kind {
acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
ToolCallStatus::Rejected
}
@@ -4297,7 +4293,7 @@ mod tests {
let thread = cx.new(|cx| {
AcpThread::new(
None,
- "Test",
+ None,
Some(work_dirs),
self.clone(),
project,
@@ -4999,7 +4995,7 @@ mod tests {
// Initial title is the default.
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Test");
+ assert_eq!(thread.title(), None);
});
// Setting a provisional title updates the display title.
@@ -5007,7 +5003,10 @@ mod tests {
thread.set_provisional_title("Hello, can you helpβ¦".into(), cx);
});
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Hello, can you helpβ¦");
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Hello, can you helpβ¦")
+ );
});
// The provisional title should NOT have propagated to the connection.
@@ -5024,7 +5023,10 @@ mod tests {
});
task.await.expect("set_title should succeed");
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Helping with Rust question");
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Helping with Rust question")
+ );
});
assert_eq!(
set_title_calls.borrow().as_slice(),
@@ -5088,7 +5090,10 @@ mod tests {
result.expect("session info update should succeed");
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Helping with Rust question");
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Helping with Rust question")
+ );
assert!(
!thread.has_provisional_title(),
"session info title update should clear provisional title"
@@ -477,6 +477,24 @@ impl PermissionOptionChoice {
pub fn label(&self) -> SharedString {
self.allow.name.clone().into()
}
+
+ /// Build a `SelectedPermissionOutcome` for this choice.
+ ///
+ /// If the choice carries `sub_patterns`, they are attached as
+ /// `SelectedPermissionParams::Terminal`.
+ pub fn build_outcome(&self, is_allow: bool) -> crate::SelectedPermissionOutcome {
+ let option = if is_allow { &self.allow } else { &self.deny };
+
+ let params = if !self.sub_patterns.is_empty() {
+ Some(crate::SelectedPermissionParams::Terminal {
+ patterns: self.sub_patterns.clone(),
+ })
+ } else {
+ None
+ };
+
+ crate::SelectedPermissionOutcome::new(option.option_id.clone(), option.kind).params(params)
+ }
}
/// Pairs a tool's permission pattern with its display name
@@ -548,6 +566,57 @@ impl PermissionOptions {
self.first_option_of_kind(acp::PermissionOptionKind::RejectOnce)
.map(|option| option.option_id.clone())
}
+
+ /// Build a `SelectedPermissionOutcome` for the `DropdownWithPatterns`
+ /// variant when the user has checked specific pattern indices.
+ ///
+ /// Returns `Some` with the always-allow/deny outcome when at least one
+ /// pattern is checked. Returns `None` when zero patterns are checked,
+ /// signaling that the caller should degrade to allow-once / deny-once.
+ ///
+ /// Panics (debug) or returns `None` (release) if called on a non-
+ /// `DropdownWithPatterns` variant.
+ pub fn build_outcome_for_checked_patterns(
+ &self,
+ checked_indices: &[usize],
+ is_allow: bool,
+ ) -> Option<crate::SelectedPermissionOutcome> {
+ let PermissionOptions::DropdownWithPatterns {
+ choices, patterns, ..
+ } = self
+ else {
+ debug_assert!(
+ false,
+ "build_outcome_for_checked_patterns called on non-DropdownWithPatterns"
+ );
+ return None;
+ };
+
+ let checked_patterns: Vec<String> = patterns
+ .iter()
+ .enumerate()
+ .filter(|(index, _)| checked_indices.contains(index))
+ .map(|(_, cp)| cp.pattern.clone())
+ .collect();
+
+ if checked_patterns.is_empty() {
+ return None;
+ }
+
+ // Use the first choice (the "Always" choice) as the base for the outcome.
+ let always_choice = choices.first()?;
+ let option = if is_allow {
+ &always_choice.allow
+ } else {
+ &always_choice.deny
+ };
+
+ let outcome = crate::SelectedPermissionOutcome::new(option.option_id.clone(), option.kind)
+ .params(Some(crate::SelectedPermissionParams::Terminal {
+ patterns: checked_patterns,
+ }));
+ Some(outcome)
+ }
}
#[cfg(feature = "test-support")]
@@ -665,11 +734,10 @@ mod test_support {
cx: &mut gpui::App,
) -> Entity<AcpThread> {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let thread_title = title.unwrap_or_else(|| SharedString::new_static("Test"));
let thread = cx.new(|cx| {
AcpThread::new(
None,
- thread_title,
+ title,
Some(work_dirs),
self.clone(),
project,
@@ -82,7 +82,7 @@ struct Session {
/// The ACP thread that handles protocol communication
acp_thread: Entity<acp_thread::AcpThread>,
project_id: EntityId,
- pending_save: Task<()>,
+ pending_save: Task<Result<()>>,
_subscriptions: Vec<Subscription>,
}
@@ -387,7 +387,7 @@ impl NativeAgent {
acp_thread: acp_thread.clone(),
project_id,
_subscriptions: subscriptions,
- pending_save: Task::ready(()),
+ pending_save: Task::ready(Ok(())),
},
);
@@ -662,14 +662,16 @@ impl NativeAgent {
let Some(session) = self.sessions.get(session_id) else {
return;
};
- let thread = thread.downgrade();
- let acp_thread = session.acp_thread.downgrade();
- cx.spawn(async move |_, cx| {
- let title = thread.read_with(cx, |thread, _| thread.title())?;
- let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
- task.await
- })
- .detach_and_log_err(cx);
+
+ if let Some(title) = thread.read(cx).title() {
+ let acp_thread = session.acp_thread.downgrade();
+ cx.spawn(async move |_, cx| {
+ let task =
+ acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
+ task.await
+ })
+ .detach_and_log_err(cx);
+ }
}
fn handle_thread_token_usage_updated(
@@ -727,7 +729,7 @@ impl NativeAgent {
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
- _event: &language_model::Event,
+ event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
@@ -744,7 +746,13 @@ impl NativeAgent {
thread.set_model(model, cx);
cx.notify();
}
- thread.set_summarization_model(summarization_model.clone(), cx);
+ if let Some(model) = summarization_model.clone() {
+ if thread.summarization_model().is_none()
+ || matches!(event, language_model::Event::ThreadSummaryModelChanged)
+ {
+ thread.set_summarization_model(Some(model), cx);
+ }
+ }
});
}
}
@@ -992,7 +1000,7 @@ impl NativeAgent {
let thread_store = self.thread_store.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
- return;
+ return Ok(());
};
let db_thread = db_thread.await;
database
@@ -1000,6 +1008,7 @@ impl NativeAgent {
.await
.log_err();
thread_store.update(cx, |store, cx| store.reload(cx));
+ Ok(())
});
}
@@ -1436,18 +1445,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx: &mut App,
) -> Task<Result<()>> {
self.0.update(cx, |agent, cx| {
+ let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
+ if let Some(thread) = thread {
+ agent.save_thread(thread, cx);
+ }
+
let Some(session) = agent.sessions.remove(session_id) else {
- return;
+ return Task::ready(Ok(()));
};
let project_id = session.project_id;
- agent.save_thread(session.thread, cx);
let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
if !has_remaining {
agent.projects.remove(&project_id);
}
- });
- Task::ready(Ok(()))
+
+ session.pending_save
+ })
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -2456,6 +2470,61 @@ mod internal_tests {
});
}
+ #[gpui::test]
+ async fn test_summarization_model_survives_transient_registry_clearing(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/", json!({ "a": {} })).await;
+ let project = Project::test(fs.clone(), [], cx).await;
+
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection.clone().new_session(
+ project.clone(),
+ PathList::new(&[Path::new("/a")]),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+
+ thread.read_with(cx, |thread, _| {
+ assert!(
+ thread.summarization_model().is_some(),
+ "session should have a summarization model from the test registry"
+ );
+ });
+
+ // Simulate what happens during a provider blip:
+ // update_active_language_model_from_settings calls set_default_model(None)
+ // when it can't resolve the model, clearing all fallbacks.
+ cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.set_default_model(None, cx);
+ });
+ });
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _| {
+ assert!(
+ thread.summarization_model().is_some(),
+ "summarization model should survive a transient default model clearing"
+ );
+ });
+ }
+
#[gpui::test]
async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
init_test(cx);
@@ -2767,7 +2836,9 @@ mod internal_tests {
cx.run_until_parked();
- // Set a draft prompt with rich content blocks before saving.
+ // Set a draft prompt with rich content blocks and scroll position
+ // AFTER run_until_parked, so the only save that captures these
+ // changes is the one performed by close_session itself.
let draft_blocks = vec![
acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
@@ -2782,8 +2853,6 @@ mod internal_tests {
offset_in_item: gpui::px(12.5),
}));
});
- thread.update(cx, |_thread, cx| cx.notify());
- cx.run_until_parked();
// Close the session so it can be reloaded from disk.
cx.update(|cx| connection.clone().close_session(&session_id, cx))
@@ -2849,6 +2918,87 @@ mod internal_tests {
});
}
+ #[gpui::test]
+ async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "file.txt": "hello"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+
+ let model = Arc::new(FakeLanguageModel::default());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ });
+
+ // Send a message so the thread is non-empty (empty threads aren't saved).
+ let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
+ let send = cx.foreground_executor().spawn(send);
+ cx.run_until_parked();
+
+ model.send_last_completion_stream_text_chunk("world");
+ model.end_last_completion_stream();
+ send.await.unwrap();
+ cx.run_until_parked();
+
+ // Set a draft prompt WITHOUT calling run_until_parked afterwards.
+ // This means no observe-triggered save has run for this change.
+ // The only way this data gets persisted is if close_session
+ // itself performs the save.
+ let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
+ "unsaved draft",
+ ))];
+ acp_thread.update(cx, |thread, _cx| {
+ thread.set_draft_prompt(Some(draft_blocks.clone()));
+ });
+
+ // Close the session immediately β no run_until_parked in between.
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ // Reopen and verify the draft prompt was saved.
+ let reloaded = agent
+ .update(cx, |agent, cx| {
+ agent.open_thread(session_id.clone(), project.clone(), cx)
+ })
+ .await
+ .unwrap();
+ reloaded.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.draft_prompt(),
+ Some(draft_blocks.as_slice()),
+ "close_session must save the thread; draft prompt was lost"
+ );
+ });
+ }
+
fn thread_entries(
thread_store: &Entity<ThreadStore>,
cx: &mut TestAppContext,
@@ -841,14 +841,20 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
// Approve the first - send "allow" option_id (UI transforms "once" to "allow")
tool_call_auth_1
.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
cx.run_until_parked();
// Reject the second - send "deny" option_id directly since Deny is now a button
tool_call_auth_2
.response
- .send(acp::PermissionOptionId::new("deny").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("deny"),
+ acp::PermissionOptionKind::RejectOnce,
+ ))
.unwrap();
cx.run_until_parked();
@@ -892,7 +898,10 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
tool_call_auth_3
.response
- .send(acp::PermissionOptionId::new("always_allow:tool_requiring_permission").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("always_allow:tool_requiring_permission"),
+ acp::PermissionOptionKind::AllowAlways,
+ ))
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
@@ -3122,7 +3131,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), None));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
@@ -3131,7 +3140,9 @@ async fn test_title_generation(cx: &mut TestAppContext) {
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
cx.run_until_parked();
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.title(), Some("Hello world".into()))
+ });
// Send another message, ensuring no title is generated this time.
let send = thread
@@ -3145,7 +3156,9 @@ async fn test_title_generation(cx: &mut TestAppContext) {
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.title(), Some("Hello world".into()))
+ });
}
#[gpui::test]
@@ -1312,7 +1312,7 @@ impl Thread {
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
let mut thread = DbThread {
- title: self.title(),
+ title: self.title().unwrap_or_default(),
messages: self.messages.clone(),
updated_at: self.updated_at,
detailed_summary: self.summary.clone(),
@@ -2491,8 +2491,8 @@ impl Thread {
}
}
- pub fn title(&self) -> SharedString {
- self.title.clone().unwrap_or("New Thread".into())
+ pub fn title(&self) -> Option<SharedString> {
+ self.title.clone()
}
pub fn is_generating_summary(&self) -> bool {
@@ -253,12 +253,14 @@ impl ContextServerRegistry {
let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
match status {
- ContextServerStatus::Starting => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
self.reload_prompts_for_server(server_id.clone(), cx);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(registered_server) = self.registered_servers.remove(server_id) {
if !registered_server.tools.is_empty() {
cx.emit(ContextServerRegistryEvent::ToolsChanged);
@@ -266,7 +266,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -372,7 +375,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -241,7 +241,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -359,7 +362,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -301,7 +301,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -428,7 +431,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -1374,7 +1374,10 @@ mod tests {
event
.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
authorize_task.await.unwrap();
}
@@ -848,7 +848,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -273,7 +273,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -379,7 +382,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -896,7 +896,10 @@ mod test {
);
authorization
.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = read_task.await;
@@ -1185,7 +1188,10 @@ mod test {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -523,7 +523,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let _result = task.await;
@@ -651,7 +654,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -518,7 +518,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let _result = task.await;
@@ -646,7 +649,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -727,7 +733,10 @@ mod tests {
let auth = event_rx.expect_authorization().await;
auth.response
- .send(acp::PermissionOptionId::new("deny").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("deny"),
+ acp::PermissionOptionKind::RejectOnce,
+ ))
.unwrap();
let output = task.await.unwrap();
@@ -2581,7 +2581,10 @@ mod tests {
event
.response
- .send(acp::PermissionOptionId::new("allow").into())
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
authorize_task.await.unwrap();
}
@@ -42,7 +42,6 @@ pub struct UnsupportedVersion;
pub struct AcpConnection {
id: AgentId,
- display_name: SharedString,
telemetry_id: SharedString,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
@@ -167,7 +166,6 @@ impl AgentSessionList for AcpSessionList {
pub async fn connect(
agent_id: AgentId,
project: Entity<Project>,
- display_name: SharedString,
command: AgentServerCommand,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
@@ -177,7 +175,6 @@ pub async fn connect(
let conn = AcpConnection::stdio(
agent_id,
project,
- display_name,
command.clone(),
default_mode,
default_model,
@@ -194,7 +191,6 @@ impl AcpConnection {
pub async fn stdio(
agent_id: AgentId,
project: Entity<Project>,
- display_name: SharedString,
command: AgentServerCommand,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
@@ -364,7 +360,6 @@ impl AcpConnection {
auth_methods,
command,
connection,
- display_name,
telemetry_id,
sessions,
agent_capabilities: response.agent_capabilities,
@@ -660,7 +655,7 @@ impl AgentConnection for AcpConnection {
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
- self.display_name.clone(),
+ None,
Some(work_dirs),
self.clone(),
project,
@@ -718,7 +713,6 @@ impl AgentConnection for AcpConnection {
let mcp_servers = mcp_servers_for_project(&project, cx);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let title = title.unwrap_or_else(|| self.display_name.clone());
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
@@ -801,7 +795,6 @@ impl AgentConnection for AcpConnection {
let mcp_servers = mcp_servers_for_project(&project, cx);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let title = title.unwrap_or_else(|| self.display_name.clone());
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
@@ -296,11 +296,6 @@ impl AgentServer for CustomAgentServer {
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let agent_id = self.agent_id();
- let display_name = delegate
- .store
- .read(cx)
- .agent_display_name(&agent_id)
- .unwrap_or_else(|| agent_id.0.clone());
let default_mode = self.default_mode(cx);
let default_model = self.default_model(cx);
let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
@@ -376,7 +371,6 @@ impl AgentServer for CustomAgentServer {
let connection = crate::acp::connect(
agent_id,
project,
- display_name,
command,
default_mode,
default_model,
@@ -208,8 +208,10 @@ pub async fn test_tool_call_with_permission<T, F>(
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(
tool_call_id,
- allow_option_id.into(),
- acp::PermissionOptionKind::AllowOnce,
+ acp_thread::SelectedPermissionOutcome::new(
+ allow_option_id,
+ acp::PermissionOptionKind::AllowOnce,
+ ),
cx,
);
@@ -34,7 +34,7 @@ agent_servers.workspace = true
agent_settings.workspace = true
ai_onboarding.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
+heapless.workspace = true
assistant_text_thread.workspace = true
assistant_slash_command.workspace = true
assistant_slash_commands.workspace = true
@@ -517,11 +517,7 @@ impl AgentConfiguration {
}
}
- fn render_context_servers_section(
- &mut self,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> impl IntoElement {
+ fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_server_ids = self.context_server_store.read(cx).server_ids();
let add_server_popover = PopoverMenu::new("add-server-popover")
@@ -601,7 +597,7 @@ impl AgentConfiguration {
} else {
parent.children(itertools::intersperse_with(
context_server_ids.iter().cloned().map(|context_server_id| {
- self.render_context_server(context_server_id, window, cx)
+ self.render_context_server(context_server_id, cx)
.into_any_element()
}),
|| {
@@ -618,7 +614,6 @@ impl AgentConfiguration {
fn render_context_server(
&self,
context_server_id: ContextServerId,
- window: &mut Window,
cx: &Context<Self>,
) -> impl use<> + IntoElement {
let server_status = self
@@ -646,6 +641,9 @@ impl AgentConfiguration {
} else {
None
};
+ let auth_required = matches!(server_status, ContextServerStatus::AuthRequired);
+ let authenticating = matches!(server_status, ContextServerStatus::Authenticating);
+ let context_server_store = self.context_server_store.clone();
let tool_count = self
.context_server_registry
@@ -689,11 +687,33 @@ impl AgentConfiguration {
Indicator::dot().color(Color::Muted).into_any_element(),
"Server is stopped.",
),
+ ContextServerStatus::AuthRequired => (
+ Indicator::dot().color(Color::Warning).into_any_element(),
+ "Authentication required.",
+ ),
+ ContextServerStatus::Authenticating => (
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Accent)
+ .with_keyed_rotate_animation(
+ SharedString::from(format!("{}-authenticating", context_server_id.0)),
+ 3,
+ )
+ .into_any_element(),
+ "Waiting for authorization...",
+ ),
};
+
let is_remote = server_configuration
.as_ref()
.map(|config| matches!(config.as_ref(), ContextServerConfiguration::Http { .. }))
.unwrap_or(false);
+
+ let should_show_logout_button = server_configuration.as_ref().is_some_and(|config| {
+ matches!(config.as_ref(), ContextServerConfiguration::Http { .. })
+ && !config.has_static_auth_header()
+ });
+
let context_server_configuration_menu = PopoverMenu::new("context-server-config-menu")
.trigger_with_tooltip(
IconButton::new("context-server-config-menu", IconName::Settings)
@@ -708,6 +728,7 @@ impl AgentConfiguration {
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let context_server_registry = self.context_server_registry.clone();
+ let context_server_store = context_server_store.clone();
move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
@@ -754,6 +775,17 @@ impl AgentConfiguration {
.ok();
}
}))
+ .when(should_show_logout_button, |this| {
+ this.entry("Log Out", None, {
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store.logout_server(&context_server_id, cx).log_err();
+ });
+ }
+ })
+ })
.separator()
.entry("Uninstall", None, {
let fs = fs.clone();
@@ -810,6 +842,9 @@ impl AgentConfiguration {
}
});
+ let feedback_base_container =
+ || h_flex().py_1().min_w_0().w_full().gap_1().justify_between();
+
v_flex()
.min_w_0()
.id(item_id.clone())
@@ -868,6 +903,7 @@ impl AgentConfiguration {
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();
+ let context_server_id = context_server_id.clone();
move |state, _window, cx| {
let is_enabled = match state {
@@ -915,30 +951,111 @@ impl AgentConfiguration {
)
.map(|parent| {
if let Some(error) = error {
+ return parent
+ .child(
+ feedback_base_container()
+ .child(
+ h_flex()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ Icon::new(IconName::XCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Error),
+ )
+ .child(
+ div().min_w_0().flex_1().child(
+ Label::new(error)
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ ),
+ )
+ .when(should_show_logout_button, |this| {
+ this.child(
+ Button::new("error-logout-server", "Log Out")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store =
+ context_server_store.clone();
+ let context_server_id =
+ context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(
+ cx,
+ |store, cx| {
+ store
+ .logout_server(
+ &context_server_id,
+ cx,
+ )
+ .log_err();
+ },
+ );
+ }
+ }),
+ )
+ }),
+ );
+ }
+ if auth_required {
return parent.child(
- h_flex()
- .gap_2()
- .pr_4()
- .items_start()
+ feedback_base_container()
.child(
h_flex()
- .flex_none()
- .h(window.line_height() / 1.6_f32)
- .justify_center()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
.child(
- Icon::new(IconName::XCircle)
+ Icon::new(IconName::Info)
.size(IconSize::XSmall)
- .color(Color::Error),
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
),
)
.child(
- div().w_full().child(
- Label::new(error)
- .buffer_font(cx)
- .color(Color::Muted)
- .size(LabelSize::Small),
- ),
+ Button::new("error-logout-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store
+ .authenticate_server(&context_server_id, cx)
+ .log_err();
+ });
+ }
+ }),
+ ),
+ );
+ }
+ if authenticating {
+ return parent.child(
+ h_flex()
+ .mt_1()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ div().size_3().flex_shrink_0(), // Alignment Div
+ )
+ .child(
+ Label::new("Authenticatingβ¦")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
),
+
);
}
parent
@@ -1234,7 +1351,7 @@ impl Render for AgentConfiguration {
.min_w_0()
.overflow_y_scroll()
.child(self.render_agent_servers_section(cx))
- .child(self.render_context_servers_section(window, cx))
+ .child(self.render_context_servers_section(cx))
.child(self.render_provider_configuration_section(cx)),
)
.vertical_scrollbar_for(&self.scroll_handle, window, cx),
@@ -1,25 +1,27 @@
-use std::sync::{Arc, Mutex};
-
use anyhow::{Context as _, Result};
use collections::HashMap;
use context_server::{ContextServerCommand, ContextServerId};
use editor::{Editor, EditorElement, EditorStyle};
+
use gpui::{
AsyncWindowContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle,
- Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
+ Subscription, Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
};
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
+use parking_lot::Mutex;
use project::{
context_server_store::{
- ContextServerStatus, ContextServerStore, registry::ContextServerDescriptorRegistry,
+ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent,
+ registry::ContextServerDescriptorRegistry,
},
project_settings::{ContextServerSettings, ProjectSettings},
worktree_store::WorktreeStore,
};
use serde::Deserialize;
use settings::{Settings as _, update_settings_file};
+use std::sync::Arc;
use theme::ThemeSettings;
use ui::{
CommonAnimationExt, KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip,
@@ -237,6 +239,8 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
format!(
r#"{{
+ /// Configure an MCP server that runs locally via stdin/stdout
+ ///
/// The name of your MCP server
"{name}": {{
/// The command which runs the MCP server
@@ -280,6 +284,8 @@ fn context_server_http_input(
format!(
r#"{{
+ /// Configure an MCP server that you connect to over HTTP
+ ///
/// The name of your remote MCP server
"{name}": {{
/// The URL of the remote MCP server
@@ -342,6 +348,8 @@ fn resolve_context_server_extension(
enum State {
Idle,
Waiting,
+ AuthRequired { server_id: ContextServerId },
+ Authenticating { _server_id: ContextServerId },
Error(SharedString),
}
@@ -352,6 +360,7 @@ pub struct ConfigureContextServerModal {
state: State,
original_server_id: Option<ContextServerId>,
scroll_handle: ScrollHandle,
+ _auth_subscription: Option<Subscription>,
}
impl ConfigureContextServerModal {
@@ -475,6 +484,7 @@ impl ConfigureContextServerModal {
cx,
),
scroll_handle: ScrollHandle::new(),
+ _auth_subscription: None,
})
})
})
@@ -486,6 +496,13 @@ impl ConfigureContextServerModal {
}
fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context<Self>) {
+ if matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ ) {
+ return;
+ }
+
self.state = State::Idle;
let Some(workspace) = self.workspace.upgrade() else {
return;
@@ -515,14 +532,19 @@ impl ConfigureContextServerModal {
async move |this, cx| {
let result = wait_for_context_server_task.await;
this.update(cx, |this, cx| match result {
- Ok(_) => {
+ Ok(ContextServerStatus::Running) => {
this.state = State::Idle;
this.show_configured_context_server_toast(id, cx);
cx.emit(DismissEvent);
}
+ Ok(ContextServerStatus::AuthRequired) => {
+ this.state = State::AuthRequired { server_id: id };
+ cx.notify();
+ }
Err(err) => {
this.set_error(err, cx);
}
+ Ok(_) => {}
})
}
})
@@ -558,6 +580,49 @@ impl ConfigureContextServerModal {
cx.emit(DismissEvent);
}
+ fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+ self.context_server_store.update(cx, |store, cx| {
+ store.authenticate_server(&server_id, cx).log_err();
+ });
+
+ self.state = State::Authenticating {
+ _server_id: server_id.clone(),
+ };
+
+ self._auth_subscription = Some(cx.subscribe(
+ &self.context_server_store,
+ move |this, _, event: &ServerStatusChangedEvent, cx| {
+ if event.server_id != server_id {
+ return;
+ }
+ match &event.status {
+ ContextServerStatus::Running => {
+ this._auth_subscription = None;
+ this.state = State::Idle;
+ this.show_configured_context_server_toast(event.server_id.clone(), cx);
+ cx.emit(DismissEvent);
+ }
+ ContextServerStatus::AuthRequired => {
+ this._auth_subscription = None;
+ this.state = State::AuthRequired {
+ server_id: event.server_id.clone(),
+ };
+ cx.notify();
+ }
+ ContextServerStatus::Error(error) => {
+ this._auth_subscription = None;
+ this.set_error(error.clone(), cx);
+ }
+ ContextServerStatus::Authenticating
+ | ContextServerStatus::Starting
+ | ContextServerStatus::Stopped => {}
+ }
+ },
+ ));
+
+ cx.notify();
+ }
+
fn show_configured_context_server_toast(&self, id: ContextServerId, cx: &mut App) {
self.workspace
.update(cx, {
@@ -615,7 +680,8 @@ impl ConfigureContextServerModal {
}
fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
- const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
+ const MODAL_DESCRIPTION: &str =
+ "Check the server docs for required arguments and environment variables.";
if let ConfigurationSource::Extension {
installation_instructions: Some(installation_instructions),
@@ -637,6 +703,67 @@ impl ConfigureContextServerModal {
}
}
+ fn render_tab_bar(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+ let is_http = match &self.source {
+ ConfigurationSource::New { is_http, .. } => *is_http,
+ _ => return None,
+ };
+
+ let tab = |label: &'static str, active: bool| {
+ div()
+ .id(label)
+ .cursor_pointer()
+ .p_1()
+ .text_sm()
+ .border_b_1()
+ .when(active, |this| {
+ this.border_color(cx.theme().colors().border_focused)
+ })
+ .when(!active, |this| {
+ this.border_color(gpui::transparent_black())
+ .text_color(cx.theme().colors().text_muted)
+ .hover(|s| s.text_color(cx.theme().colors().text))
+ })
+ .child(label)
+ };
+
+ Some(
+ h_flex()
+ .pt_1()
+ .mb_2p5()
+ .gap_1()
+ .border_b_1()
+ .border_color(cx.theme().colors().border.opacity(0.5))
+ .child(
+ tab("Local", !is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if *is_http {
+ *is_http = false;
+ let new_text = context_server_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .child(
+ tab("Remote", is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if !*is_http {
+ *is_http = true;
+ let new_text = context_server_http_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .into_any_element(),
+ )
+ }
+
fn render_modal_content(&self, cx: &App) -> AnyElement {
let editor = match &self.source {
ConfigurationSource::New { editor, .. } => editor,
@@ -682,7 +809,10 @@ impl ConfigureContextServerModal {
fn render_modal_footer(&self, cx: &mut Context<Self>) -> ModalFooter {
let focus_handle = self.focus_handle(cx);
- let is_connecting = matches!(self.state, State::Waiting);
+ let is_busy = matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ );
ModalFooter::new()
.start_slot::<Button>(
@@ -714,36 +844,6 @@ impl ConfigureContextServerModal {
move |_, _, cx| cx.open_url(&repository_url)
}),
)
- } else if let ConfigurationSource::New { is_http, .. } = &self.source {
- let label = if *is_http {
- "Configure Local"
- } else {
- "Configure Remote"
- };
- let tooltip = if *is_http {
- "Configure an MCP server that runs on stdin/stdout."
- } else {
- "Configure an MCP server that you connect to over HTTP"
- };
-
- Some(
- Button::new("toggle-kind", label)
- .tooltip(Tooltip::text(tooltip))
- .on_click(cx.listener(|this, _, window, cx| match &mut this.source {
- ConfigurationSource::New { editor, is_http } => {
- *is_http = !*is_http;
- let new_text = if *is_http {
- context_server_http_input(None)
- } else {
- context_server_input(None)
- };
- editor.update(cx, |editor, cx| {
- editor.set_text(new_text, window, cx);
- })
- }
- _ => {}
- })),
- )
} else {
None
},
@@ -777,7 +877,7 @@ impl ConfigureContextServerModal {
"Configure Server"
},
)
- .disabled(is_connecting)
+ .disabled(is_busy)
.key_binding(
KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx)
.map(|kb| kb.size(rems_from_px(12.))),
@@ -791,29 +891,62 @@ impl ConfigureContextServerModal {
)
}
- fn render_waiting_for_context_server() -> Div {
+ fn render_loading(&self, label: impl Into<SharedString>) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
- Icon::new(IconName::ArrowCircle)
+ Icon::new(IconName::LoadCircle)
.size(IconSize::XSmall)
- .color(Color::Info)
- .with_rotate_animation(2)
- .into_any_element(),
+ .color(Color::Muted)
+ .with_rotate_animation(3),
)
+ .child(Label::new(label).size(LabelSize::Small).color(Color::Muted))
+ }
+
+ fn render_auth_required(&self, server_id: &ContextServerId, cx: &mut Context<Self>) -> Div {
+ h_flex()
+ .h_8()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .justify_center()
.child(
- Label::new("Waiting for Context Server")
- .size(LabelSize::Small)
- .color(Color::Muted),
+ h_flex()
+ .gap_1p5()
+ .child(
+ Icon::new(IconName::Info)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ Button::new("authenticate-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let server_id = server_id.clone();
+ cx.listener(move |this, _event, _window, cx| {
+ this.authenticate(server_id.clone(), cx);
+ })
+ }),
)
}
fn render_modal_error(error: SharedString) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
Icon::new(IconName::Warning)
- .size(IconSize::XSmall)
+ .size(IconSize::Small)
.color(Color::Warning),
)
.child(
@@ -828,7 +961,7 @@ impl Render for ConfigureContextServerModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.elevation_3(cx)
- .w(rems(34.))
+ .w(rems(40.))
.key_context("ConfigureContextServerModal")
.on_action(
cx.listener(|this, _: &menu::Cancel, _window, cx| this.cancel(&menu::Cancel, cx)),
@@ -855,11 +988,18 @@ impl Render for ConfigureContextServerModal {
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
.child(self.render_modal_description(window, cx))
+ .children(self.render_tab_bar(cx))
.child(self.render_modal_content(cx))
.child(match &self.state {
State::Idle => div(),
State::Waiting => {
- Self::render_waiting_for_context_server()
+ self.render_loading("Connecting Serverβ¦")
+ }
+ State::AuthRequired { server_id } => {
+ self.render_auth_required(&server_id.clone(), cx)
+ }
+ State::Authenticating { .. } => {
+ self.render_loading("Authenticatingβ¦")
}
State::Error(error) => {
Self::render_modal_error(error.clone())
@@ -878,7 +1018,7 @@ fn wait_for_context_server(
context_server_store: &Entity<ContextServerStore>,
context_server_id: ContextServerId,
cx: &mut App,
-) -> Task<Result<(), Arc<str>>> {
+) -> Task<Result<ContextServerStatus, Arc<str>>> {
use std::time::Duration;
const WAIT_TIMEOUT: Duration = Duration::from_secs(120);
@@ -888,31 +1028,29 @@ fn wait_for_context_server(
let context_server_id_for_timeout = context_server_id.clone();
let subscription = cx.subscribe(context_server_store, move |_, event, _cx| {
- let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
+ let ServerStatusChangedEvent { server_id, status } = event;
+
+ if server_id != &context_server_id {
+ return;
+ }
match status {
- ContextServerStatus::Running => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
- let _ = tx.send(Ok(()));
+ ContextServerStatus::Running | ContextServerStatus::AuthRequired => {
+ if let Some(tx) = tx.lock().take() {
+ let _ = tx.send(Ok(status.clone()));
}
}
ContextServerStatus::Stopped => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err("Context server stopped running".into()));
}
}
ContextServerStatus::Error(error) => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err(error.clone()));
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
});
@@ -44,7 +44,6 @@ pub struct AgentDiffPane {
thread: Entity<AcpThread>,
focus_handle: FocusHandle,
workspace: WeakEntity<Workspace>,
- title: SharedString,
_subscriptions: Vec<Subscription>,
}
@@ -113,7 +112,6 @@ impl AgentDiffPane {
this.handle_acp_thread_event(event, cx)
}),
],
- title: SharedString::default(),
multibuffer,
editor,
thread,
@@ -121,7 +119,6 @@ impl AgentDiffPane {
workspace,
};
this.update_excerpts(window, cx);
- this.update_title(cx);
this
}
@@ -231,17 +228,9 @@ impl AgentDiffPane {
}
}
- fn update_title(&mut self, cx: &mut Context<Self>) {
- let new_title = self.thread.read(cx).title();
- if new_title != self.title {
- self.title = new_title;
- cx.emit(EditorEvent::TitleChanged);
- }
- }
-
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
if let AcpThreadEvent::TitleUpdated = event {
- self.update_title(cx)
+ cx.emit(EditorEvent::TitleChanged);
}
}
@@ -534,13 +523,17 @@ impl Item for AgentDiffPane {
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let title = self.thread.read(cx).title();
- Label::new(format!("Review: {}", title))
- .color(if params.selected {
- Color::Default
- } else {
- Color::Muted
- })
- .into_any_element()
+ Label::new(if let Some(title) = title {
+ format!("Review: {}", title)
+ } else {
+ "Review".to_string()
+ })
+ .color(if params.selected {
+ Color::Default
+ } else {
+ Color::Muted
+ })
+ .into_any_element()
}
fn telemetry_event_text(&self) -> Option<&'static str> {
@@ -26,7 +26,6 @@ use zed_actions::agent::{
ResolveConflictedFilesWithAgent, ResolveConflictsWithAgent, ReviewBranchDiff,
};
-use crate::ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal, HoldForDefault};
use crate::{
AddContextServer, AgentDiffPane, ConversationView, CopyThreadToClipboard, CycleStartThreadIn,
Follow, InlineAssistant, LoadThreadFromClipboard, NewTextThread, NewThread,
@@ -42,6 +41,10 @@ use crate::{
Agent, AgentInitialContent, ExternalSourcePrompt, NewExternalAgentThread,
NewNativeAgentThreadFromSummary,
};
+use crate::{
+ DEFAULT_THREAD_TITLE,
+ ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal, HoldForDefault},
+};
use crate::{
ExpandMessageEditor, ThreadHistoryView,
text_thread_history::{TextThreadHistory, TextThreadHistoryEvent},
@@ -75,8 +78,8 @@ use search::{BufferSearchBar, buffer_search};
use settings::{Settings, update_settings_file};
use theme::ThemeSettings;
use ui::{
- Button, Callout, ContextMenu, ContextMenuEntry, DocumentationSide, KeyBinding, PopoverMenu,
- PopoverMenuHandle, SpinnerLabel, Tab, Tooltip, prelude::*, utils::WithRemSize,
+ Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, DocumentationSide,
+ KeyBinding, PopoverMenu, PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize,
};
use util::{ResultExt as _, debug_panic};
use workspace::{
@@ -92,7 +95,6 @@ use zed_actions::{
const AGENT_PANEL_KEY: &str = "agent_panel";
const RECENTLY_UPDATED_MENU_LIMIT: usize = 6;
-const DEFAULT_THREAD_TITLE: &str = "New Thread";
fn read_serialized_panel(
workspace_id: workspace::WorkspaceId,
@@ -222,7 +224,7 @@ pub fn init(cx: &mut App) {
.register_action(|workspace, _: &OpenAgentDiff, window, cx| {
let thread = workspace
.panel::<AgentPanel>(cx)
- .and_then(|panel| panel.read(cx).active_conversation().cloned())
+ .and_then(|panel| panel.read(cx).active_conversation_view().cloned())
.and_then(|conversation| {
conversation
.read(cx)
@@ -404,17 +406,17 @@ pub fn init(cx: &mut App) {
});
},
)
- .register_action(|workspace, action: &StartThreadIn, _window, cx| {
+ .register_action(|workspace, action: &StartThreadIn, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
- panel.set_start_thread_in(action, cx);
+ panel.set_start_thread_in(action, window, cx);
});
}
})
- .register_action(|workspace, _: &CycleStartThreadIn, _window, cx| {
+ .register_action(|workspace, _: &CycleStartThreadIn, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
- panel.cycle_start_thread_in(cx);
+ panel.cycle_start_thread_in(window, cx);
});
}
});
@@ -775,11 +777,7 @@ impl AgentPanel {
SerializedActiveThread {
session_id: thread.session_id().0.to_string(),
agent_type: self.selected_agent_type.clone(),
- title: if title.as_ref() != DEFAULT_THREAD_TITLE {
- Some(title.to_string())
- } else {
- None
- },
+ title: title.map(|t| t.to_string()),
work_dirs: work_dirs.map(|dirs| dirs.serialize()),
}
});
@@ -1188,18 +1186,6 @@ impl AgentPanel {
.unwrap_or(false)
}
- pub fn active_conversation(&self) -> Option<&Entity<ConversationView>> {
- match &self.active_view {
- ActiveView::AgentThread {
- conversation_view, ..
- } => Some(conversation_view),
- ActiveView::Uninitialized
- | ActiveView::TextThread { .. }
- | ActiveView::History { .. }
- | ActiveView::Configuration => None,
- }
- }
-
pub fn new_thread(&mut self, _action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
self.new_agent_thread(AgentType::NativeAgent, window, cx);
}
@@ -1411,7 +1397,7 @@ impl AgentPanel {
}
fn expand_message_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let Some(conversation_view) = self.active_conversation() else {
+ let Some(conversation_view) = self.active_conversation_view() else {
return;
};
@@ -1737,7 +1723,7 @@ impl AgentPanel {
cx: &mut Context<Self>,
) {
if let Some(workspace) = self.workspace.upgrade()
- && let Some(conversation_view) = self.active_conversation()
+ && let Some(conversation_view) = self.active_conversation_view()
&& let Some(active_thread) = conversation_view.read(cx).active_thread().cloned()
{
active_thread.update(cx, |thread, cx| {
@@ -2263,7 +2249,12 @@ impl AgentPanel {
&self.start_thread_in
}
- fn set_start_thread_in(&mut self, action: &StartThreadIn, cx: &mut Context<Self>) {
+ fn set_start_thread_in(
+ &mut self,
+ action: &StartThreadIn,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
if matches!(action, StartThreadIn::NewWorktree) && !cx.has_flag::<AgentV2FeatureFlag>() {
return;
}
@@ -2285,16 +2276,19 @@ impl AgentPanel {
}
};
self.start_thread_in = new_target;
+ if let Some(thread) = self.active_thread_view(cx) {
+ thread.update(cx, |thread, cx| thread.focus_handle(cx).focus(window, cx));
+ }
self.serialize(cx);
cx.notify();
}
- fn cycle_start_thread_in(&mut self, cx: &mut Context<Self>) {
+ fn cycle_start_thread_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let next = match self.start_thread_in {
StartThreadIn::LocalProject => StartThreadIn::NewWorktree,
StartThreadIn::NewWorktree => StartThreadIn::LocalProject,
};
- self.set_start_thread_in(&next, cx);
+ self.set_start_thread_in(&next, window, cx);
}
fn reset_start_thread_in_to_default(&mut self, cx: &mut Context<Self>) {
@@ -2302,7 +2296,13 @@ impl AgentPanel {
let default = AgentSettings::get_global(cx).new_thread_location;
let start_thread_in = match default {
NewThreadLocation::LocalProject => StartThreadIn::LocalProject,
- NewThreadLocation::NewWorktree => StartThreadIn::NewWorktree,
+ NewThreadLocation::NewWorktree => {
+ if self.project_has_git_repository(cx) {
+ StartThreadIn::NewWorktree
+ } else {
+ StartThreadIn::LocalProject
+ }
+ }
};
if self.start_thread_in != start_thread_in {
self.start_thread_in = start_thread_in;
@@ -2536,7 +2536,7 @@ impl AgentPanel {
}
pub fn active_thread_is_draft(&self, cx: &App) -> bool {
- self.active_conversation().is_some() && !self.active_thread_has_messages(cx)
+ self.active_conversation_view().is_some() && !self.active_thread_has_messages(cx)
}
fn handle_first_send_requested(
@@ -2828,8 +2828,7 @@ impl AgentPanel {
None => {
this.update_in(cx, |this, window, cx| {
this.set_worktree_creation_error(
- "Failed to generate a branch name: all typewriter names are taken"
- .into(),
+ "Failed to generate a unique branch name".into(),
window,
cx,
);
@@ -2946,13 +2945,35 @@ impl AgentPanel {
})?
.await?;
- let panels_task = new_window_handle.update(cx, |_, _, cx| {
- new_workspace.update(cx, |workspace, _cx| workspace.take_panels_task())
- })?;
+ let panels_task = new_workspace.update(cx, |workspace, _cx| workspace.take_panels_task());
+
if let Some(task) = panels_task {
task.await.log_err();
}
+ new_workspace
+ .update(cx, |workspace, cx| {
+ workspace.project().read(cx).wait_for_initial_scan(cx)
+ })
+ .await;
+
+ new_workspace
+ .update(cx, |workspace, cx| {
+ let repos = workspace
+ .project()
+ .read(cx)
+ .repositories(cx)
+ .values()
+ .cloned()
+ .collect::<Vec<_>>();
+
+ let tasks = repos
+ .into_iter()
+ .map(|repo| repo.update(cx, |repo, _| repo.barrier()));
+ futures::future::join_all(tasks)
+ })
+ .await;
+
let initial_content = AgentInitialContent::ContentBlock {
blocks: content,
auto_submit: true,
@@ -3219,7 +3240,7 @@ impl AgentPanel {
.map(|r| r.read(cx).title_editor.clone())
{
if is_generating_title {
- Label::new("New Threadβ¦")
+ Label::new(DEFAULT_THREAD_TITLE)
.color(Color::Muted)
.truncate()
.with_animation(
@@ -3930,7 +3951,7 @@ impl AgentPanel {
};
let is_thread_loading = self
- .active_conversation()
+ .active_conversation_view()
.map(|thread| thread.read(cx).is_loading())
.unwrap_or(false);
@@ -4053,9 +4074,10 @@ impl AgentPanel {
.gap(DynamicSpacing::Base04.rems(cx))
.pl(DynamicSpacing::Base04.rems(cx))
.child(agent_selector_menu)
- .when(has_visible_worktrees, |this| {
- this.child(self.render_start_thread_in_selector(cx))
- }),
+ .when(
+ has_visible_worktrees && self.project_has_git_repository(cx),
+ |this| this.child(self.render_start_thread_in_selector(cx)),
+ ),
)
.child(
h_flex()
@@ -4134,41 +4156,31 @@ impl AgentPanel {
match status {
WorktreeCreationStatus::Creating => Some(
h_flex()
+ .absolute()
+ .bottom_12()
.w_full()
- .px(DynamicSpacing::Base06.rems(cx))
- .py(DynamicSpacing::Base02.rems(cx))
- .gap_2()
- .bg(cx.theme().colors().surface_background)
- .border_b_1()
- .border_color(cx.theme().colors().border)
- .child(SpinnerLabel::new().size(LabelSize::Small))
+ .p_2()
+ .gap_1()
+ .justify_center()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ .with_rotate_animation(3),
+ )
.child(
- Label::new("Creating worktreeβ¦")
+ Label::new("Creating Worktreeβ¦")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element(),
),
WorktreeCreationStatus::Error(message) => Some(
- h_flex()
- .w_full()
- .px(DynamicSpacing::Base06.rems(cx))
- .py(DynamicSpacing::Base02.rems(cx))
- .gap_2()
- .bg(cx.theme().colors().surface_background)
- .border_b_1()
- .border_color(cx.theme().colors().border)
- .child(
- Icon::new(IconName::Warning)
- .size(IconSize::Small)
- .color(Color::Warning),
- )
- .child(
- Label::new(message.clone())
- .color(Color::Warning)
- .size(LabelSize::Small)
- .truncate(),
- )
+ Callout::new()
+ .icon(IconName::Warning)
+ .severity(Severity::Warning)
+ .title(message.clone())
.into_any_element(),
),
}
@@ -4604,14 +4616,13 @@ impl Render for AgentPanel {
.on_action(cx.listener(Self::reset_font_size))
.on_action(cx.listener(Self::toggle_zoom))
.on_action(cx.listener(|this, _: &ReauthenticateAgent, window, cx| {
- if let Some(conversation_view) = this.active_conversation() {
+ if let Some(conversation_view) = this.active_conversation_view() {
conversation_view.update(cx, |conversation_view, cx| {
conversation_view.reauthenticate(window, cx)
})
}
}))
.child(self.render_toolbar(window, cx))
- .children(self.render_worktree_creation_status(cx))
.children(self.render_workspace_trust_message(cx))
.children(self.render_onboarding(window, cx))
.map(|parent| {
@@ -4668,6 +4679,7 @@ impl Render for AgentPanel {
ActiveView::Configuration => parent.children(self.configuration.clone()),
}
})
+ .children(self.render_worktree_creation_status(cx))
.children(self.render_trial_end_upsell(window, cx));
match self.active_view.which_font_size_used() {
@@ -4800,7 +4812,7 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate {
// Wait to create a new context until the workspace is no longer
// being updated.
cx.defer_in(window, move |panel, window, cx| {
- if let Some(conversation_view) = panel.active_conversation() {
+ if let Some(conversation_view) = panel.active_conversation_view() {
conversation_view.update(cx, |conversation_view, cx| {
conversation_view.insert_selections(window, cx);
});
@@ -4838,7 +4850,7 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate {
// Wait to create a new context until the workspace is no longer
// being updated.
cx.defer_in(window, move |panel, window, cx| {
- if let Some(conversation_view) = panel.active_conversation() {
+ if let Some(conversation_view) = panel.active_conversation_view() {
conversation_view.update(cx, |conversation_view, cx| {
conversation_view.insert_terminal_text(text, window, cx);
});
@@ -4904,7 +4916,7 @@ impl AgentPanel {
/// This is a test-only accessor that exposes the private `active_thread_view()`
/// method for test assertions. Not compiled into production builds.
pub fn active_thread_view_for_tests(&self) -> Option<&Entity<ConversationView>> {
- self.active_conversation()
+ self.active_conversation_view()
}
/// Sets the start_thread_in value directly, bypassing validation.
@@ -5094,7 +5106,7 @@ mod tests {
"workspace A agent type should be restored"
);
assert!(
- panel.active_conversation().is_some(),
+ panel.active_conversation_view().is_some(),
"workspace A should have its active thread restored"
);
});
@@ -5114,7 +5126,7 @@ mod tests {
"workspace B agent type should be restored"
);
assert!(
- panel.active_conversation().is_none(),
+ panel.active_conversation_view().is_none(),
"workspace B should have no active thread"
);
});
@@ -5566,7 +5578,7 @@ mod tests {
send_message(&panel, &mut cx);
let weak_view_a = panel.read_with(&cx, |panel, _cx| {
- panel.active_conversation().unwrap().downgrade()
+ panel.active_conversation_view().unwrap().downgrade()
});
let session_id_a = active_session_id(&panel, &cx);
@@ -5973,8 +5985,8 @@ mod tests {
});
// Change thread target to NewWorktree.
- panel.update(cx, |panel, cx| {
- panel.set_start_thread_in(&StartThreadIn::NewWorktree, cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx);
});
panel.read_with(cx, |panel, _cx| {
@@ -6196,11 +6208,11 @@ mod tests {
// Set the selected agent to Codex (a custom agent) and start_thread_in
// to NewWorktree. We do this AFTER opening the thread because
// open_external_thread_with_server overrides selected_agent_type.
- panel.update(cx, |panel, cx| {
+ panel.update_in(cx, |panel, window, cx| {
panel.selected_agent_type = AgentType::Custom {
id: CODEX_ID.into(),
};
- panel.set_start_thread_in(&StartThreadIn::NewWorktree, cx);
+ panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx);
});
// Verify the panel has the Codex agent selected.
@@ -80,6 +80,8 @@ pub(crate) use thread_history::ThreadHistory;
pub(crate) use thread_history_view::*;
use zed_actions;
+pub const DEFAULT_THREAD_TITLE: &str = "New Thread";
+
actions!(
agent,
[
@@ -1,710 +1,77 @@
use collections::HashSet;
use rand::Rng;
-/// Names of historical typewriter brands, for use in auto-generated branch names.
-/// (Hyphens and parens have been dropped so that the branch names are one-word.)
-///
-/// Thanks to https://typewriterdatabase.com/alph.0.brands for the names!
-const TYPEWRITER_NAMES: &[&str] = &[
- "abeille",
- "acme",
- "addo",
- "adler",
- "adlerette",
- "adlerita",
- "admiral",
- "agamli",
- "agar",
- "agidel",
- "agil",
- "aguia",
- "aguila",
- "ahram",
- "aigle",
- "ajax",
- "aktiv",
- "ala",
- "alba",
- "albus",
- "alexander",
- "alexis",
- "alfa",
- "allen",
- "alonso",
- "alpina",
- "amata",
- "amaya",
- "amka",
- "anavi",
- "anderson",
- "andina",
- "antares",
- "apex",
- "apsco",
- "aquila",
- "archo",
- "ardita",
- "argyle",
- "aristocrat",
- "aristokrat",
- "arlington",
- "armstrong",
- "arpha",
- "artus",
- "astoria",
- "atlantia",
- "atlantic",
- "atlas",
- "augusta",
- "aurora",
- "austro",
- "automatic",
- "avanti",
- "avona",
- "azzurra",
- "bajnok",
- "baldwin",
- "balkan",
- "baltica",
- "baltimore",
- "barlock",
- "barr",
- "barrat",
- "bartholomew",
- "bashkiriya",
- "bavaria",
- "beaucourt",
- "beko",
- "belka",
- "bennett",
- "bennington",
- "berni",
- "bianca",
- "bijou",
- "bing",
- "bisei",
- "biser",
- "bluebird",
- "bolida",
- "borgo",
- "boston",
- "boyce",
- "bradford",
- "brandenburg",
- "brigitte",
- "briton",
- "brooks",
- "brosette",
- "buddy",
- "burns",
- "burroughs",
- "byron",
- "calanda",
- "caligraph",
- "cappel",
- "cardinal",
- "carissima",
- "carlem",
- "carlton",
- "carmen",
- "cawena",
- "cella",
- "celtic",
- "century",
- "champignon",
- "cherryland",
- "chevron",
- "chicago",
- "cicero",
- "cifra",
- "citizen",
- "claudia",
- "cleveland",
- "clover",
- "coffman",
- "cole",
- "columbia",
- "commercial",
- "companion",
- "concentra",
- "concord",
- "concordia",
- "conover",
- "constanta",
- "consul",
- "conta",
- "contenta",
- "contimat",
- "contina",
- "continento",
- "cornelia",
- "coronado",
- "cosmopolita",
- "courier",
- "craftamatic",
- "crandall",
- "crown",
- "culema",
- "dactyle",
- "dankers",
- "dart",
- "daugherty",
- "davis",
- "dayton",
- "dea",
- "delmar",
- "densmore",
- "depantio",
- "diadema",
- "dial",
- "diamant",
- "diana",
- "dictatype",
- "diplomat",
- "diskret",
- "dolfus",
- "dollar",
- "domus",
- "drake",
- "draper",
- "duplex",
- "durabel",
- "dynacord",
- "eagle",
- "eclipse",
- "edelmann",
- "edelweiss",
- "edison",
- "edita",
- "edland",
- "efka",
- "eldorado",
- "electa",
- "electromatic",
- "elektro",
- "elgin",
- "elliot",
- "emerson",
- "emka",
- "emona",
- "empire",
- "engadine",
- "engler",
- "erfurt",
- "erika",
- "esko",
- "essex",
- "eureka",
- "europa",
- "everest",
- "everlux",
- "excelsior",
- "express",
- "fabers",
- "facit",
- "fairbanks",
- "faktotum",
- "famos",
- "federal",
- "felio",
- "fidat",
- "filius",
- "fips",
- "fish",
- "fitch",
- "fleet",
- "florida",
- "flott",
- "flyer",
- "flying",
- "fontana",
- "ford",
- "forto",
- "fortuna",
- "fox",
- "framo",
- "franconia",
- "franklin",
- "friden",
- "frolio",
- "furstenberg",
- "galesburg",
- "galiette",
- "gallia",
- "garbell",
- "gardner",
- "geka",
- "generation",
- "genia",
- "geniatus",
- "gerda",
- "gisela",
- "glashutte",
- "gloria",
- "godrej",
- "gossen",
- "gourland",
- "grandjean",
- "granta",
- "granville",
- "graphic",
- "gritzner",
- "groma",
- "guhl",
- "guidonia",
- "gundka",
- "hacabo",
- "haddad",
- "halberg",
- "halda",
- "hall",
- "hammond",
- "hammonia",
- "hanford",
- "hansa",
- "harmony",
- "harris",
- "hartford",
- "hassia",
- "hatch",
- "heady",
- "hebronia",
- "hebros",
- "hega",
- "helios",
- "helma",
- "herald",
- "hercules",
- "hermes",
- "herold",
- "heros",
- "hesperia",
- "hogar",
- "hooven",
- "hopkins",
- "horton",
- "hugin",
- "hungaria",
- "hurtu",
- "iberia",
- "idea",
- "ideal",
- "imperia",
- "impo",
- "industria",
- "industrio",
- "ingersoll",
- "international",
- "invicta",
- "irene",
- "iris",
- "iskra",
- "ivitsa",
- "ivriah",
- "jackson",
- "janalif",
- "janos",
- "jolux",
- "juki",
- "junior",
- "juventa",
- "juwel",
- "kamkap",
- "kamo",
- "kanzler",
- "kappel",
- "karli",
- "karstadt",
- "keaton",
- "kenbar",
- "keystone",
- "kim",
- "klein",
- "kneist",
- "knoch",
- "koh",
- "kolibri",
- "kolumbus",
- "komet",
- "kondor",
- "koniger",
- "konryu",
- "kontor",
- "kosmopolit",
- "krypton",
- "lambert",
- "lasalle",
- "lectra",
- "leframa",
- "lemair",
- "lemco",
- "liberty",
- "libia",
- "liga",
- "lignose",
- "lilliput",
- "lindeteves",
- "linowriter",
- "listvitsa",
- "ludolf",
- "lutece",
- "luxa",
- "lyubava",
- "mafra",
- "magnavox",
- "maher",
- "majestic",
- "majitouch",
- "manhattan",
- "mapuua",
- "marathon",
- "marburger",
- "maritsa",
- "maruzen",
- "maskelyne",
- "masspro",
- "matous",
- "mccall",
- "mccool",
- "mcloughlin",
- "mead",
- "mechno",
- "mehano",
- "meiselbach",
- "melbi",
- "melior",
- "melotyp",
- "mentor",
- "mepas",
- "mercedesia",
- "mercurius",
- "mercury",
- "merkur",
- "merritt",
- "merz",
- "messa",
- "meteco",
- "meteor",
- "micron",
- "mignon",
- "mikro",
- "minerva",
- "mirian",
- "mirina",
- "mitex",
- "molle",
- "monac",
- "monarch",
- "mondiale",
- "monica",
- "monofix",
- "monopol",
- "monpti",
- "monta",
- "montana",
- "montgomery",
- "moon",
- "morgan",
- "morris",
- "morse",
- "moya",
- "moyer",
- "munson",
- "musicwriter",
- "nadex",
- "nakajima",
- "neckermann",
- "neubert",
- "neya",
- "ninety",
- "nisa",
- "noiseless",
- "noor",
- "nora",
- "nord",
- "norden",
- "norica",
- "norma",
- "norman",
- "north",
- "nototyp",
- "nova",
- "novalevi",
- "odell",
- "odhner",
- "odo",
- "odoma",
- "ohio",
- "ohtani",
- "oliva",
- "oliver",
- "olivetti",
- "olympia",
- "omega",
- "optima",
- "orbis",
- "orel",
- "orga",
- "oriette",
- "orion",
- "orn",
- "orplid",
- "pacior",
- "pagina",
- "parisienne",
- "passat",
- "pearl",
- "peerless",
- "perfect",
- "perfecta",
- "perkeo",
- "perkins",
- "perlita",
- "pettypet",
- "phoenix",
- "piccola",
- "picht",
- "pinnock",
- "pionier",
- "plurotyp",
- "plutarch",
- "pneumatic",
- "pocket",
- "polyglott",
- "polygraph",
- "pontiac",
- "portable",
- "portex",
- "pozzi",
- "premier",
- "presto",
- "primavera",
- "progress",
- "protos",
- "pterotype",
- "pullman",
- "pulsatta",
- "quick",
- "racer",
- "radio",
- "rally",
- "rand",
- "readers",
- "reed",
- "referent",
- "reff",
- "regent",
- "regia",
- "regina",
- "rekord",
- "reliable",
- "reliance",
- "remagg",
- "rembrandt",
- "remer",
- "remington",
- "remsho",
- "remstar",
- "remtor",
- "reporters",
- "resko",
- "rex",
- "rexpel",
- "rheinita",
- "rheinmetall",
- "rival",
- "roberts",
- "robotron",
- "rocher",
- "rochester",
- "roebuck",
- "rofa",
- "roland",
- "rooy",
- "rover",
- "roxy",
- "roy",
- "royal",
- "rundstatler",
- "sabaudia",
- "sabb",
- "saleem",
- "salter",
- "sampo",
- "sarafan",
- "saturn",
- "saxonia",
- "schade",
- "schapiro",
- "schreibi",
- "scripta",
- "sears",
- "secor",
- "selectric",
- "selekta",
- "senator",
- "sense",
- "senta",
- "serd",
- "shilling",
- "shimade",
- "shimer",
- "sholes",
- "shuang",
- "siegfried",
- "siemag",
- "silma",
- "silver",
- "simplex",
- "simtype",
- "singer",
- "smith",
- "soemtron",
- "sonja",
- "speedwriter",
- "sphinx",
- "starlet",
- "stearns",
- "steel",
- "stella",
- "steno",
- "sterling",
- "stoewer",
- "stolzenberg",
- "stott",
- "strangfeld",
- "sture",
- "stylotyp",
- "sun",
- "superba",
- "superia",
- "supermetall",
- "surety",
- "swintec",
- "swissa",
- "talbos",
- "talleres",
- "tatrapoint",
- "taurus",
- "taylorix",
- "tell",
- "tempotype",
- "tippco",
- "titania",
- "tops",
- "towa",
- "toyo",
- "tradition",
- "transatlantic",
- "traveller",
- "trebla",
- "triumph",
- "turia",
- "typatune",
- "typen",
- "typorium",
- "ugro",
- "ultima",
- "unda",
- "underwood",
- "unica",
- "unitype",
- "ursula",
- "utax",
- "varityper",
- "vasanta",
- "vendex",
- "venus",
- "victor",
- "victoria",
- "video",
- "viking",
- "vira",
- "virotyp",
- "visigraph",
- "vittoria",
- "volcan",
- "vornado",
- "voss",
- "vultur",
- "waltons",
- "wanamaker",
- "wanderer",
- "ward",
- "warner",
- "waterloo",
- "waverley",
- "wayne",
- "webster",
- "wedgefield",
- "welco",
- "wellington",
- "wellon",
- "weltblick",
- "westphalia",
- "wiedmer",
- "williams",
- "wilson",
- "winkel",
- "winsor",
- "wizard",
- "woodstock",
- "woodwards",
- "yatran",
- "yost",
- "zenit",
- "zentronik",
- "zeta",
- "zeya",
+const ADJECTIVES: &[&str] = &[
+ "able", "agate", "agile", "alpine", "amber", "ample", "aqua", "arctic", "arid", "astral",
+ "autumn", "avid", "azure", "balmy", "birch", "bold", "boreal", "brave", "breezy", "brief",
+ "bright", "brisk", "broad", "bronze", "calm", "cerith", "civil", "clean", "clear", "clever",
+ "cobalt", "cool", "copper", "coral", "cozy", "crisp", "cubic", "cyan", "deft", "dense", "dewy",
+ "direct", "dusky", "dusty", "eager", "early", "earnest", "elder", "elfin", "equal", "even",
+ "exact", "faint", "fair", "fast", "fawn", "ferny", "fiery", "fine", "firm", "fleet", "floral",
+ "focal", "fond", "frank", "fresh", "frosty", "full", "gentle", "gilded", "glacial", "glad",
+ "glossy", "golden", "grand", "green", "gusty", "hale", "happy", "hardy", "hazel", "hearty",
+ "hilly", "humble", "hushed", "icy", "ideal", "inner", "iron", "ivory", "jade", "jovial",
+ "keen", "kind", "lapis", "leafy", "level", "light", "lilac", "limber", "lively", "local",
+ "lofty", "lucid", "lunar", "major", "maple", "mellow", "merry", "mild", "milky", "misty",
+ "modal", "modest", "mossy", "muted", "native", "naval", "neat", "nimble", "noble", "north",
+ "novel", "oaken", "ochre", "olive", "onyx", "opal", "open", "optic", "outer", "owed", "ozone",
+ "pale", "pastel", "pearl", "pecan", "peppy", "pilot", "placid", "plain", "plum", "plush",
+ "poised", "polar", "polished", "poplar", "prime", "proof", "proud", "pure", "quartz", "quick",
+ "quiet", "rapid", "raspy", "ready", "regal", "rooted", "rosy", "round", "royal", "ruby",
+ "ruddy", "russet", "rustic", "sage", "salty", "sandy", "satin", "scenic", "sedge", "serene",
+ "sharp", "sheer", "silky", "silver", "sleek", "smart", "smooth", "snowy", "solar", "solid",
+ "south", "spry", "stark", "steady", "steel", "steep", "still", "stoic", "stony", "stout",
+ "sturdy", "suede", "sunny", "supple", "sure", "swift", "tall", "tawny", "teal", "terse",
+ "thick", "tidal", "tidy", "timber", "topaz", "total", "trim", "tropic", "true", "tulip",
+ "upper", "urban", "valid", "vast", "velvet", "verde", "vivid", "vocal", "warm", "waxen",
+ "west", "whole", "wide", "wild", "wise", "witty", "woven", "young", "zealous", "zephyr",
+ "zesty", "zinc",
];
-/// Picks a typewriter name that isn't already taken by an existing branch.
-///
-/// Each entry in `existing_branches` is expected to be a full branch name
-/// like `"olivetti-a3f9b2c1"`. The prefix before the last `'-'` is treated
-/// as the taken typewriter name. Branches without a `'-'` are ignored.
+const NOUNS: &[&str] = &[
+ "anchor", "anvil", "arbor", "arch", "arrow", "atlas", "badge", "badger", "basin", "bay",
+ "beacon", "beam", "bell", "birch", "blade", "bloom", "bluff", "bolt", "bower", "breeze",
+ "bridge", "brook", "bunting", "cabin", "cairn", "canyon", "cape", "cedar", "chasm", "cliff",
+ "cloud", "clover", "coast", "cobble", "colt", "comet", "condor", "coral", "cove", "crane",
+ "crater", "creek", "crest", "curlew", "cypress", "dale", "dawn", "delta", "den", "dove",
+ "drake", "drift", "drum", "dune", "dusk", "eagle", "echo", "egret", "elk", "elm", "ember",
+ "falcon", "fawn", "fern", "ferry", "field", "finch", "fjord", "flame", "flint", "flower",
+ "forge", "fossil", "fox", "frost", "gale", "garnet", "gate", "gazelle", "geyser", "glade",
+ "glen", "gorge", "granite", "grove", "gull", "harbor", "hare", "haven", "hawk", "hazel",
+ "heath", "hedge", "heron", "hill", "hollow", "horizon", "ibis", "inlet", "isle", "ivy",
+ "jackal", "jasper", "juniper", "kestrel", "kinglet", "knoll", "lagoon", "lake", "lantern",
+ "larch", "lark", "laurel", "lava", "leaf", "ledge", "lily", "linden", "lodge", "loft", "lotus",
+ "lynx", "mantle", "maple", "marble", "marsh", "marten", "meadow", "merlin", "mesa", "mill",
+ "mint", "moon", "moose", "moss", "newt", "north", "nutmeg", "oak", "oasis", "obsidian",
+ "orbit", "orchid", "oriole", "osprey", "otter", "owl", "palm", "panther", "pass", "path",
+ "peak", "pebble", "pelican", "peony", "perch", "pier", "pine", "plover", "plume", "pond",
+ "poppy", "prairie", "prism", "puma", "quail", "quarry", "quartz", "rain", "rampart", "range",
+ "raven", "ravine", "reed", "reef", "ridge", "river", "robin", "rowan", "sage", "salmon",
+ "sequoia", "shore", "shrike", "sigma", "sky", "slate", "slope", "snow", "spark", "sparrow",
+ "spider", "spruce", "stag", "star", "stone", "stork", "storm", "stream", "summit", "swift",
+ "sycamore", "tern", "terrace", "thistle", "thorn", "thrush", "tide", "timber", "torch",
+ "tower", "trail", "trout", "tulip", "tundra", "vale", "valley", "veranda", "viper", "vista",
+ "vole", "walrus", "warbler", "willow", "wolf", "wren", "yew", "zenith",
+];
+
+/// Generates a branch name in `"adjective-noun"` format (e.g. `"swift-falcon"`).
///
-/// Returns `None` when every name in the pool is already taken.
-pub fn pick_typewriter_name(
- existing_branches: &[&str],
- rng: &mut impl Rng,
-) -> Option<&'static str> {
- let disallowed: HashSet<&str> = existing_branches
- .iter()
- .filter_map(|branch| branch.rsplit_once('-').map(|(prefix, _)| prefix))
- .collect();
+/// Tries up to 100 random combinations, skipping any name that already appears
+/// in `existing_branches`. Returns `None` if no unused name is found.
+pub fn generate_branch_name(existing_branches: &[&str], rng: &mut impl Rng) -> Option<String> {
+ let existing: HashSet<&str> = existing_branches.iter().copied().collect();
- let available: Vec<&'static str> = TYPEWRITER_NAMES
- .iter()
- .copied()
- .filter(|name| !disallowed.contains(name))
- .collect();
+ for _ in 0..100 {
+ let adjective = ADJECTIVES[rng.random_range(0..ADJECTIVES.len())];
+ let noun = NOUNS[rng.random_range(0..NOUNS.len())];
+ let name = format!("{adjective}-{noun}");
- if available.is_empty() {
- return None;
+ if !existing.contains(name.as_str()) {
+ return Some(name);
+ }
}
- let index = rng.random_range(0..available.len());
- Some(available[index])
-}
-
-/// Generates a branch name like `"olivetti-a3f9b2c1"` by picking a typewriter
-/// name that isn't already taken and appending an 8-character alphanumeric hash.
-///
-/// Returns `None` when every typewriter name in the pool is already taken.
-pub fn generate_branch_name(existing_branches: &[&str], rng: &mut impl Rng) -> Option<String> {
- let typewriter_name = pick_typewriter_name(existing_branches, rng)?;
- let hash: String = (0..8)
- .map(|_| {
- let idx: u8 = rng.random_range(0..36);
- if idx < 10 {
- (b'0' + idx) as char
- } else {
- (b'a' + idx - 10) as char
- }
- })
- .collect();
- Some(format!("{typewriter_name}-{hash}"))
+ None
}
#[cfg(test)]
@@ -713,134 +80,91 @@ mod tests {
use rand::rngs::StdRng;
#[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_with_no_disallowed(mut rng: StdRng) {
- let name = pick_typewriter_name(&[], &mut rng);
- assert!(name.is_some());
- assert!(TYPEWRITER_NAMES.contains(&name.unwrap()));
- }
-
- #[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_excludes_taken_names(mut rng: StdRng) {
- let branch_names = &["olivetti-abc12345", "selectric-def67890"];
- let name = pick_typewriter_name(branch_names, &mut rng).unwrap();
- assert_ne!(name, "olivetti");
- assert_ne!(name, "selectric");
- }
-
- #[gpui::test]
- fn test_pick_typewriter_name_all_taken(mut rng: StdRng) {
- let branch_names: Vec<String> = TYPEWRITER_NAMES
- .iter()
- .map(|name| format!("{name}-00000000"))
- .collect();
- let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect();
- let name = pick_typewriter_name(&branch_name_refs, &mut rng);
- assert!(name.is_none());
- }
-
- #[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_ignores_branches_without_hyphen(mut rng: StdRng) {
- let branch_names = &["main", "develop", "feature"];
- let name = pick_typewriter_name(branch_names, &mut rng);
- assert!(name.is_some());
- assert!(TYPEWRITER_NAMES.contains(&name.unwrap()));
+ fn test_generate_branch_name_format(mut rng: StdRng) {
+ let name = generate_branch_name(&[], &mut rng).unwrap();
+ let (adjective, noun) = name.split_once('-').expect("name should contain a hyphen");
+ assert!(
+ ADJECTIVES.contains(&adjective),
+ "{adjective:?} is not in ADJECTIVES"
+ );
+ assert!(NOUNS.contains(&noun), "{noun:?} is not in NOUNS");
}
- #[gpui::test(iterations = 10)]
- fn test_generate_branch_name_format(mut rng: StdRng) {
- let branch_name = generate_branch_name(&[], &mut rng).unwrap();
- let (prefix, suffix) = branch_name.rsplit_once('-').unwrap();
- assert!(TYPEWRITER_NAMES.contains(&prefix));
- assert_eq!(suffix.len(), 8);
- assert!(suffix.chars().all(|c| c.is_ascii_alphanumeric()));
+ #[gpui::test(iterations = 100)]
+ fn test_generate_branch_name_avoids_existing(mut rng: StdRng) {
+ let existing = &["swift-falcon", "calm-river", "bold-cedar"];
+ let name = generate_branch_name(existing, &mut rng).unwrap();
+ for &branch in existing {
+ assert_ne!(
+ name, branch,
+ "generated name should not match an existing branch"
+ );
+ }
}
#[gpui::test]
- fn test_generate_branch_name_returns_none_when_exhausted(mut rng: StdRng) {
- let branch_names: Vec<String> = TYPEWRITER_NAMES
+ fn test_generate_branch_name_returns_none_when_stuck(mut rng: StdRng) {
+ let all_names: Vec<String> = ADJECTIVES
.iter()
- .map(|name| format!("{name}-00000000"))
+ .flat_map(|adj| NOUNS.iter().map(move |noun| format!("{adj}-{noun}")))
.collect();
- let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect();
- let result = generate_branch_name(&branch_name_refs, &mut rng);
+ let refs: Vec<&str> = all_names.iter().map(|s| s.as_str()).collect();
+ let result = generate_branch_name(&refs, &mut rng);
assert!(result.is_none());
}
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_never_reuses_taken_prefix(mut rng: StdRng) {
- let existing = &["olivetti-123abc", "selectric-def456"];
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert_ne!(prefix, "olivetti");
- assert_ne!(prefix, "selectric");
- }
+ #[test]
+ fn test_adjectives_are_valid() {
+ let mut seen = HashSet::default();
+ for &word in ADJECTIVES {
+ assert!(seen.insert(word), "duplicate entry in ADJECTIVES: {word:?}");
+ }
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_avoids_multiple_taken_prefixes(mut rng: StdRng) {
- let existing = &[
- "olivetti-aaa11111",
- "selectric-bbb22222",
- "corona-ccc33333",
- "remington-ddd44444",
- "underwood-eee55555",
- ];
- let taken_prefixes: HashSet<&str> = existing
- .iter()
- .filter_map(|b| b.rsplit_once('-').map(|(prefix, _)| prefix))
- .collect();
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert!(
- !taken_prefixes.contains(prefix),
- "generated prefix {prefix:?} collides with an existing branch"
- );
- }
+ for window in ADJECTIVES.windows(2) {
+ assert!(
+ window[0] < window[1],
+ "ADJECTIVES is not sorted: {0:?} should come before {1:?}",
+ window[0],
+ window[1],
+ );
+ }
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_with_varied_hash_suffixes(mut rng: StdRng) {
- let existing = &[
- "olivetti-aaaaaaaa",
- "olivetti-bbbbbbbb",
- "olivetti-cccccccc",
- ];
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert_ne!(
- prefix, "olivetti",
- "should avoid olivetti regardless of how many variants exist"
- );
+ for &word in ADJECTIVES {
+ assert!(
+ !word.contains('-'),
+ "ADJECTIVES entry contains a hyphen: {word:?}"
+ );
+ assert!(
+ word.chars().all(|c| c.is_lowercase()),
+ "ADJECTIVES entry is not all lowercase: {word:?}"
+ );
+ }
}
#[test]
- fn test_typewriter_names_are_valid() {
+ fn test_nouns_are_valid() {
let mut seen = HashSet::default();
- for &name in TYPEWRITER_NAMES {
- assert!(
- seen.insert(name),
- "duplicate entry in TYPEWRITER_NAMES: {name:?}"
- );
+ for &word in NOUNS {
+ assert!(seen.insert(word), "duplicate entry in NOUNS: {word:?}");
}
- for window in TYPEWRITER_NAMES.windows(2) {
+ for window in NOUNS.windows(2) {
assert!(
- window[0] <= window[1],
- "TYPEWRITER_NAMES is not sorted: {0:?} should come after {1:?}",
- window[1],
+ window[0] < window[1],
+ "NOUNS is not sorted: {0:?} should come before {1:?}",
window[0],
+ window[1],
);
}
- for &name in TYPEWRITER_NAMES {
+ for &word in NOUNS {
assert!(
- !name.contains('-'),
- "TYPEWRITER_NAMES entry contains a hyphen: {name:?}"
+ !word.contains('-'),
+ "NOUNS entry contains a hyphen: {word:?}"
);
- }
-
- for &name in TYPEWRITER_NAMES {
assert!(
- name.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()),
- "TYPEWRITER_NAMES entry is not lowercase: {name:?}"
+ word.chars().all(|c| c.is_lowercase()),
+ "NOUNS entry is not all lowercase: {word:?}"
);
}
}
@@ -4,6 +4,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
+use crate::DEFAULT_THREAD_TITLE;
use crate::ThreadHistory;
use acp_thread::MentionUri;
use agent_client_protocol as acp;
@@ -192,7 +193,7 @@ pub struct EntryMatch {
fn session_title(title: Option<SharedString>) -> SharedString {
title
.filter(|title| !title.is_empty())
- .unwrap_or_else(|| SharedString::new_static("New Thread"))
+ .unwrap_or_else(|| SharedString::new_static(DEFAULT_THREAD_TITLE))
}
#[derive(Debug, Clone)]
@@ -873,7 +874,7 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
let project = workspace.read(cx).project().clone();
let repo = project.read(cx).active_repository(cx)?;
- let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(false));
+ let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(true));
Some(cx.spawn(async move |_cx| {
let base_ref = default_branch_receiver
@@ -1098,11 +1099,11 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
if let Some(agent_panel) = workspace.panel::<AgentPanel>(cx)
&& let Some(thread) = agent_panel.read(cx).active_agent_thread(cx)
+ && let Some(title) = thread.read(cx).title()
{
- let thread = thread.read(cx);
mentions.insert(MentionUri::Thread {
- id: thread.session_id().clone(),
- name: thread.title().into(),
+ id: thread.read(cx).session_id().clone(),
+ name: title.to_string(),
});
}
@@ -1,9 +1,8 @@
use acp_thread::{
AcpThread, AcpThreadEvent, AgentSessionInfo, AgentThreadEntry, AssistantMessage,
AssistantMessageChunk, AuthRequired, LoadError, MentionUri, PermissionOptionChoice,
- PermissionOptions, PermissionPattern, RetryStatus, SelectedPermissionOutcome,
- SelectedPermissionParams, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
- UserMessageId,
+ PermissionOptions, PermissionPattern, RetryStatus, SelectedPermissionOutcome, ThreadStatus,
+ ToolCall, ToolCallContent, ToolCallStatus, UserMessageId,
};
use acp_thread::{AgentConnection, Plan};
use action_log::{ActionLog, ActionLogTelemetry, DiffStats};
@@ -14,7 +13,6 @@ use agent_servers::AgentServerDelegate;
use agent_servers::{AgentServer, GEMINI_TERMINAL_AUTH_METHOD_ID};
use agent_settings::{AgentProfileId, AgentSettings};
use anyhow::{Result, anyhow};
-use arrayvec::ArrayVec;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
use client::zed_urls;
@@ -41,6 +39,7 @@ use parking_lot::RwLock;
use project::{AgentId, AgentServerStore, Project, ProjectEntryId};
use prompt_store::{PromptId, PromptStore};
+use crate::DEFAULT_THREAD_TITLE;
use crate::message_editor::SessionCapabilities;
use rope::Point;
use settings::{NotifyWhenAgentWaiting, Settings as _, SettingsStore};
@@ -249,8 +248,7 @@ impl Conversation {
self.authorize_tool_call(
session_id.clone(),
tool_call_id,
- option.option_id.clone().into(),
- option.kind,
+ SelectedPermissionOutcome::new(option.option_id.clone(), option.kind),
cx,
);
Some(())
@@ -261,7 +259,6 @@ impl Conversation {
session_id: acp::SessionId,
tool_call_id: acp::ToolCallId,
outcome: SelectedPermissionOutcome,
- option_kind: acp::PermissionOptionKind,
cx: &mut Context<Self>,
) {
let Some(thread) = self.threads.get(&session_id) else {
@@ -273,11 +270,11 @@ impl Conversation {
"Agent Tool Call Authorized",
agent = agent_telemetry_id,
session = session_id,
- option = option_kind
+ option = outcome.option_kind
);
thread.update(cx, |thread, cx| {
- thread.authorize_tool_call(tool_call_id, outcome, option_kind, cx);
+ thread.authorize_tool_call(tool_call_id, outcome, cx);
});
cx.notify();
}
@@ -552,7 +549,7 @@ impl ConversationView {
(
Some(thread.session_id().clone()),
thread.work_dirs().cloned(),
- Some(thread.title()),
+ thread.title(),
)
})
.unwrap_or((None, None, None));
@@ -1107,9 +1104,12 @@ impl ConversationView {
&self.workspace
}
- pub fn title(&self, _cx: &App) -> SharedString {
+ pub fn title(&self, cx: &App) -> SharedString {
match &self.server_state {
- ServerState::Connected(_) => "New Thread".into(),
+ ServerState::Connected(view) => view
+ .active_view()
+ .and_then(|v| v.read(cx).thread.read(cx).title())
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into()),
ServerState::Loading(_) => "Loadingβ¦".into(),
ServerState::LoadError { error, .. } => match error {
LoadError::Unsupported { .. } => {
@@ -1351,8 +1351,9 @@ impl ConversationView {
);
}
AcpThreadEvent::TitleUpdated => {
- let title = thread.read(cx).title();
- if let Some(active_thread) = self.thread_view(&thread_id) {
+ if let Some(title) = thread.read(cx).title()
+ && let Some(active_thread) = self.thread_view(&thread_id)
+ {
let title_editor = active_thread.read(cx).title_editor.clone();
title_editor.update(cx, |editor, cx| {
if editor.text(cx) != title {
@@ -3709,7 +3710,7 @@ pub(crate) mod tests {
cx.new(|cx| {
AcpThread::new(
None,
- name,
+ Some(name.into()),
None,
connection,
project,
@@ -3909,7 +3910,7 @@ pub(crate) mod tests {
Task::ready(Ok(cx.new(|cx| {
AcpThread::new(
None,
- "AuthGatedAgent",
+ None,
Some(work_dirs),
self,
project,
@@ -3983,7 +3984,7 @@ pub(crate) mod tests {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(
None,
- "SaboteurAgentConnection",
+ None,
Some(work_dirs),
self,
project,
@@ -4053,7 +4054,7 @@ pub(crate) mod tests {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(
None,
- "RefusalAgentConnection",
+ None,
Some(work_dirs),
self,
project,
@@ -4133,7 +4134,7 @@ pub(crate) mod tests {
let thread = cx.new(|cx| {
AcpThread::new(
None,
- "CwdCapturingConnection",
+ None,
Some(work_dirs),
self.clone(),
project,
@@ -4168,7 +4169,7 @@ pub(crate) mod tests {
let thread = cx.new(|cx| {
AcpThread::new(
None,
- "CwdCapturingConnection",
+ None,
Some(work_dirs),
self.clone(),
project,
@@ -6110,7 +6111,7 @@ pub(crate) mod tests {
assert_eq!(editor.text(cx), "My Custom Title");
});
thread.read_with(cx, |thread, _cx| {
- assert_eq!(thread.title().as_ref(), "My Custom Title");
+ assert_eq!(thread.title(), Some("My Custom Title".into()));
});
}
@@ -6196,7 +6197,7 @@ pub(crate) mod tests {
cx.new(|cx| {
AcpThread::new(
parent_session_id,
- "Test Thread",
+ None,
None,
connection,
project,
@@ -6272,8 +6273,10 @@ pub(crate) mod tests {
conversation.authorize_tool_call(
acp::SessionId::new("session-1"),
acp::ToolCallId::new("tc-1"),
- acp::PermissionOptionId::new("allow-1").into(),
- acp::PermissionOptionKind::AllowOnce,
+ SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow-1"),
+ acp::PermissionOptionKind::AllowOnce,
+ ),
cx,
);
});
@@ -6295,8 +6298,10 @@ pub(crate) mod tests {
conversation.authorize_tool_call(
acp::SessionId::new("session-1"),
acp::ToolCallId::new("tc-2"),
- acp::PermissionOptionId::new("allow-2").into(),
- acp::PermissionOptionKind::AllowOnce,
+ SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow-2"),
+ acp::PermissionOptionKind::AllowOnce,
+ ),
cx,
);
});
@@ -6434,8 +6439,10 @@ pub(crate) mod tests {
conversation.authorize_tool_call(
acp::SessionId::new("thread-a"),
acp::ToolCallId::new("tc-a"),
- acp::PermissionOptionId::new("allow-a").into(),
- acp::PermissionOptionKind::AllowOnce,
+ SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow-a"),
+ acp::PermissionOptionKind::AllowOnce,
+ ),
cx,
);
});
@@ -6704,7 +6711,7 @@ pub(crate) mod tests {
let thread = cx.new(|cx| {
AcpThread::new(
None,
- "CloseCapableConnection",
+ Some("CloseCapableConnection".into()),
Some(work_dirs),
self,
project,
@@ -1,4 +1,4 @@
-use crate::SelectPermissionGranularity;
+use crate::{DEFAULT_THREAD_TITLE, SelectPermissionGranularity};
use std::cell::RefCell;
use acp_thread::ContentBlock;
@@ -8,6 +8,7 @@ use editor::actions::OpenExcerpts;
use crate::StartThreadIn;
use crate::message_editor::SharedSessionCapabilities;
use gpui::{Corner, List};
+use heapless::Vec as ArrayVec;
use language_model::{LanguageModelEffortLevel, Speed};
use settings::update_settings_file;
use ui::{ButtonLike, SplitButton, SplitButtonStyle, Tab};
@@ -404,7 +405,11 @@ impl ThreadView {
let can_edit = thread.update(cx, |thread, cx| thread.can_set_title(cx));
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
- editor.set_text(thread.read(cx).title(), window, cx);
+ if let Some(title) = thread.read(cx).title() {
+ editor.set_text(title, window, cx);
+ } else {
+ editor.set_text(DEFAULT_THREAD_TITLE, window, cx);
+ }
editor.set_read_only(!can_edit);
editor
});
@@ -1051,7 +1056,7 @@ impl ThreadView {
.ok();
}
});
- if is_first_message {
+ if is_first_message && thread.read_with(cx, |thread, _cx| thread.title().is_none())? {
let text: String = contents
.iter()
.filter_map(|block| match block {
@@ -1065,7 +1070,7 @@ impl ThreadView {
.join(" ");
let text = text.lines().next().unwrap_or("").trim();
if !text.is_empty() {
- let title: SharedString = util::truncate_and_trailoff(text, 20).into();
+ let title: SharedString = util::truncate_and_trailoff(text, 200).into();
thread.update(cx, |thread, cx| {
thread.set_provisional_title(title, cx);
})?;
@@ -1536,7 +1541,7 @@ impl ThreadView {
EditorEvent::Blurred => {
if title_editor.read(cx).text(cx).is_empty() {
title_editor.update(cx, |editor, cx| {
- editor.set_text("New Thread", window, cx);
+ editor.set_text(DEFAULT_THREAD_TITLE, window, cx);
});
}
}
@@ -1575,12 +1580,11 @@ impl ThreadView {
session_id: acp::SessionId,
tool_call_id: acp::ToolCallId,
outcome: SelectedPermissionOutcome,
- option_kind: acp::PermissionOptionKind,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.conversation.update(cx, |conversation, cx| {
- conversation.authorize_tool_call(session_id, tool_call_id, outcome, option_kind, cx);
+ conversation.authorize_tool_call(session_id, tool_call_id, outcome, cx);
});
if self.should_be_following {
self.workspace
@@ -1643,8 +1647,7 @@ impl ThreadView {
self.authorize_tool_call(
self.id.clone(),
tool_call_id,
- option_id.into(),
- option_kind,
+ SelectedPermissionOutcome::new(option_id, option_kind),
window,
cx,
);
@@ -1735,16 +1738,9 @@ impl ThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<()> {
- let (choices, dropdown_with_patterns) = match options {
- PermissionOptions::Dropdown(choices) => (choices.as_slice(), None),
- PermissionOptions::DropdownWithPatterns {
- choices,
- patterns,
- tool_name,
- } => (
- choices.as_slice(),
- Some((patterns.as_slice(), tool_name.as_str())),
- ),
+ let choices = match options {
+ PermissionOptions::Dropdown(choices) => choices.as_slice(),
+ PermissionOptions::DropdownWithPatterns { choices, .. } => choices.as_slice(),
_ => {
let kind = if is_allow {
acp::PermissionOptionKind::AllowOnce
@@ -1758,34 +1754,9 @@ impl ThreadView {
let selection = self.permission_selections.get(&tool_call_id);
// When in per-command pattern mode, use the checked patterns.
- if let Some(PermissionSelection::SelectedPatterns(checked)) = selection
- && let Some((patterns, tool_name)) = dropdown_with_patterns
- {
- let checked_patterns: Vec<_> = patterns
- .iter()
- .enumerate()
- .filter(|(index, _)| checked.contains(index))
- .map(|(_, cp)| cp.pattern.clone())
- .collect();
-
- if !checked_patterns.is_empty() {
- let (option_id_str, kind) = if is_allow {
- (
- format!("always_allow:{}", tool_name),
- acp::PermissionOptionKind::AllowAlways,
- )
- } else {
- (
- format!("always_deny:{}", tool_name),
- acp::PermissionOptionKind::RejectAlways,
- )
- };
- let outcome =
- SelectedPermissionOutcome::new(acp::PermissionOptionId::new(option_id_str))
- .params(Some(SelectedPermissionParams::Terminal {
- patterns: checked_patterns,
- }));
- self.authorize_tool_call(session_id, tool_call_id, outcome, kind, window, cx);
+ if let Some(PermissionSelection::SelectedPatterns(checked)) = selection {
+ if let Some(outcome) = options.build_outcome_for_checked_patterns(checked, is_allow) {
+ self.authorize_tool_call(session_id, tool_call_id, outcome, window, cx);
return Some(());
}
}
@@ -1796,32 +1767,9 @@ impl ThreadView {
.unwrap_or_else(|| choices.len().saturating_sub(1));
let selected_choice = choices.get(selected_index).or(choices.last())?;
+ let outcome = selected_choice.build_outcome(is_allow);
- let selected_option = if is_allow {
- &selected_choice.allow
- } else {
- &selected_choice.deny
- };
-
- let params = if !selected_choice.sub_patterns.is_empty() {
- Some(SelectedPermissionParams::Terminal {
- patterns: selected_choice.sub_patterns.clone(),
- })
- } else {
- None
- };
-
- let outcome =
- SelectedPermissionOutcome::new(selected_option.option_id.clone()).params(params);
-
- self.authorize_tool_call(
- session_id,
- tool_call_id,
- outcome,
- selected_option.kind,
- window,
- cx,
- );
+ self.authorize_tool_call(session_id, tool_call_id, outcome, window, cx);
Some(())
}
@@ -4655,7 +4603,10 @@ impl ThreadView {
.language_for_name("Markdown");
let thread = self.thread.read(cx);
- let thread_title = thread.title().to_string();
+ let thread_title = thread
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into())
+ .to_string();
let markdown = thread.to_markdown(cx);
let project = workspace.read(cx).project().clone();
@@ -6367,7 +6318,7 @@ impl ThreadView {
focus_handle: &FocusHandle,
cx: &Context<Self>,
) -> Div {
- let mut seen_kinds: ArrayVec<acp::PermissionOptionKind, 3> = ArrayVec::new();
+ let mut seen_kinds: ArrayVec<acp::PermissionOptionKind, 3, u8> = ArrayVec::new();
div()
.p_1()
@@ -6417,7 +6368,7 @@ impl ThreadView {
return this;
}
- seen_kinds.push(option.kind);
+ seen_kinds.push(option.kind).unwrap();
this.key_binding(
KeyBinding::for_action_in(action, focus_handle, cx)
@@ -6434,8 +6385,7 @@ impl ThreadView {
this.authorize_tool_call(
session_id.clone(),
tool_call_id.clone(),
- option_id.clone().into(),
- option_kind,
+ SelectedPermissionOutcome::new(option_id.clone(), option_kind),
window,
cx,
);
@@ -7067,7 +7017,7 @@ impl ThreadView {
let thread_title = thread
.as_ref()
- .map(|t| t.read(cx).title())
+ .and_then(|t| t.read(cx).title())
.filter(|t| !t.is_empty());
let tool_call_label = tool_call.label.read(cx).source().to_string();
let has_tool_call_label = !tool_call_label.is_empty();
@@ -739,7 +739,7 @@ mod tests {
/// Inserts a list of images into the editor as context mentions.
/// This is the shared implementation used by both paste and file picker operations.
pub(crate) async fn insert_images_as_context(
- images: Vec<gpui::Image>,
+ images: Vec<(gpui::Image, SharedString)>,
editor: Entity<Editor>,
mention_set: Entity<MentionSet>,
workspace: WeakEntity<Workspace>,
@@ -751,7 +751,7 @@ pub(crate) async fn insert_images_as_context(
let replacement_text = MentionUri::PastedImage.as_link().to_string();
- for image in images {
+ for (image, name) in images {
let Some((excerpt_id, text_anchor, multibuffer_anchor)) = editor
.update_in(cx, |editor, window, cx| {
let snapshot = editor.snapshot(window, cx);
@@ -785,7 +785,7 @@ pub(crate) async fn insert_images_as_context(
excerpt_id,
text_anchor,
content_len,
- MentionUri::PastedImage.name().into(),
+ name.clone(),
IconName::Image.path().into(),
None,
None,
@@ -856,10 +856,11 @@ pub(crate) fn paste_images_as_context(
Some(window.spawn(cx, async move |mut cx| {
use itertools::Itertools;
- let (mut images, paths) = clipboard
+ let default_name: SharedString = MentionUri::PastedImage.name().into();
+ let (mut images, paths): (Vec<(gpui::Image, SharedString)>, Vec<_>) = clipboard
.into_entries()
.filter_map(|entry| match entry {
- ClipboardEntry::Image(image) => Some(Either::Left(image)),
+ ClipboardEntry::Image(image) => Some(Either::Left((image, default_name.clone()))),
ClipboardEntry::ExternalPaths(paths) => Some(Either::Right(paths)),
_ => None,
})
@@ -870,24 +871,32 @@ pub(crate) fn paste_images_as_context(
cx.background_spawn(async move {
let mut images = vec![];
for path in paths.into_iter().flat_map(|paths| paths.paths().to_owned()) {
- let Ok(content) = async_fs::read(path).await else {
+ let Ok(content) = async_fs::read(&path).await else {
continue;
};
let Ok(format) = image::guess_format(&content) else {
continue;
};
- images.push(gpui::Image::from_bytes(
- match format {
- image::ImageFormat::Png => gpui::ImageFormat::Png,
- image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
- image::ImageFormat::WebP => gpui::ImageFormat::Webp,
- image::ImageFormat::Gif => gpui::ImageFormat::Gif,
- image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
- image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
- image::ImageFormat::Ico => gpui::ImageFormat::Ico,
- _ => continue,
- },
- content,
+ let name: SharedString = path
+ .file_name()
+ .and_then(|n| n.to_str())
+ .map(|s| SharedString::from(s.to_owned()))
+ .unwrap_or_else(|| default_name.clone());
+ images.push((
+ gpui::Image::from_bytes(
+ match format {
+ image::ImageFormat::Png => gpui::ImageFormat::Png,
+ image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
+ image::ImageFormat::WebP => gpui::ImageFormat::Webp,
+ image::ImageFormat::Gif => gpui::ImageFormat::Gif,
+ image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
+ image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
+ image::ImageFormat::Ico => gpui::ImageFormat::Ico,
+ _ => continue,
+ },
+ content,
+ ),
+ name,
));
}
images
@@ -1,3 +1,4 @@
+use crate::DEFAULT_THREAD_TITLE;
use crate::SendImmediately;
use crate::ThreadHistory;
use crate::{
@@ -14,7 +15,6 @@ use acp_thread::MentionUri;
use agent::ThreadStore;
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
-use collections::HashSet;
use editor::{
Addon, AnchorRangeExt, ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode,
EditorStyle, Inlay, MultiBuffer, MultiBufferOffset, MultiBufferSnapshot, ToOffset,
@@ -25,7 +25,7 @@ use gpui::{
AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat,
KeyContext, SharedString, Subscription, Task, TextStyle, WeakEntity,
};
-use language::{Buffer, Language, language_settings::InlayHintKind};
+use language::{Buffer, language_settings::InlayHintKind};
use parking_lot::RwLock;
use project::AgentId;
use project::{CompletionIntent, InlayHint, InlayHintLabel, InlayId, Project, Worktree};
@@ -172,16 +172,18 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let language = Language::new(
- language::LanguageConfig {
- completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
- ..Default::default()
- },
- None,
- );
+ let language_registry = project
+ .upgrade()
+ .map(|project| project.read(cx).languages().clone());
let editor = cx.new(|cx| {
- let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
+ let buffer = cx.new(|cx| {
+ let buffer = Buffer::local("", cx);
+ if let Some(language_registry) = language_registry.as_ref() {
+ buffer.set_language_registry(language_registry.clone());
+ }
+ buffer
+ });
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(mode, buffer, None, window, cx);
@@ -287,6 +289,22 @@ impl MessageEditor {
}
}));
+ if let Some(language_registry) = language_registry {
+ let editor = editor.clone();
+ cx.spawn(async move |_, cx| {
+ let markdown = language_registry.language_for_name("Markdown").await?;
+ editor.update(cx, |editor, cx| {
+ if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_language(Some(markdown), cx);
+ });
+ }
+ });
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
Self {
editor,
mention_set,
@@ -370,7 +388,7 @@ impl MessageEditor {
};
let thread_title = title
.filter(|title| !title.is_empty())
- .unwrap_or_else(|| SharedString::new_static("New Thread"));
+ .unwrap_or_else(|| SharedString::new_static(DEFAULT_THREAD_TITLE));
let uri = MentionUri::Thread {
id: session_id,
name: thread_title.to_string(),
@@ -1349,7 +1367,12 @@ impl MessageEditor {
continue;
};
- images.push(gpui::Image::from_bytes(format, content));
+ let name: gpui::SharedString = path
+ .file_name()
+ .and_then(|n| n.to_str())
+ .map(|s| gpui::SharedString::from(s.to_owned()))
+ .unwrap_or_else(|| "Image".into());
+ images.push((gpui::Image::from_bytes(format, content), name));
}
crate::mention_set::insert_images_as_context(
@@ -1,5 +1,7 @@
use crate::thread_history::ThreadHistory;
-use crate::{AgentPanel, ConversationView, RemoveHistory, RemoveSelectedThread};
+use crate::{
+ AgentPanel, ConversationView, DEFAULT_THREAD_TITLE, RemoveHistory, RemoveSelectedThread,
+};
use acp_thread::AgentSessionInfo;
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta, Utc};
use editor::{Editor, EditorEvent};
@@ -16,14 +18,12 @@ use ui::{
WithScrollbar, prelude::*,
};
-const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
-
-pub(crate) fn thread_title(entry: &AgentSessionInfo) -> &SharedString {
+pub(crate) fn thread_title(entry: &AgentSessionInfo) -> SharedString {
entry
.title
- .as_ref()
+ .clone()
.filter(|title| !title.is_empty())
- .unwrap_or(DEFAULT_TITLE)
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into())
}
pub struct ThreadHistoryView {
@@ -203,7 +203,7 @@ impl ThreadHistoryView {
let mut candidates = Vec::with_capacity(entries.len());
for (idx, entry) in entries.iter().enumerate() {
- candidates.push(StringMatchCandidate::new(idx, thread_title(entry)));
+ candidates.push(StringMatchCandidate::new(idx, &thread_title(entry)));
}
const MAX_MATCHES: usize = 100;
@@ -429,7 +429,7 @@ impl ThreadHistoryView {
(_, None) => "β".to_string(),
};
- let title = thread_title(entry).clone();
+ let title = thread_title(entry);
let full_date = entry_time
.map(|time| {
EntryTimeFormat::DateAndTime.format_timestamp(time.timestamp(), self.local_timezone)
@@ -678,7 +678,7 @@ impl HistoryEntryElement {
impl RenderOnce for HistoryEntryElement {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let id = ElementId::Name(self.entry.session_id.0.clone().into());
- let title = thread_title(&self.entry).clone();
+ let title = thread_title(&self.entry);
let formatted_time = self
.entry
.updated_at
@@ -21,6 +21,8 @@ use ui::{App, Context, SharedString};
use util::ResultExt as _;
use workspace::PathList;
+use crate::DEFAULT_THREAD_TITLE;
+
pub fn init(cx: &mut App) {
SidebarThreadMetadataStore::init_global(cx);
@@ -81,12 +83,16 @@ fn migrate_thread_metadata(cx: &mut App) {
.collect::<Vec<_>>()
});
+ log::info!("Migrating {} thread store entries", metadata.len());
+
// Manually save each entry to the database and call reload, otherwise
// we'll end up triggering lots of reloads after each save
for entry in metadata {
db.save(entry).await?;
}
+ log::info!("Finished migrating thread store entries");
+
let _ = store.update(cx, |store, cx| store.reload(cx));
Ok(())
})
@@ -134,7 +140,9 @@ impl ThreadMetadata {
pub fn from_thread(thread: &Entity<acp_thread::AcpThread>, cx: &App) -> Self {
let thread_ref = thread.read(cx);
let session_id = thread_ref.session_id().clone();
- let title = thread_ref.title();
+ let title = thread_ref
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into());
let updated_at = Utc::now();
let agent_id = thread_ref.connection().agent_id();
@@ -987,7 +995,7 @@ mod tests {
cx.new(|cx| {
acp_thread::AcpThread::new(
Some(regular_session_id.clone()),
- "Subagent Thread",
+ Some("Subagent Thread".into()),
None,
connection.clone(),
project.clone(),
@@ -575,7 +575,7 @@ impl ThreadsArchiveView {
.when(can_unarchive, |this| {
this.child(
Button::new("unarchive-thread", "Restore")
- .style(ButtonStyle::OutlinedGhost)
+ .style(ButtonStyle::Filled)
.label_size(LabelSize::Small)
.when(is_focused, |this| {
this.key_binding(
@@ -606,6 +606,7 @@ impl ThreadsArchiveView {
"delete-thread",
IconName::Trash,
)
+ .style(ButtonStyle::Filled)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip({
@@ -901,14 +901,16 @@ impl TextThreadStore {
cx,
);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
self.slash_commands.remove(&slash_command_ids);
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
}
@@ -17,6 +17,7 @@ test-support = ["gpui/test-support"]
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
+base64.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -24,14 +25,17 @@ http_client = { workspace = true, features = ["test-support"] }
log.workspace = true
net.workspace = true
parking_lot.workspace = true
+rand.workspace = true
postage.workspace = true
schemars.workspace = true
serde_json.workspace = true
serde.workspace = true
settings.workspace = true
+sha2.workspace = true
slotmap.workspace = true
smol.workspace = true
tempfile.workspace = true
+tiny_http.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
terminal.workspace = true
@@ -35,7 +35,7 @@ pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
-type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
+type ResponseHandler = Box<dyn Send + FnOnce(String)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
@@ -62,6 +62,14 @@ pub(crate) struct Client {
#[allow(dead_code)]
transport: Arc<dyn Transport>,
request_timeout: Option<Duration>,
+ /// Single-slot side channel for the last transport-level error. When the
+ /// output task encounters a send failure it stashes the error here and
+ /// exits; the next request to observe cancellation `.take()`s it so it can
+ /// propagate a typed error (e.g. `TransportError::AuthRequired`) instead
+ /// of a generic "cancelled". This works because `initialize` is the sole
+ /// in-flight request at startup, but would need rethinking if concurrent
+ /// requests are ever issued during that phase.
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -223,13 +231,16 @@ impl Client {
input.or(err)
});
+ let last_transport_error: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
let output_task = cx.background_spawn({
let transport = transport.clone();
+ let last_transport_error = last_transport_error.clone();
Self::handle_output(
transport,
outbound_rx,
output_done_tx,
response_handlers.clone(),
+ last_transport_error,
)
.log_err()
});
@@ -246,6 +257,7 @@ impl Client {
output_done_rx: Mutex::new(Some(output_done_rx)),
transport,
request_timeout,
+ last_transport_error,
})
}
@@ -279,7 +291,7 @@ impl Client {
if let Some(handlers) = response_handlers.lock().as_mut()
&& let Some(handler) = handlers.remove(&response.id)
{
- handler(Ok(message.to_string()));
+ handler(message.to_string());
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
subscription_set.lock().notify(
@@ -315,6 +327,7 @@ impl Client {
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
) -> anyhow::Result<()> {
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
@@ -324,7 +337,11 @@ impl Client {
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
- transport.send(message).await?;
+ if let Err(err) = transport.send(message).await {
+ log::debug!("transport send failed: {:#}", err);
+ *last_transport_error.lock() = Some(err);
+ return Ok(());
+ }
}
drop(output_done_tx);
Ok(())
@@ -408,7 +425,7 @@ impl Client {
response = rx.fuse() => {
let elapsed = started.elapsed();
log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
- match response? {
+ match response {
Ok(response) => {
let parsed: AnyResponse = serde_json::from_str(&response)?;
if let Some(error) = parsed.error {
@@ -419,7 +436,12 @@ impl Client {
anyhow::bail!("Invalid response: no result or error");
}
}
- Err(_) => anyhow::bail!("cancelled")
+ Err(_canceled) => {
+ if let Some(err) = self.last_transport_error.lock().take() {
+ return Err(err);
+ }
+ anyhow::bail!("cancelled")
+ }
}
}
_ = cancel_fut => {
@@ -1,5 +1,6 @@
pub mod client;
pub mod listener;
+pub mod oauth;
pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
@@ -0,0 +1,2800 @@
+//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
+//! PKCE flow, per the MCP spec's OAuth profile.
+//!
+//! The flow is split into two phases:
+//!
+//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
+//! Authorization Server Metadata. This can happen early (e.g. on a 401
+//! during server startup) because it doesn't need the redirect URI yet.
+//!
+//! 2. **Client registration** ([`resolve_client_registration`]) is separate
+//! because DCR requires the actual loopback redirect URI, which includes an
+//! ephemeral port that only exists once the callback server has started.
+//!
+//! After authentication, the full state is captured in [`OAuthSession`] which
+//! is persisted to the keychain. On next startup, the stored session feeds
+//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
+//! without requiring another browser flow.
+
+use anyhow::{Context as _, Result, anyhow, bail};
+use async_trait::async_trait;
+use base64::Engine as _;
+use futures::AsyncReadExt as _;
+use futures::channel::mpsc;
+use http_client::{AsyncBody, HttpClient, Request};
+use parking_lot::Mutex as SyncMutex;
+use rand::Rng as _;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::{Duration, SystemTime};
+use url::Url;
+use util::ResultExt as _;
+
+/// The CIMD URL where Zed's OAuth client metadata document is hosted.
+pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
+
+/// Validate that a URL is safe to use as an OAuth endpoint.
+///
+/// OAuth endpoints carry sensitive material (authorization codes, PKCE
+/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
+/// loopback addresses, per RFC 8252 Section 8.3.
+fn require_https_or_loopback(url: &Url) -> Result<()> {
+ if url.scheme() == "https" {
+ return Ok(());
+ }
+ if url.scheme() == "http" {
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
+ _ => {}
+ }
+ }
+ }
+ bail!(
+ "OAuth endpoint must use HTTPS (got {}://{})",
+ url.scheme(),
+ url.host_str().unwrap_or("?")
+ )
+}
+
+/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
+/// protections against private/reserved IP ranges.
+///
+/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
+/// an attacker-controlled MCP server from directing Zed to fetch internal
+/// network resources via metadata URLs.
+///
+/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
+/// blocked here β full mitigation requires resolver-level validation (e.g. a
+/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
+fn validate_oauth_url(url: &Url) -> Result<()> {
+ require_https_or_loopback(url)?;
+
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) => {
+ // Loopback is already allowed by require_https_or_loopback.
+ if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Ipv6(ip) => {
+ // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
+ // could bypass the IPv4 checks above.
+ if let Some(mapped_v4) = ip.to_ipv4_mapped() {
+ if mapped_v4.is_private()
+ || mapped_v4.is_link_local()
+ || mapped_v4.is_broadcast()
+ || mapped_v4.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
+ mapped_v4
+ );
+ }
+ }
+
+ if ip.is_unspecified() || ip.is_multicast() {
+ bail!(
+ "OAuth endpoint must not point to reserved IPv6 address: {}",
+ ip
+ );
+ }
+ // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
+ // nightly-only, so check the prefix manually.
+ if (ip.segments()[0] & 0xfe00) == 0xfc00 {
+ bail!(
+ "OAuth endpoint must not point to IPv6 unique-local address: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Domain(_) => {
+ // Domain-based SSRF prevention requires resolver-level checks.
+ // See known limitation in the doc comment above.
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
+/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ProtectedResourceMetadata {
+ pub resource: Url,
+ pub authorization_servers: Vec<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+}
+
+/// Parsed from the authorization server's .well-known endpoint
+/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthServerMetadata {
+ pub issuer: Url,
+ pub authorization_endpoint: Url,
+ pub token_endpoint: Url,
+ pub registration_endpoint: Option<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+ pub code_challenge_methods_supported: Option<Vec<String>>,
+ pub client_id_metadata_document_supported: bool,
+}
+
+/// The result of client registration β either CIMD or DCR.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthClientRegistration {
+ pub client_id: String,
+ /// Only present for DCR-minted registrations.
+ pub client_secret: Option<String>,
+}
+
+impl std::fmt::Debug for OAuthClientRegistration {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthClientRegistration")
+ .field("client_id", &self.client_id)
+ .field(
+ "client_secret",
+ &self.client_secret.as_ref().map(|_| "[redacted]"),
+ )
+ .finish()
+ }
+}
+
+/// Access and refresh tokens obtained from the token endpoint.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthTokens {
+ pub access_token: String,
+ pub refresh_token: Option<String>,
+ pub expires_at: Option<SystemTime>,
+}
+
+impl std::fmt::Debug for OAuthTokens {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthTokens")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_at", &self.expires_at)
+ .finish()
+ }
+}
+
+/// Everything discovered before the browser flow starts. Client registration is
+/// resolved separately, once the real redirect URI is known.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct OAuthDiscovery {
+ pub resource_metadata: ProtectedResourceMetadata,
+ pub auth_server_metadata: AuthServerMetadata,
+ pub scopes: Vec<String>,
+}
+
+/// The persisted OAuth session for a context server.
+///
+/// Stored in the keychain so startup can restore a refresh-capable provider
+/// without another browser flow. Deliberately excludes the full discovery
+/// metadata to keep the serialized size well within keychain item limits.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthSession {
+ pub token_endpoint: Url,
+ pub resource: Url,
+ pub client_registration: OAuthClientRegistration,
+ pub tokens: OAuthTokens,
+}
+
+impl std::fmt::Debug for OAuthSession {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthSession")
+ .field("token_endpoint", &self.token_endpoint)
+ .field("resource", &self.resource)
+ .field("client_registration", &self.client_registration)
+ .field("tokens", &self.tokens)
+ .finish()
+ }
+}
+
+/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum BearerError {
+ /// The request is missing a required parameter, includes an unsupported
+ /// parameter or parameter value, or is otherwise malformed.
+ InvalidRequest,
+ /// The access token provided is expired, revoked, malformed, or invalid.
+ InvalidToken,
+ /// The request requires higher privileges than provided by the access token.
+ InsufficientScope,
+ /// An unrecognized error code (extension or future spec addition).
+ Other,
+}
+
+impl BearerError {
+ fn parse(value: &str) -> Self {
+ match value {
+ "invalid_request" => BearerError::InvalidRequest,
+ "invalid_token" => BearerError::InvalidToken,
+ "insufficient_scope" => BearerError::InsufficientScope,
+ _ => BearerError::Other,
+ }
+ }
+}
+
+/// Fields extracted from a `WWW-Authenticate: Bearer` header.
+///
+/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
+/// at the Protected Resource Metadata document. The optional `scope` parameter
+/// (RFC 6750 Section 3) indicates scopes required for the request.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct WwwAuthenticate {
+ pub resource_metadata: Option<Url>,
+ pub scope: Option<Vec<String>>,
+ /// The parsed `error` parameter per RFC 6750 Section 3.1.
+ pub error: Option<BearerError>,
+ pub error_description: Option<String>,
+}
+
+/// Parse a `WWW-Authenticate` header value.
+///
+/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
+/// Per RFC 6750 and RFC 9728, the relevant parameters are:
+/// - `resource_metadata` β URL of the Protected Resource Metadata document
+/// - `scope` β space-separated list of required scopes
+/// - `error` β error code (e.g. "insufficient_scope")
+/// - `error_description` β human-readable error description
+pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
+ let header = header.trim();
+
+ let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
+ header[6..].trim()
+ } else {
+ bail!("WWW-Authenticate header does not use Bearer scheme");
+ };
+
+ if params_str.is_empty() {
+ return Ok(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+ }
+
+ let params = parse_auth_params(params_str);
+
+ let resource_metadata = params
+ .get("resource_metadata")
+ .map(|v| Url::parse(v))
+ .transpose()
+ .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
+
+ let scope = params
+ .get("scope")
+ .map(|v| v.split_whitespace().map(String::from).collect());
+
+ let error = params.get("error").map(|v| BearerError::parse(v));
+ let error_description = params.get("error_description").cloned();
+
+ Ok(WwwAuthenticate {
+ resource_metadata,
+ scope,
+ error,
+ error_description,
+ })
+}
+
+/// Parse comma-separated `key="value"` or `key=token` parameters from an
+/// auth-param list (RFC 7235 Section 2.1).
+fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
+ let mut params = collections::HashMap::default();
+ let mut remaining = input.trim();
+
+ while !remaining.is_empty() {
+ // Skip leading whitespace and commas.
+ remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
+ if remaining.is_empty() {
+ break;
+ }
+
+ // Find the key (everything before '=').
+ let eq_pos = match remaining.find('=') {
+ Some(pos) => pos,
+ None => break,
+ };
+
+ let key = remaining[..eq_pos].trim().to_lowercase();
+ remaining = &remaining[eq_pos + 1..];
+ remaining = remaining.trim_start();
+
+ // Parse the value: either quoted or unquoted (token).
+ let value;
+ if remaining.starts_with('"') {
+ // Quoted string: find the closing quote, handling escaped chars.
+ remaining = &remaining[1..]; // skip opening quote
+ let mut val = String::new();
+ let mut chars = remaining.char_indices();
+ loop {
+ match chars.next() {
+ Some((_, '\\')) => {
+ // Escaped character β take the next char literally.
+ if let Some((_, c)) = chars.next() {
+ val.push(c);
+ }
+ }
+ Some((i, '"')) => {
+ remaining = &remaining[i + 1..];
+ break;
+ }
+ Some((_, c)) => val.push(c),
+ None => {
+ remaining = "";
+ break;
+ }
+ }
+ }
+ value = val;
+ } else {
+ // Unquoted token: read until comma or whitespace.
+ let end = remaining
+ .find(|c: char| c == ',' || c.is_whitespace())
+ .unwrap_or(remaining.len());
+ value = remaining[..end].to_string();
+ remaining = &remaining[end..];
+ }
+
+ if !key.is_empty() {
+ params.insert(key, value);
+ }
+ }
+
+ params
+}
+
+/// Construct the well-known Protected Resource Metadata URIs for a given MCP
+/// server URL, per RFC 9728 Section 3.
+///
+/// Returns URIs in priority order:
+/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
+/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
+pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", server_url.scheme(), server_url.authority());
+
+ let path = server_url.path().trim_start_matches('/');
+ if !path.is_empty() {
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-protected-resource/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ }
+
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
+ urls.push(url);
+ }
+
+ urls
+}
+
+/// Construct the well-known Authorization Server Metadata URIs for a given
+/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
+///
+/// Returns URIs in priority order, which differs depending on whether the
+/// issuer URL has a path component.
+pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", issuer.scheme(), issuer.authority());
+ let path = issuer.path().trim_matches('/');
+
+ if !path.is_empty() {
+ // Issuer with path: try path-inserted variants first.
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-authorization-server/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/openid-configuration/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/{}/.well-known/openid-configuration",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ } else {
+ // No path: standard well-known locations.
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
+ urls.push(url);
+ }
+ }
+
+ urls
+}
+
+// -- Canonical server URI (RFC 8707) -----------------------------------------
+
+/// Derive the canonical resource URI for an MCP server URL, suitable for the
+/// `resource` parameter in authorization and token requests per RFC 8707.
+///
+/// Lowercases the scheme and host, preserves the path (without trailing slash),
+/// strips fragments and query strings.
+pub fn canonical_server_uri(server_url: &Url) -> String {
+ let mut uri = format!(
+ "{}://{}",
+ server_url.scheme().to_ascii_lowercase(),
+ server_url.host_str().unwrap_or("").to_ascii_lowercase(),
+ );
+ if let Some(port) = server_url.port() {
+ uri.push_str(&format!(":{}", port));
+ }
+ let path = server_url.path();
+ if path != "/" {
+ uri.push_str(path.trim_end_matches('/'));
+ }
+ uri
+}
+
+// -- Scope selection ---------------------------------------------------------
+
+/// Select scopes following the MCP spec's Scope Selection Strategy:
+/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
+/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
+/// 3. Return empty if neither is available.
+pub fn select_scopes(
+ www_authenticate: &WwwAuthenticate,
+ resource_metadata: &ProtectedResourceMetadata,
+) -> Vec<String> {
+ if let Some(ref scopes) = www_authenticate.scope {
+ if !scopes.is_empty() {
+ return scopes.clone();
+ }
+ }
+ resource_metadata
+ .scopes_supported
+ .clone()
+ .unwrap_or_default()
+}
+
+// -- Client registration strategy --------------------------------------------
+
+/// The registration approach to use, determined from auth server metadata.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ClientRegistrationStrategy {
+ /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
+ Cimd { client_id: String },
+ /// The auth server has a registration endpoint. Caller must POST to it.
+ Dcr { registration_endpoint: Url },
+ /// No supported registration mechanism.
+ Unavailable,
+}
+
+/// Determine how to register with the authorization server, following the
+/// spec's recommended priority: CIMD first, DCR fallback.
+pub fn determine_registration_strategy(
+ auth_server_metadata: &AuthServerMetadata,
+) -> ClientRegistrationStrategy {
+ if auth_server_metadata.client_id_metadata_document_supported {
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: endpoint.clone(),
+ }
+ } else {
+ ClientRegistrationStrategy::Unavailable
+ }
+}
+
+// -- PKCE (RFC 7636) ---------------------------------------------------------
+
+/// A PKCE code verifier and its S256 challenge.
+#[derive(Clone)]
+pub struct PkceChallenge {
+ pub verifier: String,
+ pub challenge: String,
+}
+
+impl std::fmt::Debug for PkceChallenge {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PkceChallenge")
+ .field("verifier", &"[redacted]")
+ .field("challenge", &self.challenge)
+ .finish()
+ }
+}
+
+/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
+///
+/// The verifier is 43 base64url characters derived from 32 random bytes.
+/// The challenge is `BASE64URL(SHA256(verifier))`.
+pub fn generate_pkce_challenge() -> PkceChallenge {
+ let mut random_bytes = [0u8; 32];
+ rand::rng().fill(&mut random_bytes);
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let verifier = engine.encode(&random_bytes);
+
+ let digest = Sha256::digest(verifier.as_bytes());
+ let challenge = engine.encode(digest);
+
+ PkceChallenge {
+ verifier,
+ challenge,
+ }
+}
+
+// -- Authorization URL construction ------------------------------------------
+
+/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
+pub fn build_authorization_url(
+ auth_server_metadata: &AuthServerMetadata,
+ client_id: &str,
+ redirect_uri: &str,
+ scopes: &[String],
+ resource: &str,
+ pkce: &PkceChallenge,
+ state: &str,
+) -> Url {
+ let mut url = auth_server_metadata.authorization_endpoint.clone();
+ {
+ let mut query = url.query_pairs_mut();
+ query.append_pair("response_type", "code");
+ query.append_pair("client_id", client_id);
+ query.append_pair("redirect_uri", redirect_uri);
+ if !scopes.is_empty() {
+ query.append_pair("scope", &scopes.join(" "));
+ }
+ query.append_pair("resource", resource);
+ query.append_pair("code_challenge", &pkce.challenge);
+ query.append_pair("code_challenge_method", "S256");
+ query.append_pair("state", state);
+ }
+ url
+}
+
+// -- Token endpoint request bodies -------------------------------------------
+
+/// The JSON body returned by the token endpoint on success.
+#[derive(Deserialize)]
+pub struct TokenResponse {
+ pub access_token: String,
+ #[serde(default)]
+ pub refresh_token: Option<String>,
+ #[serde(default)]
+ pub expires_in: Option<u64>,
+ #[serde(default)]
+ pub token_type: Option<String>,
+}
+
+impl std::fmt::Debug for TokenResponse {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TokenResponse")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_in", &self.expires_in)
+ .field("token_type", &self.token_type)
+ .finish()
+ }
+}
+
+impl TokenResponse {
+ /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
+ pub fn into_tokens(self) -> OAuthTokens {
+ let expires_at = self
+ .expires_in
+ .map(|secs| SystemTime::now() + Duration::from_secs(secs));
+ OAuthTokens {
+ access_token: self.access_token,
+ refresh_token: self.refresh_token,
+ expires_at,
+ }
+ }
+}
+
+/// Build the form-encoded body for an authorization code token exchange.
+pub fn token_exchange_params(
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "authorization_code".to_string()),
+ ("code", code.to_string()),
+ ("redirect_uri", redirect_uri.to_string()),
+ ("client_id", client_id.to_string()),
+ ("code_verifier", code_verifier.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+/// Build the form-encoded body for a token refresh request.
+pub fn token_refresh_params(
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "refresh_token".to_string()),
+ ("refresh_token", refresh_token.to_string()),
+ ("client_id", client_id.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+// -- DCR request body (RFC 7591) ---------------------------------------------
+
+/// Build the JSON body for a Dynamic Client Registration request.
+///
+/// The `redirect_uri` should be the actual loopback URI with the ephemeral
+/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
+/// redirect URI matching even for loopback addresses, so we register the
+/// exact URI we intend to use.
+pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
+ serde_json::json!({
+ "client_name": "Zed",
+ "redirect_uris": [redirect_uri],
+ "grant_types": ["authorization_code"],
+ "response_types": ["code"],
+ "token_endpoint_auth_method": "none"
+ })
+}
+
+// -- Discovery (async, hits real endpoints) ----------------------------------
+
+/// Fetch Protected Resource Metadata from the MCP server.
+///
+/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
+/// then falls back to well-known URIs constructed from `server_url`.
+pub async fn fetch_protected_resource_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<ProtectedResourceMetadata> {
+ let candidate_urls = match &www_authenticate.resource_metadata {
+ Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
+ Some(url) => {
+ log::warn!(
+ "Ignoring cross-origin resource_metadata URL {} \
+ (server origin: {})",
+ url,
+ server_url.origin().unicode_serialization()
+ );
+ protected_resource_metadata_urls(server_url)
+ }
+ None => protected_resource_metadata_urls(server_url),
+ };
+
+ for url in &candidate_urls {
+ match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ if response.authorization_servers.is_empty() {
+ bail!(
+ "Protected Resource Metadata at {} has no authorization_servers",
+ url
+ );
+ }
+ return Ok(ProtectedResourceMetadata {
+ resource: response.resource.unwrap_or_else(|| server_url.clone()),
+ authorization_servers: response.authorization_servers,
+ scopes_supported: response.scopes_supported,
+ });
+ }
+ Err(err) => {
+ log::debug!(
+ "Failed to fetch Protected Resource Metadata from {}: {}",
+ url,
+ err
+ );
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Protected Resource Metadata for {}",
+ server_url
+ )
+}
+
+/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
+/// endpoints in the priority order specified by the MCP spec.
+pub async fn fetch_auth_server_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ issuer: &Url,
+) -> Result<AuthServerMetadata> {
+ let candidate_urls = auth_server_metadata_urls(issuer);
+
+ for url in &candidate_urls {
+ match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
+ if reported_issuer != *issuer {
+ bail!(
+ "Auth server metadata issuer mismatch: expected {}, got {}",
+ issuer,
+ reported_issuer
+ );
+ }
+
+ return Ok(AuthServerMetadata {
+ issuer: reported_issuer,
+ authorization_endpoint: response
+ .authorization_endpoint
+ .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
+ token_endpoint: response
+ .token_endpoint
+ .ok_or_else(|| anyhow!("missing token_endpoint"))?,
+ registration_endpoint: response.registration_endpoint,
+ scopes_supported: response.scopes_supported,
+ code_challenge_methods_supported: response.code_challenge_methods_supported,
+ client_id_metadata_document_supported: response
+ .client_id_metadata_document_supported
+ .unwrap_or(false),
+ });
+ }
+ Err(err) => {
+ log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Authorization Server Metadata for {}",
+ issuer
+ )
+}
+
+/// Run the full discovery flow: fetch resource metadata, then auth server
+/// metadata, then select scopes. Client registration is resolved separately,
+/// once the real redirect URI is known.
+pub async fn discover(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<OAuthDiscovery> {
+ let resource_metadata =
+ fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
+
+ let auth_server_url = resource_metadata
+ .authorization_servers
+ .first()
+ .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
+
+ let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
+
+ // Verify PKCE S256 support (spec requirement).
+ match &auth_server_metadata.code_challenge_methods_supported {
+ Some(methods) if methods.iter().any(|m| m == "S256") => {}
+ Some(_) => bail!("authorization server does not support S256 PKCE"),
+ None => bail!("authorization server does not advertise code_challenge_methods_supported"),
+ }
+
+ // Verify there is at least one supported registration strategy before we
+ // present the server as ready to authenticate.
+ match determine_registration_strategy(&auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+
+ let scopes = select_scopes(www_authenticate, &resource_metadata);
+
+ Ok(OAuthDiscovery {
+ resource_metadata,
+ auth_server_metadata,
+ scopes,
+ })
+}
+
+/// Resolve the OAuth client registration for an authorization flow.
+///
+/// CIMD uses the static client metadata document directly. For DCR, a fresh
+/// registration is performed each time because the loopback redirect URI
+/// includes an ephemeral port that changes every flow.
+pub async fn resolve_client_registration(
+ http_client: &Arc<dyn HttpClient>,
+ discovery: &OAuthDiscovery,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ match determine_registration_strategy(&discovery.auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
+ client_id,
+ client_secret: None,
+ }),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint,
+ } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+}
+
+// -- Dynamic Client Registration (RFC 7591) ----------------------------------
+
+/// Perform Dynamic Client Registration with the authorization server.
+pub async fn perform_dcr(
+ http_client: &Arc<dyn HttpClient>,
+ registration_endpoint: &Url,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ validate_oauth_url(registration_endpoint)?;
+
+ let body = dcr_registration_body(redirect_uri);
+ let body_bytes = serde_json::to_vec(&body)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(registration_endpoint.as_str())
+ .header("Content-Type", "application/json")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body_bytes))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "DCR failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let dcr_response: DcrResponse =
+ serde_json::from_str(&response_body).context("failed to parse DCR response")?;
+
+ Ok(OAuthClientRegistration {
+ client_id: dcr_response.client_id,
+ client_secret: dcr_response.client_secret,
+ })
+}
+
+// -- Token exchange and refresh (async) --------------------------------------
+
+/// Exchange an authorization code for tokens at the token endpoint.
+pub async fn exchange_code(
+ http_client: &Arc<dyn HttpClient>,
+ auth_server_metadata: &AuthServerMetadata,
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
+ post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
+}
+
+/// Refresh tokens using a refresh token.
+pub async fn refresh_tokens(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_refresh_params(refresh_token, client_id, resource);
+ post_token_request(http_client, token_endpoint, ¶ms).await
+}
+
+/// POST form-encoded parameters to a token endpoint and parse the response.
+async fn post_token_request(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ params: &[(&str, String)],
+) -> Result<OAuthTokens> {
+ validate_oauth_url(token_endpoint)?;
+
+ let body = url::form_urlencoded::Serializer::new(String::new())
+ .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
+ .finish();
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(token_endpoint.as_str())
+ .header("Content-Type", "application/x-www-form-urlencoded")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body.into_bytes()))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "token request failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let token_response: TokenResponse =
+ serde_json::from_str(&response_body).context("failed to parse token response")?;
+
+ Ok(token_response.into_tokens())
+}
+
+// -- Loopback HTTP callback server -------------------------------------------
+
+/// An OAuth authorization callback received via the loopback HTTP server.
+pub struct OAuthCallback {
+ pub code: String,
+ pub state: String,
+}
+
+impl std::fmt::Debug for OAuthCallback {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthCallback")
+ .field("code", &"[redacted]")
+ .field("state", &"[redacted]")
+ .finish()
+ }
+}
+
+impl OAuthCallback {
+ /// Parse the query string from a callback URL like
+ /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
+ pub fn parse_query(query: &str) -> Result<Self> {
+ let mut code: Option<String> = None;
+ let mut state: Option<String> = None;
+ let mut error: Option<String> = None;
+ let mut error_description: Option<String> = None;
+
+ for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
+ match key.as_ref() {
+ "code" => {
+ if !value.is_empty() {
+ code = Some(value.into_owned());
+ }
+ }
+ "state" => {
+ if !value.is_empty() {
+ state = Some(value.into_owned());
+ }
+ }
+ "error" => {
+ if !value.is_empty() {
+ error = Some(value.into_owned());
+ }
+ }
+ "error_description" => {
+ if !value.is_empty() {
+ error_description = Some(value.into_owned());
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
+ // checking for missing code/state.
+ if let Some(error_code) = error {
+ bail!(
+ "OAuth authorization failed: {} ({})",
+ error_code,
+ error_description.as_deref().unwrap_or("no description")
+ );
+ }
+
+ let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
+ let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
+
+ Ok(Self { code, state })
+ }
+}
+
+/// How long to wait for the browser to complete the OAuth flow before giving
+/// up and releasing the loopback port.
+const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
+
+/// Start a loopback HTTP server to receive the OAuth authorization callback.
+///
+/// Binds to an ephemeral loopback port for each flow.
+///
+/// Returns `(redirect_uri, callback_future)`. The caller should use the
+/// redirect URI in the authorization request, open the browser, then await
+/// the future to receive the callback.
+///
+/// The server accepts exactly one request on `/callback`, validates that it
+/// contains `code` and `state` query parameters, responds with a minimal
+/// HTML page telling the user they can close the tab, and shuts down.
+///
+/// The callback server shuts down when the returned oneshot receiver is dropped
+/// (e.g. because the authentication task was cancelled), or after a timeout
+/// ([CALLBACK_TIMEOUT]).
+pub async fn start_callback_server() -> Result<(
+ String,
+ futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
+)> {
+ let server = tiny_http::Server::http("127.0.0.1:0")
+ .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
+ let port = server
+ .server_addr()
+ .to_ip()
+ .context("server not bound to a TCP address")?
+ .port();
+
+ let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ // `tiny_http` is blocking, so we run it on a background thread.
+ // The `recv_timeout` loop lets us check for cancellation (the receiver
+ // being dropped) and enforce an overall timeout.
+ std::thread::spawn(move || {
+ let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
+
+ loop {
+ if tx.is_canceled() {
+ return;
+ }
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return;
+ }
+
+ let timeout = remaining.min(Duration::from_millis(500));
+ let Some(request) = (match server.recv_timeout(timeout) {
+ Ok(req) => req,
+ Err(_) => {
+ let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+ return;
+ }
+ }) else {
+ // Timeout with no request β loop back and check cancellation.
+ continue;
+ };
+
+ let result = handle_callback_request(&request);
+
+ let (status_code, body) = match &result {
+ Ok(_) => (
+ 200,
+ "<html><body><h1>Authorization successful</h1>\
+ <p>You can close this tab and return to Zed.</p></body></html>",
+ ),
+ Err(err) => {
+ log::error!("OAuth callback error: {}", err);
+ (
+ 400,
+ "<html><body><h1>Authorization failed</h1>\
+ <p>Something went wrong. Please try again from Zed.</p></body></html>",
+ )
+ }
+ };
+
+ let response = tiny_http::Response::from_string(body)
+ .with_status_code(status_code)
+ .with_header(
+ tiny_http::Header::from_str("Content-Type: text/html")
+ .expect("failed to construct response header"),
+ )
+ .with_header(
+ tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+ .expect("failed to construct response header"),
+ );
+ request.respond(response).log_err();
+
+ let _ = tx.send(result);
+ return;
+ }
+ });
+
+ Ok((redirect_uri, rx))
+}
+
+/// Extract the `code` and `state` query parameters from an OAuth callback
+/// request to `/callback`.
+fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
+ let url = Url::parse(&format!("http://localhost{}", request.url()))
+ .context("malformed callback request URL")?;
+
+ if url.path() != "/callback" {
+ bail!("unexpected path in OAuth callback: {}", url.path());
+ }
+
+ let query = url
+ .query()
+ .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+ OAuthCallback::parse_query(query)
+}
+
+// -- JSON fetch helper -------------------------------------------------------
+
+async fn fetch_json<T: serde::de::DeserializeOwned>(
+ http_client: &Arc<dyn HttpClient>,
+ url: &Url,
+) -> Result<T> {
+ validate_oauth_url(url)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::GET)
+ .uri(url.as_str())
+ .header("Accept", "application/json")
+ .body(AsyncBody::default())?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ bail!("HTTP {} fetching {}", response.status(), url);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
+}
+
+// -- Serde response types for discovery --------------------------------------
+
+#[derive(Debug, Deserialize)]
+struct ProtectedResourceMetadataResponse {
+ #[serde(default)]
+ resource: Option<Url>,
+ #[serde(default)]
+ authorization_servers: Vec<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+}
+
+#[derive(Debug, Deserialize)]
+struct AuthServerMetadataResponse {
+ #[serde(default)]
+ issuer: Option<Url>,
+ #[serde(default)]
+ authorization_endpoint: Option<Url>,
+ #[serde(default)]
+ token_endpoint: Option<Url>,
+ #[serde(default)]
+ registration_endpoint: Option<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+ #[serde(default)]
+ code_challenge_methods_supported: Option<Vec<String>>,
+ #[serde(default)]
+ client_id_metadata_document_supported: Option<bool>,
+}
+
+#[derive(Debug, Deserialize)]
+struct DcrResponse {
+ client_id: String,
+ #[serde(default)]
+ client_secret: Option<String>,
+}
+
+/// Provides OAuth tokens to the HTTP transport layer.
+///
+/// The transport calls `access_token()` before each request. On a 401 response
+/// it calls `try_refresh()` and retries once if the refresh succeeds.
+#[async_trait]
+pub trait OAuthTokenProvider: Send + Sync {
+ /// Returns the current access token, if one is available.
+ fn access_token(&self) -> Option<String>;
+
+ /// Attempts to refresh the access token. Returns `true` if a new token was
+ /// obtained and the request should be retried.
+ async fn try_refresh(&self) -> Result<bool>;
+}
+
+/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
+/// an HTTP client for token refresh. The same provider type is used both after
+/// an interactive authentication flow and when restoring a saved session from
+/// the keychain on startup.
+pub struct McpOAuthTokenProvider {
+ session: SyncMutex<OAuthSession>,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+}
+
+impl McpOAuthTokenProvider {
+ pub fn new(
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+ ) -> Self {
+ Self {
+ session: SyncMutex::new(session),
+ http_client,
+ token_refresh_tx,
+ }
+ }
+
+ fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
+ tokens.expires_at.is_some_and(|expires_at| {
+ SystemTime::now()
+ .checked_add(Duration::from_secs(30))
+ .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
+ })
+ }
+}
+
+#[async_trait]
+impl OAuthTokenProvider for McpOAuthTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ let session = self.session.lock();
+ if Self::access_token_is_expired(&session.tokens) {
+ return None;
+ }
+ Some(session.tokens.access_token.clone())
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ let (refresh_token, token_endpoint, resource, client_id) = {
+ let session = self.session.lock();
+ match session.tokens.refresh_token.clone() {
+ Some(refresh_token) => (
+ refresh_token,
+ session.token_endpoint.clone(),
+ session.resource.clone(),
+ session.client_registration.client_id.clone(),
+ ),
+ None => return Ok(false),
+ }
+ };
+
+ let resource_str = canonical_server_uri(&resource);
+
+ match refresh_tokens(
+ &self.http_client,
+ &token_endpoint,
+ &refresh_token,
+ &client_id,
+ &resource_str,
+ )
+ .await
+ {
+ Ok(mut new_tokens) => {
+ if new_tokens.refresh_token.is_none() {
+ new_tokens.refresh_token = Some(refresh_token);
+ }
+
+ {
+ let mut session = self.session.lock();
+ session.tokens = new_tokens;
+
+ if let Some(ref tx) = self.token_refresh_tx {
+ tx.unbounded_send(session.clone()).ok();
+ }
+ }
+
+ Ok(true)
+ }
+ Err(err) => {
+ log::warn!("OAuth token refresh failed: {}", err);
+ Ok(false)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use http_client::Response;
+
+ // -- require_https_or_loopback tests ------------------------------------
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_https() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_remote() {
+ let url = Url::parse("http://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
+ let url = Url::parse("http://[::1]:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost() {
+ let url = Url::parse("http://localhost:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
+ let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
+ let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_ftp() {
+ let url = Url::parse("ftp://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ // -- validate_oauth_url (SSRF) tests ------------------------------------
+
+ #[test]
+ fn test_validate_oauth_url_accepts_https_public() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_10() {
+ let url = Url::parse("https://10.0.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_172() {
+ let url = Url::parse("https://172.16.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_192() {
+ let url = Url::parse("https://192.168.1.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_link_local() {
+ let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_ula() {
+ let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_unspecified() {
+ let url = Url::parse("https://[::]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
+ let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
+ let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_http_loopback() {
+ // Loopback is permitted (it's our callback server).
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_https_public_ip() {
+ let url = Url::parse("https://93.184.216.34/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ // -- parse_www_authenticate tests ----------------------------------------
+
+ #[test]
+ fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "user:profile".to_string()])
+ );
+ assert_eq!(result.error, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_resource_metadata_only() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_bare_bearer() {
+ let result = parse_www_authenticate("Bearer").unwrap();
+ assert_eq!(result.resource_metadata, None);
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_with_error() {
+ let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(result.error, Some(BearerError::InsufficientScope));
+ assert_eq!(
+ result.error_description.as_deref(),
+ Some("Additional file write permission required")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "files:write".to_string()])
+ );
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_token_error() {
+ let header =
+ r#"Bearer error="invalid_token", error_description="The access token expired""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidToken));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_request_error() {
+ let header = r#"Bearer error="invalid_request""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidRequest));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_unknown_error() {
+ let header = r#"Bearer error="some_future_error""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::Other));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_rejects_non_bearer() {
+ let result = parse_www_authenticate("Basic realm=\"example\"");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_case_insensitive_scheme() {
+ let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_multiline_style() {
+ // Some servers emit the header spread across multiple lines joined by
+ // whitespace, as shown in the spec examples.
+ let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_with_path() {
+ let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_without_path() {
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 1);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://mcp.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_with_path() {
+ let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 3);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration/tenant1"
+ );
+ assert_eq!(
+ urls[2].as_str(),
+ "https://auth.example.com/tenant1/.well-known/openid-configuration"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_without_path() {
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration"
+ );
+ }
+
+ // -- Canonical server URI tests ------------------------------------------
+
+ #[test]
+ fn test_canonical_server_uri_simple() {
+ let url = Url::parse("https://mcp.example.com").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_with_path() {
+ let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_strips_trailing_slash() {
+ let url = Url::parse("https://mcp.example.com/").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_preserves_port() {
+ let url = Url::parse("https://mcp.example.com:8443").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_lowercases() {
+ let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
+ assert_eq!(
+ canonical_server_uri(&url),
+ "https://mcp.example.com/Server/MCP"
+ );
+ }
+
+ // -- Scope selection tests -----------------------------------------------
+
+ #[test]
+ fn test_select_scopes_prefers_www_authenticate() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
+ }
+
+ #[test]
+ fn test_select_scopes_falls_back_to_resource_metadata() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["admin".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
+ }
+
+ #[test]
+ fn test_select_scopes_empty_when_nothing_available() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: None,
+ };
+ assert!(select_scopes(&www_auth, &resource_meta).is_empty());
+ }
+
+ // -- Client registration strategy tests ----------------------------------
+
+ #[test]
+ fn test_registration_strategy_prefers_cimd() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_falls_back_to_dcr() {
+ let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(reg_endpoint.clone()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: reg_endpoint,
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_unavailable() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Unavailable,
+ );
+ }
+
+ // -- PKCE tests ----------------------------------------------------------
+
+ #[test]
+ fn test_pkce_challenge_verifier_length() {
+ let pkce = generate_pkce_challenge();
+ // 32 random bytes β 43 base64url chars (no padding).
+ assert_eq!(pkce.verifier.len(), 43);
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_valid_base64url() {
+ let pkce = generate_pkce_challenge();
+ for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
+ assert!(
+ c.is_ascii_alphanumeric() || c == '-' || c == '_',
+ "invalid base64url character: {}",
+ c
+ );
+ }
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_s256_of_verifier() {
+ let pkce = generate_pkce_challenge();
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
+ let expected_challenge = engine.encode(expected_digest);
+ assert_eq!(pkce.challenge, expected_challenge);
+ }
+
+ #[test]
+ fn test_pkce_challenges_are_unique() {
+ let a = generate_pkce_challenge();
+ let b = generate_pkce_challenge();
+ assert_ne!(a.verifier, b.verifier);
+ }
+
+ // -- Authorization URL tests ---------------------------------------------
+
+ #[test]
+ fn test_build_authorization_url() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ let pkce = PkceChallenge {
+ verifier: "test_verifier".into(),
+ challenge: "test_challenge".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "https://zed.dev/oauth/client-metadata.json",
+ "http://127.0.0.1:12345/callback",
+ &["files:read".into(), "files:write".into()],
+ "https://mcp.example.com",
+ &pkce,
+ "random_state_123",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert_eq!(pairs.get("response_type").unwrap(), "code");
+ assert_eq!(
+ pairs.get("client_id").unwrap(),
+ "https://zed.dev/oauth/client-metadata.json"
+ );
+ assert_eq!(
+ pairs.get("redirect_uri").unwrap(),
+ "http://127.0.0.1:12345/callback"
+ );
+ assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
+ assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
+ assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
+ assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
+ assert_eq!(pairs.get("state").unwrap(), "random_state_123");
+ }
+
+ #[test]
+ fn test_build_authorization_url_omits_empty_scope() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ let pkce = PkceChallenge {
+ verifier: "v".into(),
+ challenge: "c".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "client_123",
+ "http://127.0.0.1:9999/callback",
+ &[],
+ "https://mcp.example.com",
+ &pkce,
+ "state",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert!(!pairs.contains_key("scope"));
+ }
+
+ // -- Token exchange / refresh param tests --------------------------------
+
+ #[test]
+ fn test_token_exchange_params() {
+ let params = token_exchange_params(
+ "auth_code_abc",
+ "client_xyz",
+ "http://127.0.0.1:5555/callback",
+ "verifier_123",
+ "https://mcp.example.com",
+ );
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "authorization_code");
+ assert_eq!(map["code"], "auth_code_abc");
+ assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["code_verifier"], "verifier_123");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_token_refresh_params() {
+ let params =
+ token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "refresh_token");
+ assert_eq!(map["refresh_token"], "refresh_token_abc");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ // -- Token response tests ------------------------------------------------
+
+ #[test]
+ fn test_token_response_into_tokens_with_expiry() {
+ let response: TokenResponse = serde_json::from_str(
+ r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
+ )
+ .unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_123");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
+ assert!(tokens.expires_at.is_some());
+ }
+
+ #[test]
+ fn test_token_response_into_tokens_minimal() {
+ let response: TokenResponse =
+ serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_789");
+ assert_eq!(tokens.refresh_token, None);
+ assert_eq!(tokens.expires_at, None);
+ }
+
+ // -- DCR body test -------------------------------------------------------
+
+ #[test]
+ fn test_dcr_registration_body_shape() {
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+ assert_eq!(body["client_name"], "Zed");
+ assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["response_types"][0], "code");
+ assert_eq!(body["token_endpoint_auth_method"], "none");
+ }
+
+ // -- Test helpers for async/HTTP tests -----------------------------------
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ // -- Discovery integration tests -----------------------------------------
+
+ #[test]
+ fn test_fetch_protected_resource_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["read", "write"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ assert_eq!(
+ metadata.authorization_servers[0].as_str(),
+ "https://auth.example.com/"
+ );
+ assert_eq!(
+ metadata.scopes_supported,
+ Some(vec!["read".to_string(), "write".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri == "https://mcp.example.com/custom-resource-metadata" {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(500, r#"{"error": "should not be called"}"#)
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ // The cross-origin URL should NOT be fetched; only the
+ // well-known fallback at the server's own origin should be.
+ if uri.contains("attacker.example.com") {
+ panic!("should not fetch cross-origin resource_metadata URL");
+ } else if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ // Should have used the fallback well-known URL, not the attacker's.
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert_eq!(
+ metadata.token_endpoint.as_str(),
+ "https://auth.example.com/token"
+ );
+ assert!(metadata.registration_endpoint.is_some());
+ assert!(metadata.client_id_metadata_document_supported);
+ assert_eq!(
+ metadata.code_challenge_methods_supported,
+ Some(vec!["S256".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("openid-configuration") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert!(!metadata.client_id_metadata_document_supported);
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ // Response claims to be a different issuer.
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://evil.example.com",
+ "authorization_endpoint": "https://evil.example.com/authorize",
+ "token_endpoint": "https://evil.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let result = fetch_auth_server_metadata(&client, &issuer).await;
+
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("issuer mismatch"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Full discover integration tests -------------------------------------
+
+ #[test]
+ fn test_full_discover_with_cimd() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["mcp:read"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, CIMD_URL);
+ assert_eq!(registration.client_secret, None);
+ assert_eq!(discovery.scopes, vec!["mcp:read"]);
+ });
+ }
+
+ #[test]
+ fn test_full_discover_with_dcr_fallback() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": false
+ }"#,
+ )
+ } else if uri.contains("/register") {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dcr-minted-id-123",
+ "client_secret": "dcr-secret-456"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dcr-minted-id-123");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dcr-secret-456")
+ );
+ assert_eq!(discovery.scopes, vec!["files:read"]);
+ });
+ }
+
+ #[test]
+ fn test_discover_fails_without_pkce_support() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let result = discover(&client, &server_url, &www_auth).await;
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("code_challenge_methods_supported"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Token exchange integration tests ------------------------------------
+
+ #[test]
+ fn test_exchange_code_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new_access_token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 3600,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let tokens = exchange_code(
+ &client,
+ &metadata,
+ "auth_code_123",
+ CIMD_URL,
+ "http://127.0.0.1:9999/callback",
+ "verifier_abc",
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "new_access_token");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_refresh_tokens_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "refreshed_token",
+ "expires_in": 1800,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
+
+ let tokens = refresh_tokens(
+ &client,
+ &token_endpoint,
+ "old_refresh_token",
+ CIMD_URL,
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "refreshed_token");
+ assert_eq!(tokens.refresh_token, None);
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_exchange_code_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let result = exchange_code(
+ &client,
+ &metadata,
+ "bad_code",
+ "client",
+ "http://127.0.0.1:1/callback",
+ "verifier",
+ "https://mcp.example.com",
+ )
+ .await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("400"));
+ });
+ }
+
+ // -- DCR integration tests -----------------------------------------------
+
+ #[test]
+ fn test_perform_dcr() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dynamic-client-001",
+ "client_secret": "dynamic-secret-001"
+ }"#,
+ )
+ })
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dynamic-client-001");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dynamic-secret-001")
+ );
+ });
+ }
+
+ #[test]
+ fn test_perform_dcr_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(
+ async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
+ )
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("403"));
+ });
+ }
+
+ // -- OAuthCallback parse tests -------------------------------------------
+
+ #[test]
+ fn test_oauth_callback_parse_query() {
+ let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_reversed_order() {
+ let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_with_extra_params() {
+ let callback =
+ OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
+ .unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_code() {
+ let result = OAuthCallback::parse_query("state=test_state");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("code"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("state"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_code() {
+ let result = OAuthCallback::parse_query("code=&state=test_state");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code&state=");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_url_encoded_values() {
+ let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
+ assert_eq!(callback.code, "abc def");
+ assert_eq!(callback.state, "test=state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_response() {
+ let result = OAuthCallback::parse_query(
+ "error=access_denied&error_description=User%20denied%20access&state=abc",
+ );
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("access_denied"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("User denied access"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_without_description() {
+ let result = OAuthCallback::parse_query("error=server_error&state=abc");
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("server_error"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("no description"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ // -- McpOAuthTokenProvider tests -----------------------------------------
+
+ fn make_test_session(
+ access_token: &str,
+ refresh_token: Option<&str>,
+ expires_at: Option<SystemTime>,
+ ) -> OAuthSession {
+ OAuthSession {
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ resource: Url::parse("https://mcp.example.com").unwrap(),
+ client_registration: OAuthClientRegistration {
+ client_id: "test-client".into(),
+ client_secret: None,
+ },
+ tokens: OAuthTokens {
+ access_token: access_token.into(),
+ refresh_token: refresh_token.map(String::from),
+ expires_at,
+ },
+ }
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_none_when_token_expired() {
+ let expired = SystemTime::now() - Duration::from_secs(60);
+ let session = make_test_session("stale-token", Some("rt"), Some(expired));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token(), None);
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_not_expired() {
+ let far_future = SystemTime::now() + Duration::from_secs(3600);
+ let session = make_test_session("valid-token", Some("rt"), Some(far_future));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
+ let session = make_test_session("no-expiry-token", Some("rt"), None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
+ smol::block_on(async {
+ let session = make_test_session("token", None, None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| {
+ Box::pin(async { unreachable!("no HTTP call expected") })
+ }),
+ None,
+ );
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh-token"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "refresh_token": "new-refresh",
+ "expires_in": 1800
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+ assert_eq!(provider.access_token().as_deref(), Some("new-access"));
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("new-refresh")
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("original-refresh"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "expires_in": 900
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("original-refresh"),
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh"), None);
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, None);
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ // The old token should still be in place.
+ assert_eq!(provider.access_token().as_deref(), Some("old-access"));
+ });
+ }
+}
@@ -8,8 +8,30 @@ use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
+use crate::oauth::{self, OAuthTokenProvider, WwwAuthenticate};
use crate::transport::Transport;
+/// Typed errors returned by the HTTP transport that callers can downcast from
+/// `anyhow::Error` to handle specific failure modes.
+#[derive(Debug)]
+pub enum TransportError {
+ /// The server returned 401 and token refresh either wasn't possible or
+ /// failed. The caller should initiate the OAuth authorization flow.
+ AuthRequired { www_authenticate: WwwAuthenticate },
+}
+
+impl std::fmt::Display for TransportError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TransportError::AuthRequired { .. } => {
+ write!(f, "OAuth authorization required")
+ }
+ }
+ }
+}
+
+impl std::error::Error for TransportError {}
+
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
@@ -25,8 +47,11 @@ pub struct HttpTransport {
response_rx: channel::Receiver<String>,
error_tx: channel::Sender<String>,
error_rx: channel::Receiver<String>,
- // Authentication headers to include in requests
+ /// Static headers to include in every request (e.g. from server config).
headers: HashMap<String, String>,
+ /// When set, the transport attaches `Authorization: Bearer` headers and
+ /// handles 401 responses with token refresh + retry.
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
}
impl HttpTransport {
@@ -35,6 +60,16 @@ impl HttpTransport {
endpoint: String,
headers: HashMap<String, String>,
executor: BackgroundExecutor,
+ ) -> Self {
+ Self::new_with_token_provider(http_client, endpoint, headers, executor, None)
+ }
+
+ pub fn new_with_token_provider(
+ http_client: Arc<dyn HttpClient>,
+ endpoint: String,
+ headers: HashMap<String, String>,
+ executor: BackgroundExecutor,
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
) -> Self {
let (response_tx, response_rx) = channel::unbounded();
let (error_tx, error_rx) = channel::unbounded();
@@ -49,14 +84,14 @@ impl HttpTransport {
error_tx,
error_rx,
headers,
+ token_provider,
}
}
- /// Send a message and handle the response based on content type
- async fn send_message(&self, message: String) -> Result<()> {
- let is_notification =
- !message.contains("\"id\":") || message.contains("notifications/initialized");
-
+ /// Build a POST request for the given message body, attaching all standard
+ /// headers (content-type, accept, session ID, static headers, and bearer
+ /// token if available).
+ fn build_request(&self, message: &[u8]) -> Result<http_client::Request<AsyncBody>> {
let mut request_builder = Request::builder()
.method(Method::POST)
.uri(&self.endpoint)
@@ -70,15 +105,71 @@ impl HttpTransport {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
- // Add session ID if we have one (except for initialize)
+ // Attach bearer token when a token provider is present.
+ if let Some(token) = self.token_provider.as_ref().and_then(|p| p.access_token()) {
+ request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
+ // Add session ID if we have one (except for initialize).
if let Some(ref session_id) = *self.session_id.lock() {
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
}
- let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
+ Ok(request_builder.body(AsyncBody::from(message.to_vec()))?)
+ }
+
+ /// Send a message and handle the response based on content type.
+ async fn send_message(&self, message: String) -> Result<()> {
+ let is_notification =
+ !message.contains("\"id\":") || message.contains("notifications/initialized");
+
+ // If we currently have no access token, try refreshing before sending
+ // the request so restored but expired sessions do not need an initial
+ // 401 round-trip before they can recover.
+ if let Some(ref provider) = self.token_provider {
+ if provider.access_token().is_none() {
+ provider.try_refresh().await.unwrap_or(false);
+ }
+ }
+
+ let request = self.build_request(message.as_bytes())?;
let mut response = self.http_client.send(request).await?;
- // Handle different response types based on status and content-type
+ // On 401, try refreshing the token and retry once.
+ if response.status().as_u16() == 401 {
+ let www_auth_header = response
+ .headers()
+ .get("www-authenticate")
+ .and_then(|v| v.to_str().ok())
+ .unwrap_or("Bearer");
+
+ let www_authenticate =
+ oauth::parse_www_authenticate(www_auth_header).unwrap_or(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+
+ if let Some(ref provider) = self.token_provider {
+ if provider.try_refresh().await.unwrap_or(false) {
+ // Retry with the refreshed token.
+ let retry_request = self.build_request(message.as_bytes())?;
+ response = self.http_client.send(retry_request).await?;
+
+ // If still 401 after refresh, give up.
+ if response.status().as_u16() == 401 {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ }
+
+ // Handle different response types based on status and content-type.
match response.status() {
status if status.is_success() => {
// Check content type
@@ -233,6 +324,7 @@ impl Drop for HttpTransport {
let endpoint = self.endpoint.clone();
let session_id = self.session_id.lock().clone();
let headers = self.headers.clone();
+ let access_token = self.token_provider.as_ref().and_then(|p| p.access_token());
if let Some(session_id) = session_id {
self.executor
@@ -242,11 +334,17 @@ impl Drop for HttpTransport {
.uri(&endpoint)
.header(HEADER_SESSION_ID, &session_id);
- // Add authentication headers if present
+ // Add static authentication headers.
for (key, value) in headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
+ // Attach bearer token if available.
+ if let Some(token) = access_token {
+ request_builder =
+ request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
let request = request_builder.body(AsyncBody::empty());
if let Ok(request) = request {
@@ -257,3 +355,402 @@ impl Drop for HttpTransport {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use async_trait::async_trait;
+ use gpui::TestAppContext;
+ use parking_lot::Mutex as SyncMutex;
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+
+ /// A mock token provider that returns a configurable token and tracks
+ /// refresh attempts.
+ struct FakeTokenProvider {
+ token: SyncMutex<Option<String>>,
+ refreshed_token: SyncMutex<Option<String>>,
+ refresh_succeeds: AtomicBool,
+ refresh_count: AtomicUsize,
+ }
+
+ impl FakeTokenProvider {
+ fn new(token: Option<&str>, refresh_succeeds: bool) -> Arc<Self> {
+ Self::with_refreshed_token(token, None, refresh_succeeds)
+ }
+
+ fn with_refreshed_token(
+ token: Option<&str>,
+ refreshed_token: Option<&str>,
+ refresh_succeeds: bool,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ token: SyncMutex::new(token.map(String::from)),
+ refreshed_token: SyncMutex::new(refreshed_token.map(String::from)),
+ refresh_succeeds: AtomicBool::new(refresh_succeeds),
+ refresh_count: AtomicUsize::new(0),
+ })
+ }
+
+ fn set_token(&self, token: &str) {
+ *self.token.lock() = Some(token.to_string());
+ }
+
+ fn refresh_count(&self) -> usize {
+ self.refresh_count.load(Ordering::SeqCst)
+ }
+ }
+
+ #[async_trait]
+ impl OAuthTokenProvider for FakeTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ self.token.lock().clone()
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ self.refresh_count.fetch_add(1, Ordering::SeqCst);
+
+ let refresh_succeeds = self.refresh_succeeds.load(Ordering::SeqCst);
+ if refresh_succeeds {
+ if let Some(token) = self.refreshed_token.lock().clone() {
+ *self.token.lock() = Some(token);
+ }
+ }
+
+ Ok(refresh_succeeds)
+ }
+ }
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ #[gpui::test]
+ async fn test_bearer_token_attached_to_requests(cx: &mut TestAppContext) {
+ // Capture the Authorization header from the request.
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::new(Some("test-access-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer test-access-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_no_bearer_token_without_provider(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert!(captured_auth.lock().is_none());
+ }
+
+ #[gpui::test]
+ async fn test_missing_token_triggers_refresh_before_first_request(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(None, Some("refreshed-token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after proactive refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer refreshed-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_invalid_token_still_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer error="invalid_token", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(
+ Some("old-token"),
+ Some("refreshed-token"),
+ true,
+ );
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ // First request: 401.
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ // Retry after refresh: 200.
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("old-token"), true);
+ // Simulate the refresh updating the token.
+ let provider_ref = provider.clone();
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ // Set the new token that will be used on retry.
+ provider_ref.set_token("refreshed-token");
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider_ref.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_when_refresh_fails(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="read write""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // Refresh returns false β no new token available.
+ let provider = FakeTokenProvider::new(Some("stale-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert_eq!(
+ www_authenticate
+ .resource_metadata
+ .as_ref()
+ .map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource"),
+ );
+ assert_eq!(
+ www_authenticate.scope,
+ Some(vec!["read".to_string(), "write".to_string()]),
+ );
+ }
+ }
+ assert_eq!(provider.refresh_count(), 1);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_without_provider(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // No token provider at all.
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert!(www_authenticate.resource_metadata.is_none());
+ assert!(www_authenticate.scope.is_none());
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_401_after_successful_refresh_still_returns_auth_required(
+ cx: &mut TestAppContext,
+ ) {
+ // Both requests return 401 β the server rejects the refreshed token too.
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ err.downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ // Refresh was attempted exactly once.
+ assert_eq!(provider.refresh_count(), 1);
+ }
+}
@@ -17,7 +17,7 @@ cli-support = []
[dependencies]
ai_onboarding.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
+heapless.workspace = true
brotli.workspace = true
buffer_diff.workspace = true
client.workspace = true
@@ -1,5 +1,4 @@
use anyhow::Result;
-use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
@@ -27,6 +26,7 @@ use gpui::{
http_client::{self, AsyncBody, Method},
prelude::*,
};
+use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
@@ -332,7 +332,7 @@ struct ProjectState {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
+ pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_edit_prediction_refresh: Option<(EntityId, Instant)>,
last_jump_prediction_refresh: Option<(EntityId, Instant)>,
@@ -2311,18 +2311,24 @@ impl EditPredictionStore {
});
if project_state.pending_predictions.len() < max_pending_predictions {
- project_state.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- drop_on_cancel,
- });
+ project_state
+ .pending_predictions
+ .push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ drop_on_cancel,
+ })
+ .unwrap();
} else {
let pending_prediction = project_state.pending_predictions.pop().unwrap();
- project_state.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- drop_on_cancel,
- });
+ project_state
+ .pending_predictions
+ .push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ drop_on_cancel,
+ })
+ .unwrap();
project_state.cancel_pending_prediction(pending_prediction, cx);
}
}
@@ -82,6 +82,10 @@ pub struct ExamplePrediction {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub provider: PredictionProvider,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -166,6 +170,10 @@ pub struct ExampleScore {
pub inserted_tokens: usize,
#[serde(default)]
pub deleted_tokens: usize,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
impl Example {
@@ -150,6 +150,20 @@ pub fn zeta2_output_for_patch(
);
}
+ if version == ZetaFormat::V0318SeedMultiRegions {
+ let cursor_in_new = cursor_offset.map(|cursor_offset| {
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
+ });
+ return multi_region::encode_from_old_and_new_v0318(
+ &old_editable_region,
+ &result,
+ cursor_in_new,
+ zeta_prompt::CURSOR_MARKER,
+ multi_region::V0318_END_MARKER,
+ );
+ }
+
if version == ZetaFormat::V0316SeedMultiRegions {
let cursor_in_new = cursor_offset.map(|cursor_offset| {
let hunk_start = first_hunk_offset.unwrap_or(0);
@@ -237,7 +251,10 @@ impl TeacherPrompt {
}
}
- if response.trim().ends_with(Self::NO_EDITS) {
+ if response
+ .trim_end_matches(&[' ', '\n', '`'])
+ .ends_with(Self::NO_EDITS)
+ {
return Ok(no_edits);
}
@@ -872,4 +889,42 @@ mod tests {
let result = extract_last_codeblock(text).unwrap();
assert_eq!(result, "content here\n");
}
+
+ #[test]
+ fn test_parse_no_edits_response_with_trailing_backticks() {
+ let response = "NO_EDITS```";
+
+ let parsed = TeacherPrompt::parse(
+ &Example {
+ spec: edit_prediction::example_spec::ExampleSpec {
+ name: "test".to_string(),
+ repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "HEAD".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")),
+ cursor_position: "0:0".to_string(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ },
+ prompt_inputs: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ qa: Vec::new(),
+ zed_version: None,
+ state: None,
+ },
+ response,
+ )
+ .unwrap();
+
+ assert!(parsed.0.is_empty());
+ assert!(parsed.1.is_none());
+ }
}
@@ -263,6 +263,8 @@ pub async fn run_prediction(
actual_cursor: None,
error: None,
provider,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
step_progress.set_substatus("requesting prediction");
@@ -455,6 +457,8 @@ async fn predict_anthropic(
_ => PredictionProvider::TeacherNonBatching(backend),
}
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -572,6 +576,8 @@ async fn predict_openai(
_ => PredictionProvider::TeacherNonBatching(backend),
}
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -656,6 +662,8 @@ pub async fn predict_baseten(
actual_cursor,
error: None,
provider: PredictionProvider::Baseten(format),
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -10,7 +10,7 @@ use crate::{
BatchProvider, PredictionProvider,
anthropic_client::AnthropicClient,
example::{ActualCursor, Example, ExamplePrediction},
- format_prompt::{TeacherPrompt, extract_last_codeblock},
+ format_prompt::TeacherPrompt,
metrics::count_patch_token_changes,
openai_client::OpenAiClient,
parse_output::run_parse_output,
@@ -227,10 +227,7 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
/// and delegates normal output to `TeacherPrompt::parse`.
pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
- let last_codeblock =
- extract_last_codeblock(actual_output).unwrap_or_else(|| actual_output.to_string());
-
- if last_codeblock.contains(KEEP_PREVIOUS) {
+ if actual_output.contains(KEEP_PREVIOUS) {
let original = example
.predictions
.first()
@@ -426,6 +423,8 @@ pub async fn run_repair(
actual_cursor,
error: err,
provider: PredictionProvider::Repair,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
Ok(())
@@ -454,3 +453,71 @@ pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
Ok(())
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{PredictionProvider, TeacherBackend};
+ use edit_prediction::example_spec::ExampleSpec;
+ use std::{path::Path, sync::Arc};
+
+ fn example_with_previous_prediction() -> Example {
+ Example {
+ spec: ExampleSpec {
+ name: "example".to_string(),
+ repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "HEAD".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: Arc::from(Path::new("src/main.rs")),
+ cursor_position: "0:0".to_string(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ },
+ prompt_inputs: None,
+ prompt: None,
+ predictions: vec![ExamplePrediction {
+ actual_patch: Some("previous patch".to_string()),
+ actual_output: String::new(),
+ actual_cursor: Some(ActualCursor {
+ path: "src/main.rs".to_string(),
+ row: 1,
+ column: 2,
+ offset: 3,
+ editable_region_offset: Some(4),
+ }),
+ error: None,
+ provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
+ cumulative_logprob: None,
+ avg_logprob: None,
+ }],
+ score: Vec::new(),
+ qa: Vec::new(),
+ zed_version: None,
+ state: None,
+ }
+ }
+
+ #[test]
+ fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
+ let example = example_with_previous_prediction();
+ let actual_output = indoc::indoc! {"
+ After reviewing the feedback, the previous prediction is still correct.
+ Use `KEEP_PREVIOUS`.
+
+ ```
+ unrelated trailing code block
+ ```
+ "};
+
+ let (patch, cursor) = parse(&example, actual_output).unwrap();
+
+ assert_eq!(patch, "previous patch");
+ assert_eq!(cursor.unwrap().offset, 3);
+ }
+}
@@ -78,6 +78,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes: false,
inserted_tokens: 0,
deleted_tokens: 0,
+ cumulative_logprob: None,
+ avg_logprob: None,
};
let cursor_path = example.spec.cursor_path.as_ref();
@@ -189,6 +191,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes,
inserted_tokens: token_changes.inserted_tokens,
deleted_tokens: token_changes.deleted_tokens,
+ cumulative_logprob: prediction.cumulative_logprob,
+ avg_logprob: prediction.avg_logprob,
});
}
@@ -1028,6 +1028,7 @@ fn assert_related_files_impl(
pretty_assertions::assert_eq!(actual, expected)
}
+#[track_caller]
fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
let actual_first_lines = definitions
.iter()
@@ -174,7 +174,7 @@ pub fn register_fake_definition_server(
struct DefinitionIndex {
language: Arc<Language>,
definitions: HashMap<String, Vec<lsp::Location>>,
- type_annotations: HashMap<String, String>,
+ type_annotations_by_file: HashMap<Uri, HashMap<String, String>>,
files: HashMap<Uri, FileEntry>,
}
@@ -189,7 +189,7 @@ impl DefinitionIndex {
Self {
language,
definitions: HashMap::default(),
- type_annotations: HashMap::default(),
+ type_annotations_by_file: HashMap::default(),
files: HashMap::default(),
}
}
@@ -199,6 +199,7 @@ impl DefinitionIndex {
locations.retain(|loc| &loc.uri != uri);
!locations.is_empty()
});
+ self.type_annotations_by_file.remove(uri);
self.files.remove(uri);
}
@@ -243,11 +244,11 @@ impl DefinitionIndex {
.push(location);
}
- for (identifier_name, type_name) in extract_type_annotations(content) {
- self.type_annotations
- .entry(identifier_name)
- .or_insert(type_name);
- }
+ let type_annotations = extract_type_annotations(content)
+ .into_iter()
+ .collect::<HashMap<_, _>>();
+ self.type_annotations_by_file
+ .insert(uri.clone(), type_annotations);
self.files.insert(
uri,
@@ -279,7 +280,11 @@ impl DefinitionIndex {
let entry = self.files.get(&uri)?;
let name = word_at_position(&entry.contents, position)?;
- if let Some(type_name) = self.type_annotations.get(name) {
+ if let Some(type_name) = self
+ .type_annotations_by_file
+ .get(&uri)
+ .and_then(|annotations| annotations.get(name))
+ {
if let Some(locations) = self.definitions.get(type_name) {
return Some(lsp::GotoDefinitionResponse::Array(locations.clone()));
}
@@ -367,6 +372,20 @@ fn extract_base_type_name(type_str: &str) -> String {
return outer.to_string();
}
+ if let Some(call_start) = trimmed.find("::") {
+ let outer = &trimmed[..call_start];
+ if matches!(outer, "Arc" | "Box" | "Rc" | "Option" | "Vec" | "Cow") {
+ let rest = trimmed[call_start + 2..].trim_start();
+ if let Some(paren_start) = rest.find('(') {
+ let inner = &rest[paren_start + 1..];
+ let inner = inner.trim();
+ if !inner.is_empty() {
+ return extract_base_type_name(inner);
+ }
+ }
+ }
+ }
+
trimmed
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()
@@ -392,6 +392,20 @@ where
&bracket_colors_markup(&mut cx),
"All markdown brackets should be colored based on their depth, again"
);
+
+ cx.set_state(indoc! {r#"Λ('')('')
+
+((''))('')
+
+('')((''))"#});
+ cx.executor().advance_clock(Duration::from_millis(100));
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ "«1('')1»«1('')1»\n\n«1(«2('')2»)1»«1('')1»\n\n«1('')1»«1(«2('')2»)1»\n1 hsla(207.80, 16.20%, 69.19%, 1.00)\n2 hsla(29.00, 54.00%, 65.88%, 1.00)\n",
+ &bracket_colors_markup(&mut cx),
+ "Markdown quote pairs should not interfere with parenthesis pairing"
+ );
}
#[gpui::test]
@@ -1,13 +1,17 @@
use edit_prediction_types::{
EditPredictionDelegate, EditPredictionIconSet, PredictedCursorPosition,
};
-use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
+use gpui::{
+ Entity, KeyBinding, KeybindingKeystroke, Keystroke, Modifiers, NoAction, Task, prelude::*,
+};
use indoc::indoc;
-use language::Buffer;
use language::EditPredictionsMode;
-use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
+use language::{Buffer, CodeLabel};
+use multi_buffer::{Anchor, ExcerptId, MultiBufferSnapshot, ToPoint};
+use project::{Completion, CompletionResponse, CompletionSource};
use std::{
ops::Range,
+ rc::Rc,
sync::{
Arc,
atomic::{self, AtomicUsize},
@@ -17,8 +21,9 @@ use text::{Point, ToOffset};
use ui::prelude::*;
use crate::{
- AcceptEditPrediction, EditPrediction, EditPredictionKeybindAction,
- EditPredictionKeybindSurface, MenuEditPredictionsPolicy,
+ AcceptEditPrediction, CompletionContext, CompletionProvider, EditPrediction,
+ EditPredictionKeybindAction, EditPredictionKeybindSurface, MenuEditPredictionsPolicy,
+ ShowCompletions,
editor_tests::{init_test, update_test_language_settings},
test::editor_test_context::EditorTestContext,
};
@@ -482,57 +487,10 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
});
}
-fn load_default_keymap(cx: &mut gpui::TestAppContext) {
- cx.update(|cx| {
- cx.bind_keys(
- settings::KeymapFile::load_asset_allow_partial_failure(
- settings::DEFAULT_KEYMAP_PATH,
- cx,
- )
- .expect("failed to load default keymap"),
- );
- });
-}
-
#[gpui::test]
-async fn test_tab_is_preferred_accept_binding_over_alt_tab(cx: &mut gpui::TestAppContext) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
-
- propose_edits(&provider, vec![(8..8, "42")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::Inline,
- window,
- cx,
- );
- let keystroke = keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
- assert!(
- !keystroke.modifiers().modified(),
- "preferred accept binding should be unmodified (tab), got modifiers: {:?}",
- keystroke.modifiers()
- );
- assert_eq!(
- keystroke.key(),
- "tab",
- "preferred accept binding should be tab"
- );
- });
-}
-
-#[gpui::test]
-async fn test_subtle_in_code_indicator_prefers_preview_binding(cx: &mut gpui::TestAppContext) {
+async fn test_edit_prediction_preview_activates_when_prediction_arrives_with_modifier_held(
+ cx: &mut gpui::TestAppContext,
+) {
init_test(cx, |_| {});
load_default_keymap(cx);
update_test_language_settings(cx, &|settings| {
@@ -544,227 +502,324 @@ async fn test_subtle_in_code_indicator_prefers_preview_binding(cx: &mut gpui::Te
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let x = Λ;");
- propose_edits(&provider, vec![(8..8, "42")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
- assert!(
- editor.edit_prediction_requires_modifier(),
- "subtle mode should require a modifier"
- );
-
- let inline_keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::Inline,
- window,
- cx,
- );
- let compact_keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverCompact,
- window,
- cx,
- );
-
- let accept_keystroke = inline_keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
- let preview_keystroke = inline_keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
- let in_code_keystroke = inline_keybind_display
- .displayed_keystroke
- .as_ref()
- .expect("should have an in-code binding");
- let compact_cursor_popover_keystroke = compact_keybind_display
- .displayed_keystroke
- .as_ref()
- .expect("should have a compact cursor popover binding");
-
- assert_eq!(accept_keystroke.key(), "tab");
- assert!(
- !editor.has_visible_completions_menu(),
- "compact cursor-popover branch should be used without a completions menu"
- );
- assert!(
- preview_keystroke.modifiers().modified(),
- "preview binding should use modifiers in subtle mode"
- );
- assert_eq!(
- compact_cursor_popover_keystroke.key(),
- preview_keystroke.key(),
- "subtle compact cursor popover should prefer the preview binding"
- );
- assert_eq!(
- compact_cursor_popover_keystroke.modifiers(),
- preview_keystroke.modifiers(),
- "subtle compact cursor popover should use the preview binding modifiers"
- );
- assert_eq!(
- in_code_keystroke.key(),
- preview_keystroke.key(),
- "subtle in-code indicator should prefer the preview binding"
- );
- assert_eq!(
- in_code_keystroke.modifiers(),
- preview_keystroke.modifiers(),
- "subtle in-code indicator should use the preview binding modifiers"
- );
+ cx.editor(|editor, _, _| {
+ assert!(!editor.has_active_edit_prediction());
+ assert!(!editor.edit_prediction_preview_is_active());
});
-}
-#[gpui::test]
-async fn test_tab_accepts_edit_prediction_over_completion(cx: &mut gpui::TestAppContext) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
-
- propose_edits(&provider, vec![(8..8, "42")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
- assert_editor_active_edit_completion(&mut cx, |_, edits| {
- assert_eq!(edits.len(), 1);
- assert_eq!(edits[0].1.as_ref(), "42");
+ let preview_modifiers = cx.update_editor(|editor, window, cx| {
+ *editor
+ .preview_edit_prediction_keystroke(window, cx)
+ .unwrap()
+ .modifiers()
});
- cx.simulate_keystroke("tab");
+ cx.simulate_modifiers_change(preview_modifiers);
cx.run_until_parked();
- cx.assert_editor_state("let x = 42Λ;");
-}
-
-#[gpui::test]
-async fn test_single_line_prediction_uses_accept_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
+ cx.editor(|editor, _, _| {
+ assert!(!editor.has_active_edit_prediction());
+ assert!(editor.edit_prediction_preview_is_active());
+ });
propose_edits(&provider, vec![(8..8, "42")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
+ editor.set_menu_edit_predictions_policy(MenuEditPredictionsPolicy::ByProvider);
+ editor.update_visible_edit_prediction(window, cx)
+ });
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
+ cx.editor(|editor, _, _| {
+ assert!(editor.has_active_edit_prediction());
+ assert!(
+ editor.edit_prediction_preview_is_active(),
+ "prediction preview should activate immediately when the prediction arrives while the preview modifier is still held",
);
+ });
+}
- let accept_keystroke = keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
- let preview_keystroke = keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
-
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Accept,
- "single-line prediction should show the accept action"
+fn load_default_keymap(cx: &mut gpui::TestAppContext) {
+ cx.update(|cx| {
+ cx.bind_keys(
+ settings::KeymapFile::load_asset_allow_partial_failure(
+ settings::DEFAULT_KEYMAP_PATH,
+ cx,
+ )
+ .expect("failed to load default keymap"),
);
- assert_eq!(accept_keystroke.key(), "tab");
- assert!(preview_keystroke.modifiers().modified());
});
}
#[gpui::test]
-async fn test_multi_line_prediction_uses_preview_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
-
- propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
+async fn test_inline_edit_prediction_keybind_selection_cases(cx: &mut gpui::TestAppContext) {
+ enum InlineKeybindState {
+ Normal,
+ ShowingCompletions,
+ InLeadingWhitespace,
+ ShowingCompletionsAndLeadingWhitespace,
+ }
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
- );
- let preview_keystroke = keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
+ enum ExpectedKeystroke {
+ DefaultAccept,
+ DefaultPreview,
+ Literal(&'static str),
+ }
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Preview,
- "multi-line prediction should show the preview action"
- );
- assert!(preview_keystroke.modifiers().modified());
- });
-}
+ struct InlineKeybindCase {
+ name: &'static str,
+ use_default_keymap: bool,
+ mode: EditPredictionsMode,
+ extra_bindings: Vec<KeyBinding>,
+ state: InlineKeybindState,
+ expected_accept_keystroke: ExpectedKeystroke,
+ expected_preview_keystroke: ExpectedKeystroke,
+ expected_displayed_keystroke: ExpectedKeystroke,
+ }
-#[gpui::test]
-async fn test_single_line_prediction_with_preview_uses_accept_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
init_test(cx, |_| {});
load_default_keymap(cx);
+ let mut default_cx = EditorTestContext::new(cx).await;
+ let provider = default_cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut default_cx);
+ default_cx.set_state("let x = Λ;");
+ propose_edits(&provider, vec![(8..8, "42")], &mut default_cx);
+ default_cx
+ .update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ let (default_accept_keystroke, default_preview_keystroke) =
+ default_cx.update_editor(|editor, window, cx| {
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::Inline,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .expect("default inline edit prediction should have an accept binding")
+ .clone();
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .expect("default inline edit prediction should have a preview binding")
+ .clone();
+ (accept_keystroke, preview_keystroke)
+ });
+
+ let cases = [
+ InlineKeybindCase {
+ name: "default setup prefers tab over alt-tab for accept",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: Vec::new(),
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultAccept,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultAccept,
+ },
+ InlineKeybindCase {
+ name: "subtle mode displays preview binding inline",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Subtle,
+ extra_bindings: Vec::new(),
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultPreview,
+ },
+ InlineKeybindCase {
+ name: "removing default tab binding still displays tab",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "tab",
+ NoAction,
+ Some("Editor && edit_prediction && edit_prediction_mode == eager"),
+ )],
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultPreview,
+ },
+ InlineKeybindCase {
+ name: "custom-only rebound accept key uses replacement key",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction"),
+ )],
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "showing completions restores conflict-context binding",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && showing_completions"),
+ )],
+ state: InlineKeybindState::ShowingCompletions,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "leading whitespace restores conflict-context binding",
+ use_default_keymap: false,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && in_leading_whitespace"),
+ )],
+ state: InlineKeybindState::InLeadingWhitespace,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "showing completions and leading whitespace restore combined conflict binding",
+ use_default_keymap: false,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && showing_completions && in_leading_whitespace"),
+ )],
+ state: InlineKeybindState::ShowingCompletionsAndLeadingWhitespace,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ ];
+
+ for case in cases {
+ init_test(cx, |_| {});
+ if case.use_default_keymap {
+ load_default_keymap(cx);
+ }
+ update_test_language_settings(cx, &|settings| {
+ settings.edit_predictions.get_or_insert_default().mode = Some(case.mode);
+ });
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
+ if !case.extra_bindings.is_empty() {
+ cx.update(|cx| cx.bind_keys(case.extra_bindings.clone()));
+ }
- propose_edits_with_preview(&provider, vec![(8..8, "42")], &mut cx).await;
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
+ match case.state {
+ InlineKeybindState::Normal | InlineKeybindState::ShowingCompletions => {
+ cx.set_state("let x = Λ;");
+ }
+ InlineKeybindState::InLeadingWhitespace
+ | InlineKeybindState::ShowingCompletionsAndLeadingWhitespace => {
+ cx.set_state(indoc! {"
+ fn main() {
+ Λ
+ }
+ "});
+ }
+ }
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
- );
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ if matches!(
+ case.state,
+ InlineKeybindState::ShowingCompletions
+ | InlineKeybindState::ShowingCompletionsAndLeadingWhitespace
+ ) {
+ assign_editor_completion_menu_provider(&mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.show_completions(&ShowCompletions, window, cx);
+ });
+ cx.run_until_parked();
+ }
- let accept_keystroke = keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
- let preview_keystroke = keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
+ cx.update_editor(|editor, window, cx| {
+ assert!(
+ editor.has_active_edit_prediction(),
+ "case '{}' should have an active edit prediction",
+ case.name
+ );
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Accept,
- "single-line prediction should show the accept action even with edit_preview"
- );
- assert_eq!(accept_keystroke.key(), "tab");
- assert!(preview_keystroke.modifiers().modified());
- });
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::Inline,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have an accept binding", case.name));
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a preview binding", case.name));
+ let displayed_keystroke = keybind_display
+ .displayed_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a displayed binding", case.name));
+
+ let expected_accept_keystroke = match case.expected_accept_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+ let expected_preview_keystroke = match case.expected_preview_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+ let expected_displayed_keystroke = match case.expected_displayed_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+
+ assert_eq!(
+ accept_keystroke, &expected_accept_keystroke,
+ "case '{}' selected the wrong accept binding",
+ case.name
+ );
+ assert_eq!(
+ preview_keystroke, &expected_preview_keystroke,
+ "case '{}' selected the wrong preview binding",
+ case.name
+ );
+ assert_eq!(
+ displayed_keystroke, &expected_displayed_keystroke,
+ "case '{}' selected the wrong displayed binding",
+ case.name
+ );
+
+ if matches!(case.mode, EditPredictionsMode::Subtle) {
+ assert!(
+ editor.edit_prediction_requires_modifier(),
+ "case '{}' should require a modifier",
+ case.name
+ );
+ }
+ });
+ }
}
#[gpui::test]
-async fn test_multi_line_prediction_with_preview_uses_preview_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
+async fn test_tab_accepts_edit_prediction_over_completion(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
load_default_keymap(cx);
@@ -773,131 +828,194 @@ async fn test_multi_line_prediction_with_preview_uses_preview_cursor_popover_act
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let x = Λ;");
- propose_edits_with_preview(&provider, vec![(8..8, "42\n43")], &mut cx).await;
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
- );
- let preview_keystroke = keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
-
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Preview,
- "multi-line prediction should show the preview action with edit_preview"
- );
- assert!(preview_keystroke.modifiers().modified());
+ assert_editor_active_edit_completion(&mut cx, |_, edits| {
+ assert_eq!(edits.len(), 1);
+ assert_eq!(edits[0].1.as_ref(), "42");
});
-}
-
-#[gpui::test]
-async fn test_single_line_deletion_of_newline_uses_accept_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state(indoc! {"
- fn main() {
- let value = 1;
- Λprintln!(\"done\");
- }
- "});
-
- propose_edits(
- &provider,
- vec![(Point::new(1, 18)..Point::new(2, 17), "")],
- &mut cx,
- );
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
-
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
- );
-
- let accept_keystroke = keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
- let preview_keystroke = keybind_display
- .preview_keystroke
- .as_ref()
- .expect("should have a preview binding");
+ cx.simulate_keystroke("tab");
+ cx.run_until_parked();
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Accept,
- "deleting one newline plus adjacent text should show the accept action"
- );
- assert_eq!(accept_keystroke.key(), "tab");
- assert!(preview_keystroke.modifiers().modified());
- });
+ cx.assert_editor_state("let x = 42Λ;");
}
#[gpui::test]
-async fn test_stale_single_line_prediction_does_not_force_preview_cursor_popover_action(
- cx: &mut gpui::TestAppContext,
-) {
- init_test(cx, |_| {});
- load_default_keymap(cx);
-
- let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionDelegate::default());
- assign_editor_completion_provider(provider.clone(), &mut cx);
- cx.set_state("let x = Λ;");
-
- propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
- cx.update_editor(|editor, _window, cx| {
- assert!(editor.active_edit_prediction.is_some());
- assert!(editor.stale_edit_prediction_in_menu.is_none());
- editor.take_active_edit_prediction(cx);
- assert!(editor.active_edit_prediction.is_none());
- assert!(editor.stale_edit_prediction_in_menu.is_some());
- });
+async fn test_cursor_popover_edit_prediction_keybind_cases(cx: &mut gpui::TestAppContext) {
+ enum CursorPopoverPredictionKind {
+ SingleLine,
+ MultiLine,
+ SingleLineWithPreview,
+ MultiLineWithPreview,
+ DeleteSingleNewline,
+ StaleSingleLineAfterMultiLine,
+ }
- propose_edits(&provider, vec![(8..8, "42")], &mut cx);
- cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+ struct CursorPopoverCase {
+ name: &'static str,
+ prediction_kind: CursorPopoverPredictionKind,
+ expected_action: EditPredictionKeybindAction,
+ }
- cx.update_editor(|editor, window, cx| {
- assert!(editor.has_active_edit_prediction());
+ let cases = [
+ CursorPopoverCase {
+ name: "single line prediction uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::SingleLine,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "multi line prediction uses preview action",
+ prediction_kind: CursorPopoverPredictionKind::MultiLine,
+ expected_action: EditPredictionKeybindAction::Preview,
+ },
+ CursorPopoverCase {
+ name: "single line prediction with preview still uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::SingleLineWithPreview,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "multi line prediction with preview uses preview action",
+ prediction_kind: CursorPopoverPredictionKind::MultiLineWithPreview,
+ expected_action: EditPredictionKeybindAction::Preview,
+ },
+ CursorPopoverCase {
+ name: "single line newline deletion uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::DeleteSingleNewline,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "stale multi line prediction does not force preview action",
+ prediction_kind: CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ ];
+
+ for case in cases {
+ init_test(cx, |_| {});
+ load_default_keymap(cx);
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+
+ match case.prediction_kind {
+ CursorPopoverPredictionKind::SingleLine => {
+ cx.set_state("let x = Λ;");
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::MultiLine => {
+ cx.set_state("let x = Λ;");
+ propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::SingleLineWithPreview => {
+ cx.set_state("let x = Λ;");
+ propose_edits_with_preview(&provider, vec![(8..8, "42")], &mut cx).await;
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::MultiLineWithPreview => {
+ cx.set_state("let x = Λ;");
+ propose_edits_with_preview(&provider, vec![(8..8, "42\n43")], &mut cx).await;
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::DeleteSingleNewline => {
+ cx.set_state(indoc! {"
+ fn main() {
+ let value = 1;
+ Λprintln!(\"done\");
+ }
+ "});
+ propose_edits(
+ &provider,
+ vec![(Point::new(1, 18)..Point::new(2, 17), "")],
+ &mut cx,
+ );
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine => {
+ cx.set_state("let x = Λ;");
+ propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ cx.update_editor(|editor, _window, cx| {
+ assert!(editor.active_edit_prediction.is_some());
+ assert!(editor.stale_edit_prediction_in_menu.is_none());
+ editor.take_active_edit_prediction(cx);
+ assert!(editor.active_edit_prediction.is_none());
+ assert!(editor.stale_edit_prediction_in_menu.is_some());
+ });
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ }
- let keybind_display = editor.edit_prediction_keybind_display(
- EditPredictionKeybindSurface::CursorPopoverExpanded,
- window,
- cx,
- );
- let accept_keystroke = keybind_display
- .accept_keystroke
- .as_ref()
- .expect("should have an accept binding");
+ cx.update_editor(|editor, window, cx| {
+ assert!(
+ editor.has_active_edit_prediction(),
+ "case '{}' should have an active edit prediction",
+ case.name
+ );
- assert_eq!(
- keybind_display.action,
- EditPredictionKeybindAction::Accept,
- "single-line active prediction should show the accept action"
- );
- assert!(
- editor.stale_edit_prediction_in_menu.is_none(),
- "refreshing the visible prediction should clear stale menu state"
- );
- assert_eq!(accept_keystroke.key(), "tab");
- });
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::CursorPopoverExpanded,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have an accept binding", case.name));
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a preview binding", case.name));
+
+ assert_eq!(
+ keybind_display.action, case.expected_action,
+ "case '{}' selected the wrong cursor popover action",
+ case.name
+ );
+ assert_eq!(
+ accept_keystroke.key(),
+ "tab",
+ "case '{}' selected the wrong accept binding",
+ case.name
+ );
+ assert!(
+ preview_keystroke.modifiers().modified(),
+ "case '{}' should use a modified preview binding",
+ case.name
+ );
+
+ if matches!(
+ case.prediction_kind,
+ CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine
+ ) {
+ assert!(
+ editor.stale_edit_prediction_in_menu.is_none(),
+ "case '{}' should clear stale menu state",
+ case.name
+ );
+ }
+ });
+ }
}
fn assert_editor_active_edit_completion(
@@ -1054,6 +1172,12 @@ fn assign_editor_completion_provider(
})
}
+fn assign_editor_completion_menu_provider(cx: &mut EditorTestContext) {
+ cx.update_editor(|editor, _, _| {
+ editor.set_completion_provider(Some(Rc::new(FakeCompletionMenuProvider)));
+ });
+}
+
fn propose_edits_non_zed<T: ToOffset>(
provider: &Entity<FakeNonZedEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
@@ -1086,6 +1210,54 @@ fn assign_editor_completion_provider_non_zed(
})
}
+struct FakeCompletionMenuProvider;
+
+impl CompletionProvider for FakeCompletionMenuProvider {
+ fn completions(
+ &self,
+ _excerpt_id: ExcerptId,
+ _buffer: &Entity<Buffer>,
+ _buffer_position: text::Anchor,
+ _trigger: CompletionContext,
+ _window: &mut Window,
+ _cx: &mut Context<crate::Editor>,
+ ) -> Task<anyhow::Result<Vec<CompletionResponse>>> {
+ let completion = Completion {
+ replace_range: text::Anchor::MIN..text::Anchor::MAX,
+ new_text: "fake_completion".to_string(),
+ label: CodeLabel::plain("fake_completion".to_string(), None),
+ documentation: None,
+ source: CompletionSource::Custom,
+ icon_path: None,
+ match_start: None,
+ snippet_deduplication_key: None,
+ insert_text_mode: None,
+ confirm: None,
+ };
+
+ Task::ready(Ok(vec![CompletionResponse {
+ completions: vec![completion],
+ display_options: Default::default(),
+ is_incomplete: false,
+ }]))
+ }
+
+ fn is_completion_trigger(
+ &self,
+ _buffer: &Entity<Buffer>,
+ _position: language::Anchor,
+ _text: &str,
+ _trigger_in_words: bool,
+ _cx: &mut Context<crate::Editor>,
+ ) -> bool {
+ false
+ }
+
+ fn filter_completions(&self) -> bool {
+ false
+ }
+}
+
#[derive(Default, Clone)]
pub struct FakeEditPredictionDelegate {
pub completion: Option<edit_prediction_types::EditPrediction>,
@@ -1869,6 +1869,7 @@ pub enum MultibufferSelectionMode {
pub struct RewrapOptions {
pub override_language_settings: bool,
pub preserve_existing_whitespace: bool,
+ pub line_length: Option<usize>,
}
impl Editor {
@@ -2885,6 +2886,11 @@ impl Editor {
if self.in_leading_whitespace {
key_context.add("in_leading_whitespace");
}
+ if self.edit_prediction_requires_modifier() {
+ key_context.set("edit_prediction_mode", "subtle")
+ } else {
+ key_context.set("edit_prediction_mode", "eager");
+ }
if self.selection_mark_mode {
key_context.add("selection_mode");
@@ -2952,7 +2958,7 @@ impl Editor {
window: &mut Window,
cx: &mut App,
) -> Option<gpui::KeybindingKeystroke> {
- let key_context = self.key_context_internal(self.has_active_edit_prediction(), window, cx);
+ let key_context = self.key_context_internal(true, window, cx);
let bindings =
match granularity {
@@ -2979,7 +2985,7 @@ impl Editor {
window: &mut Window,
cx: &mut App,
) -> Option<gpui::KeybindingKeystroke> {
- let key_context = self.key_context_internal(self.has_active_edit_prediction(), window, cx);
+ let key_context = self.key_context_internal(true, window, cx);
let bindings = window.bindings_for_action_in_context(&AcceptEditPrediction, key_context);
bindings
.into_iter()
@@ -2990,6 +2996,32 @@ impl Editor {
})
}
+ fn edit_prediction_preview_modifiers_held(
+ &self,
+ modifiers: &Modifiers,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> bool {
+ let key_context = self.key_context_internal(true, window, cx);
+ let actions: [&dyn Action; 3] = [
+ &AcceptEditPrediction,
+ &AcceptNextWordEditPrediction,
+ &AcceptNextLineEditPrediction,
+ ];
+
+ actions.into_iter().any(|action| {
+ window
+ .bindings_for_action_in_context(action, key_context.clone())
+ .into_iter()
+ .rev()
+ .any(|binding| {
+ binding.keystrokes().first().is_some_and(|keystroke| {
+ keystroke.modifiers().modified() && keystroke.modifiers() == modifiers
+ })
+ })
+ })
+ }
+
fn edit_prediction_cursor_popover_prefers_preview(
&self,
completion: &EditPredictionState,
@@ -5119,6 +5151,7 @@ impl Editor {
RewrapOptions {
override_language_settings: true,
preserve_existing_whitespace: true,
+ line_length: None,
},
cx,
)
@@ -8498,9 +8531,12 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ self.update_edit_prediction_settings(cx);
+
// Ensure that the edit prediction preview is updated, even when not
// enabled, if there's an active edit prediction preview.
if self.show_edit_predictions_in_menu()
+ || self.edit_prediction_requires_modifier()
|| matches!(
self.edit_prediction_preview,
EditPredictionPreview::Active { .. }
@@ -8593,24 +8629,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let mut modifiers_held = false;
-
- let key_context = self.key_context_internal(self.has_active_edit_prediction(), window, cx);
- let actions: [&dyn Action; 3] = [
- &AcceptEditPrediction,
- &AcceptNextWordEditPrediction,
- &AcceptNextLineEditPrediction,
- ];
-
- for action in actions {
- let bindings = window.bindings_for_action_in_context(action, key_context.clone());
- for binding in bindings {
- if let Some(keystroke) = binding.keystrokes().first() {
- modifiers_held = modifiers_held
- || (keystroke.modifiers() == modifiers && keystroke.modifiers().modified());
- }
- }
- }
+ let modifiers_held = self.edit_prediction_preview_modifiers_held(modifiers, window, cx);
if modifiers_held {
if matches!(
@@ -13704,7 +13723,7 @@ impl Editor {
continue;
};
- let wrap_column = self.hard_wrap.unwrap_or_else(|| {
+ let wrap_column = options.line_length.or(self.hard_wrap).unwrap_or_else(|| {
buffer
.language_settings_at(Point::new(start_row, 0), cx)
.preferred_line_length as usize
@@ -14648,6 +14648,107 @@ async fn test_organize_imports_manual_trigger(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_formatter_failure_does_not_abort_subsequent_formatters(cx: &mut TestAppContext) {
+ init_test(cx, |settings| {
+ settings.defaults.formatter = Some(FormatterList::Vec(vec![
+ Formatter::LanguageServer(settings::LanguageServerFormatterSpecifier::Current),
+ Formatter::CodeAction("organize-imports".into()),
+ ]))
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_file(path!("/file.rs"), "fn main() {}\n".into())
+ .await;
+
+ let project = Project::test(fs, [path!("/").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ code_action_provider: Some(lsp::CodeActionProviderCapability::Simple(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/file.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let (editor, cx) = cx.add_window_view(|window, cx| {
+ build_editor_with_project(project.clone(), buffer, window, cx)
+ });
+
+ let fake_server = fake_servers.next().await.unwrap();
+
+ // Formatter #1 (LanguageServer) returns an error to simulate failure
+ fake_server.set_request_handler::<lsp::request::Formatting, _, _>(
+ move |_params, _| async move { Err(anyhow::anyhow!("Simulated formatter failure")) },
+ );
+
+ // Formatter #2 (CodeAction) returns a successful edit
+ fake_server.set_request_handler::<lsp::request::CodeActionRequest, _, _>(
+ move |_params, _| async move {
+ let uri = lsp::Uri::from_file_path(path!("/file.rs")).unwrap();
+ Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction(
+ lsp::CodeAction {
+ kind: Some("organize-imports".into()),
+ edit: Some(lsp::WorkspaceEdit::new(
+ [(
+ uri,
+ vec![lsp::TextEdit::new(
+ lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
+ "use std::io;\n".to_string(),
+ )],
+ )]
+ .into_iter()
+ .collect(),
+ )),
+ ..Default::default()
+ },
+ )]))
+ },
+ );
+
+ fake_server.set_request_handler::<lsp::request::CodeActionResolveRequest, _, _>({
+ move |params, _| async move { Ok(params) }
+ });
+
+ editor
+ .update_in(cx, |editor, window, cx| {
+ editor.perform_format(
+ project.clone(),
+ FormatTrigger::Manual,
+ FormatTarget::Buffers(editor.buffer().read(cx).all_buffers()),
+ window,
+ cx,
+ )
+ })
+ .unwrap()
+ .await;
+
+ // Formatter #1 (LanguageServer) failed, but formatter #2 (CodeAction) should have applied
+ editor.update(cx, |editor, cx| {
+ assert_eq!(editor.text(cx), "use std::io;\nfn main() {}\n");
+ });
+
+ // The entire format operation should undo as one transaction
+ editor.update_in(cx, |editor, window, cx| {
+ editor.undo(&Default::default(), window, cx);
+ assert_eq!(editor.text(cx), "fn main() {}\n");
+ });
+}
+
#[gpui::test]
async fn test_concurrent_format_requests(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -16,10 +16,7 @@ use project::project_settings::ProjectSettings;
use settings::Settings;
use std::sync::Arc;
use time::OffsetDateTime;
-use ui::{
- Divider, HighlightedLabel, KeyBinding, ListHeader, ListItem, ListItemSpacing, Tooltip,
- prelude::*,
-};
+use ui::{Divider, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
use ui_input::ErasedEditor;
use util::ResultExt;
use workspace::notifications::DetachAndPromptErr;
@@ -1084,21 +1081,6 @@ impl PickerDelegate for BranchListDelegate {
)
}
- fn render_header(
- &self,
- _window: &mut Window,
- _cx: &mut Context<Picker<Self>>,
- ) -> Option<AnyElement> {
- matches!(self.state, PickerState::List).then(|| {
- let label = match self.branch_filter {
- BranchFilter::All => "Branches",
- BranchFilter::Remote => "Remotes",
- };
-
- ListHeader::new(label).inset(true).into_any_element()
- })
- }
-
fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
if self.editor_position() == PickerEditorPosition::End {
return None;
@@ -1193,7 +1175,11 @@ impl PickerDelegate for BranchListDelegate {
this.justify_between()
.child({
let focus_handle = focus_handle.clone();
- Button::new("filter-remotes", "Filter Remotes")
+ let filter_label = match self.branch_filter {
+ BranchFilter::All => "Filter Remote",
+ BranchFilter::Remote => "Show All",
+ };
+ Button::new("filter-remotes", filter_label)
.toggle_state(matches!(
self.branch_filter,
BranchFilter::Remote
@@ -2276,6 +2276,7 @@ impl GitPanel {
RewrapOptions {
override_language_settings: false,
preserve_existing_whitespace: true,
+ line_length: None,
},
cx,
);
@@ -25,8 +25,8 @@ actions!(
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GitPickerTab {
- Branches,
Worktrees,
+ Branches,
Stash,
}
@@ -190,9 +190,9 @@ impl GitPicker {
fn activate_next_tab(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.tab = match self.tab {
- GitPickerTab::Branches => GitPickerTab::Worktrees,
- GitPickerTab::Worktrees => GitPickerTab::Stash,
- GitPickerTab::Stash => GitPickerTab::Branches,
+ GitPickerTab::Worktrees => GitPickerTab::Branches,
+ GitPickerTab::Branches => GitPickerTab::Stash,
+ GitPickerTab::Stash => GitPickerTab::Worktrees,
};
self.ensure_active_picker(window, cx);
self.focus_active_picker(window, cx);
@@ -201,9 +201,9 @@ impl GitPicker {
fn activate_previous_tab(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.tab = match self.tab {
- GitPickerTab::Branches => GitPickerTab::Stash,
- GitPickerTab::Worktrees => GitPickerTab::Branches,
- GitPickerTab::Stash => GitPickerTab::Worktrees,
+ GitPickerTab::Worktrees => GitPickerTab::Stash,
+ GitPickerTab::Branches => GitPickerTab::Worktrees,
+ GitPickerTab::Stash => GitPickerTab::Branches,
};
self.ensure_active_picker(window, cx);
self.focus_active_picker(window, cx);
@@ -241,9 +241,9 @@ impl GitPicker {
"git-picker-tabs",
[
ToggleButtonSimple::new(
- GitPickerTab::Branches.to_string(),
+ GitPickerTab::Worktrees.to_string(),
cx.listener(|this, _, window, cx| {
- this.tab = GitPickerTab::Branches;
+ this.tab = GitPickerTab::Worktrees;
this.ensure_active_picker(window, cx);
this.focus_active_picker(window, cx);
cx.notify();
@@ -251,16 +251,16 @@ impl GitPicker {
)
.tooltip(move |_, cx| {
Tooltip::for_action_in(
- "Toggle Branch Picker",
- &ActivateBranchesTab,
- &branches_focus_handle,
+ "Toggle Worktree Picker",
+ &ActivateWorktreesTab,
+ &worktrees_focus_handle,
cx,
)
}),
ToggleButtonSimple::new(
- GitPickerTab::Worktrees.to_string(),
+ GitPickerTab::Branches.to_string(),
cx.listener(|this, _, window, cx| {
- this.tab = GitPickerTab::Worktrees;
+ this.tab = GitPickerTab::Branches;
this.ensure_active_picker(window, cx);
this.focus_active_picker(window, cx);
cx.notify();
@@ -268,9 +268,9 @@ impl GitPicker {
)
.tooltip(move |_, cx| {
Tooltip::for_action_in(
- "Toggle Worktree Picker",
- &ActivateWorktreesTab,
- &worktrees_focus_handle,
+ "Toggle Branch Picker",
+ &ActivateBranchesTab,
+ &branches_focus_handle,
cx,
)
}),
@@ -297,8 +297,8 @@ impl GitPicker {
.style(ToggleButtonGroupStyle::Outlined)
.auto_width()
.selected_index(match self.tab {
- GitPickerTab::Branches => 0,
- GitPickerTab::Worktrees => 1,
+ GitPickerTab::Worktrees => 0,
+ GitPickerTab::Branches => 1,
GitPickerTab::Stash => 2,
}),
)
@@ -2,7 +2,10 @@
use anyhow::Result;
use buffer_diff::BufferDiff;
-use editor::{Editor, EditorEvent, MultiBuffer, ToPoint, actions::DiffClipboardWithSelectionData};
+use editor::{
+ Editor, EditorEvent, EditorSettings, MultiBuffer, SplittableEditor, ToPoint,
+ actions::DiffClipboardWithSelectionData,
+};
use futures::{FutureExt, select_biased};
use gpui::{
AnyElement, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FocusHandle,
@@ -10,6 +13,7 @@ use gpui::{
};
use language::{self, Buffer, Point};
use project::Project;
+use settings::Settings;
use std::{
any::{Any, TypeId},
cmp,
@@ -22,13 +26,13 @@ use ui::{Color, Icon, IconName, Label, LabelCommon as _, SharedString};
use util::paths::PathExt;
use workspace::{
- Item, ItemHandle as _, ItemNavHistory, Workspace,
+ Item, ItemNavHistory, Workspace,
item::{ItemEvent, SaveOptions, TabContentParams},
searchable::SearchableItemHandle,
};
pub struct TextDiffView {
- diff_editor: Entity<Editor>,
+ diff_editor: Entity<SplittableEditor>,
title: SharedString,
path: Option<SharedString>,
buffer_changes_tx: watch::Sender<()>,
@@ -125,11 +129,11 @@ impl TextDiffView {
);
let task = window.spawn(cx, async move |cx| {
- let project = workspace.update(cx, |workspace, _| workspace.project().clone())?;
-
update_diff_buffer(&diff_buffer, &source_buffer, &clipboard_buffer, cx).await?;
workspace.update_in(cx, |workspace, window, cx| {
+ let project = workspace.project().clone();
+ let workspace_entity = cx.entity();
let diff_view = cx.new(|cx| {
TextDiffView::new(
clipboard_buffer,
@@ -138,6 +142,7 @@ impl TextDiffView {
expanded_selection_range,
diff_buffer,
project,
+ workspace_entity,
window,
cx,
)
@@ -162,6 +167,7 @@ impl TextDiffView {
source_range: Range<Point>,
diff_buffer: Entity<BufferDiff>,
project: Entity<Project>,
+ workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -174,15 +180,24 @@ impl TextDiffView {
multibuffer
});
let diff_editor = cx.new(|cx| {
- let mut editor = Editor::for_multibuffer(multibuffer, Some(project), window, cx);
- editor.start_temporary_diff_override();
- editor.disable_diagnostics(cx);
- editor.set_expand_all_diff_hunks(cx);
- editor.set_render_diff_hunk_controls(
+ let splittable = SplittableEditor::new(
+ EditorSettings::get_global(cx).diff_view_style,
+ multibuffer,
+ project,
+ workspace,
+ window,
+ cx,
+ );
+ splittable.set_render_diff_hunk_controls(
Arc::new(|_, _, _, _, _, _, _, _| gpui::Empty.into_any_element()),
cx,
);
- editor
+ splittable.rhs_editor().update(cx, |editor, cx| {
+ editor.start_temporary_diff_override();
+ editor.disable_diagnostics(cx);
+ editor.set_expand_all_diff_hunks(cx);
+ });
+ splittable
});
let (buffer_changes_tx, mut buffer_changes_rx) = watch::channel(());
@@ -352,12 +367,14 @@ impl Item for TextDiffView {
&'a self,
type_id: TypeId,
self_handle: &'a Entity<Self>,
- _: &'a App,
+ cx: &'a App,
) -> Option<gpui::AnyEntity> {
if type_id == TypeId::of::<Self>() {
Some(self_handle.clone().into())
- } else if type_id == TypeId::of::<Editor>() {
+ } else if type_id == TypeId::of::<SplittableEditor>() {
Some(self.diff_editor.clone().into())
+ } else if type_id == TypeId::of::<Editor>() {
+ Some(self.diff_editor.read(cx).rhs_editor().clone().into())
} else {
None
}
@@ -372,7 +389,7 @@ impl Item for TextDiffView {
cx: &App,
f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem),
) {
- self.diff_editor.for_each_project_item(cx, f)
+ self.diff_editor.read(cx).for_each_project_item(cx, f)
}
fn set_nav_history(
@@ -381,7 +398,8 @@ impl Item for TextDiffView {
_: &mut Window,
cx: &mut Context<Self>,
) {
- self.diff_editor.update(cx, |editor, _| {
+ let rhs = self.diff_editor.read(cx).rhs_editor().clone();
+ rhs.update(cx, |editor, _| {
editor.set_nav_history(Some(nav_history));
});
}
@@ -463,11 +481,11 @@ impl Render for TextDiffView {
mod tests {
use super::*;
use editor::{MultiBufferOffset, PathKey, test::editor_test_context::assert_state_with_diff};
- use gpui::{TestAppContext, VisualContext};
+ use gpui::{BorrowAppContext, TestAppContext, VisualContext};
use language::Point;
use project::{FakeFs, Project};
use serde_json::json;
- use settings::SettingsStore;
+ use settings::{DiffViewStyle, SettingsStore};
use unindent::unindent;
use util::{path, test::marked_text_ranges};
use workspace::MultiWorkspace;
@@ -476,6 +494,11 @@ mod tests {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
+ cx.update_global::<SettingsStore, _>(|store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.editor.diff_view_style = Some(DiffViewStyle::Unified);
+ });
+ });
theme::init(theme::LoadThemes::JustBase, cx);
});
}
@@ -918,7 +941,9 @@ mod tests {
cx.executor().run_until_parked();
assert_state_with_diff(
- &diff_view.read_with(cx, |diff_view, _| diff_view.diff_editor.clone()),
+ &diff_view.read_with(cx, |diff_view, cx| {
+ diff_view.diff_editor.read(cx).rhs_editor().clone()
+ }),
cx,
expected_diff,
);
@@ -1,7 +1,7 @@
use anyhow::{Context as _, Result};
use collections::HashMap;
pub use gpui_macros::Action;
-pub use no_action::{NoAction, is_no_action};
+pub use no_action::{NoAction, Unbind, is_no_action, is_unbind};
use serde_json::json;
use std::{
any::{Any, TypeId},
@@ -290,19 +290,6 @@ impl ActionRegistry {
}
}
- #[cfg(test)]
- pub(crate) fn load_action<A: Action>(&mut self) {
- self.insert_action(MacroActionData {
- name: A::name_for_type(),
- type_id: TypeId::of::<A>(),
- build: A::build,
- json_schema: A::action_json_schema,
- deprecated_aliases: A::deprecated_aliases(),
- deprecation_message: A::deprecation_message(),
- documentation: A::documentation(),
- });
- }
-
fn insert_action(&mut self, action: MacroActionData) {
let name = action.name;
if self.by_name.contains_key(name) {
@@ -432,7 +419,8 @@ pub fn generate_list_of_all_registered_actions() -> impl Iterator<Item = MacroAc
mod no_action {
use crate as gpui;
- use std::any::Any as _;
+ use schemars::JsonSchema;
+ use serde::Deserialize;
actions!(
zed,
@@ -443,8 +431,23 @@ mod no_action {
]
);
+ /// Action with special handling which unbinds later bindings for the same keystrokes when they
+ /// dispatch the named action, regardless of that action's context.
+ ///
+ /// In keymap JSON this is written as:
+ ///
+ /// `["zed::Unbind", "editor::NewLine"]`
+ #[derive(Clone, Debug, PartialEq, Deserialize, JsonSchema, gpui::Action)]
+ #[action(namespace = zed)]
+ pub struct Unbind(pub gpui::SharedString);
+
/// Returns whether or not this action represents a removed key binding.
pub fn is_no_action(action: &dyn gpui::Action) -> bool {
- action.as_any().type_id() == (NoAction {}).type_id()
+ action.as_any().is::<NoAction>()
+ }
+
+ /// Returns whether or not this action represents an unbind marker.
+ pub fn is_unbind(action: &dyn gpui::Action) -> bool {
+ action.as_any().is::<Unbind>()
}
}
@@ -49,7 +49,8 @@ use crate::{
PlatformKeyboardMapper, Point, Priority, PromptBuilder, PromptButton, PromptHandle,
PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource,
SharedString, SubscriberSet, Subscription, SvgRenderer, Task, TextRenderingMode, TextSystem,
- ThermalState, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
+ ThermalState, Window, WindowAppearance, WindowButtonLayout, WindowHandle, WindowId,
+ WindowInvalidator,
colors::{Colors, GlobalColors},
hash, init_app_menus,
};
@@ -1177,6 +1178,11 @@ impl App {
self.platform.window_appearance()
}
+ /// Returns the window button layout configuration when supported.
+ pub fn button_layout(&self) -> Option<WindowButtonLayout> {
+ self.platform.button_layout()
+ }
+
/// Reads data from the platform clipboard.
pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
self.platform.read_from_clipboard()
@@ -479,6 +479,24 @@ impl<'a, T: 'static> Context<'a, T> {
subscription
}
+ /// Registers a callback to be invoked when the window button layout changes.
+ pub fn observe_button_layout_changed(
+ &self,
+ window: &mut Window,
+ mut callback: impl FnMut(&mut T, &mut Window, &mut Context<T>) + 'static,
+ ) -> Subscription {
+ let view = self.weak_entity();
+ let (subscription, activate) = window.button_layout_observers.insert(
+ (),
+ Box::new(move |window, cx| {
+ view.update(cx, |view, cx| callback(view, window, cx))
+ .is_ok()
+ }),
+ );
+ activate();
+ subscription
+ }
+
/// Register a callback to be invoked when a keystroke is received by the application
/// in any window. Note that this fires after all other action and event mechanisms have resolved
/// and that this API will not be invoked if the event's propagation is stopped.
@@ -629,66 +629,99 @@ mod tests {
use std::{cell::RefCell, ops::Range, rc::Rc};
use crate::{
- Action, ActionRegistry, App, Bounds, Context, DispatchTree, FocusHandle, InputHandler,
- IntoElement, KeyBinding, KeyContext, Keymap, Pixels, Point, Render, Subscription,
- TestAppContext, UTF16Selection, Window,
+ ActionRegistry, App, Bounds, Context, DispatchTree, FocusHandle, InputHandler, IntoElement,
+ KeyBinding, KeyContext, Keymap, Pixels, Point, Render, Subscription, TestAppContext,
+ UTF16Selection, Unbind, Window,
};
- #[derive(PartialEq, Eq)]
- struct TestAction;
+ actions!(dispatch_test, [TestAction, SecondaryTestAction]);
- impl Action for TestAction {
- fn name(&self) -> &'static str {
- "test::TestAction"
- }
-
- fn name_for_type() -> &'static str
- where
- Self: ::std::marker::Sized,
- {
- "test::TestAction"
- }
-
- fn partial_eq(&self, action: &dyn Action) -> bool {
- action.as_any().downcast_ref::<Self>() == Some(self)
- }
-
- fn boxed_clone(&self) -> std::boxed::Box<dyn Action> {
- Box::new(TestAction)
- }
+ fn test_dispatch_tree(bindings: Vec<KeyBinding>) -> DispatchTree {
+ let registry = ActionRegistry::default();
- fn build(_value: serde_json::Value) -> anyhow::Result<Box<dyn Action>>
- where
- Self: Sized,
- {
- Ok(Box::new(TestAction))
- }
+ DispatchTree::new(
+ Rc::new(RefCell::new(Keymap::new(bindings))),
+ Rc::new(registry),
+ )
}
#[test]
fn test_keybinding_for_action_bounds() {
- let keymap = Keymap::new(vec![KeyBinding::new(
+ let tree = test_dispatch_tree(vec![KeyBinding::new(
"cmd-n",
TestAction,
Some("ProjectPanel"),
)]);
- let mut registry = ActionRegistry::default();
+ let contexts = vec![
+ KeyContext::parse("Workspace").unwrap(),
+ KeyContext::parse("ProjectPanel").unwrap(),
+ ];
+
+ let keybinding = tree.bindings_for_action(&TestAction, &contexts);
+
+ assert!(keybinding[0].action.partial_eq(&TestAction))
+ }
+
+ #[test]
+ fn test_bindings_for_action_hides_targeted_unbind_in_active_context() {
+ let tree = test_dispatch_tree(vec![
+ KeyBinding::new("tab", TestAction, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("dispatch_test::TestAction".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new(
+ "tab",
+ SecondaryTestAction,
+ Some("Editor && showing_completions"),
+ ),
+ ]);
+
+ let contexts = vec![
+ KeyContext::parse("Workspace").unwrap(),
+ KeyContext::parse("Editor showing_completions edit_prediction").unwrap(),
+ ];
- registry.load_action::<TestAction>();
+ let bindings = tree.bindings_for_action(&TestAction, &contexts);
+ assert!(bindings.is_empty());
- let keymap = Rc::new(RefCell::new(keymap));
+ let highest = tree.highest_precedence_binding_for_action(&TestAction, &contexts);
+ assert!(highest.is_none());
+
+ let fallback_bindings = tree.bindings_for_action(&SecondaryTestAction, &contexts);
+ assert_eq!(fallback_bindings.len(), 1);
+ assert!(fallback_bindings[0].action.partial_eq(&SecondaryTestAction));
+ }
- let tree = DispatchTree::new(keymap, Rc::new(registry));
+ #[test]
+ fn test_bindings_for_action_keeps_targeted_binding_outside_unbind_context() {
+ let tree = test_dispatch_tree(vec![
+ KeyBinding::new("tab", TestAction, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("dispatch_test::TestAction".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new(
+ "tab",
+ SecondaryTestAction,
+ Some("Editor && showing_completions"),
+ ),
+ ]);
let contexts = vec![
KeyContext::parse("Workspace").unwrap(),
- KeyContext::parse("ProjectPanel").unwrap(),
+ KeyContext::parse("Editor").unwrap(),
];
- let keybinding = tree.bindings_for_action(&TestAction, &contexts);
+ let bindings = tree.bindings_for_action(&TestAction, &contexts);
+ assert_eq!(bindings.len(), 1);
+ assert!(bindings[0].action.partial_eq(&TestAction));
- assert!(keybinding[0].action.partial_eq(&TestAction))
+ let highest = tree.highest_precedence_binding_for_action(&TestAction, &contexts);
+ assert!(highest.is_some_and(|binding| binding.action.partial_eq(&TestAction)));
}
#[test]
@@ -698,10 +731,7 @@ mod tests {
KeyBinding::new("space", TestAction, Some("ContextA")),
KeyBinding::new("space f g", TestAction, Some("ContextB")),
];
- let keymap = Rc::new(RefCell::new(Keymap::new(bindings)));
- let mut registry = ActionRegistry::default();
- registry.load_action::<TestAction>();
- let mut tree = DispatchTree::new(keymap, Rc::new(registry));
+ let mut tree = test_dispatch_tree(bindings);
type DispatchPath = SmallVec<[super::DispatchNodeId; 32]>;
fn dispatch(
@@ -4,7 +4,7 @@ mod context;
pub use binding::*;
pub use context::*;
-use crate::{Action, AsKeystroke, Keystroke, is_no_action};
+use crate::{Action, AsKeystroke, Keystroke, Unbind, is_no_action, is_unbind};
use collections::{HashMap, HashSet};
use smallvec::SmallVec;
use std::any::TypeId;
@@ -19,7 +19,7 @@ pub struct KeymapVersion(usize);
pub struct Keymap {
bindings: Vec<KeyBinding>,
binding_indices_by_action_id: HashMap<TypeId, SmallVec<[usize; 3]>>,
- no_action_binding_indices: Vec<usize>,
+ disabled_binding_indices: Vec<usize>,
version: KeymapVersion,
}
@@ -27,6 +27,26 @@ pub struct Keymap {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct BindingIndex(usize);
+fn disabled_binding_matches_context(disabled_binding: &KeyBinding, binding: &KeyBinding) -> bool {
+ match (
+ &disabled_binding.context_predicate,
+ &binding.context_predicate,
+ ) {
+ (None, _) => true,
+ (Some(_), None) => false,
+ (Some(disabled_predicate), Some(predicate)) => disabled_predicate.is_superset(predicate),
+ }
+}
+
+fn binding_is_unbound(disabled_binding: &KeyBinding, binding: &KeyBinding) -> bool {
+ disabled_binding.keystrokes == binding.keystrokes
+ && disabled_binding
+ .action()
+ .as_any()
+ .downcast_ref::<Unbind>()
+ .is_some_and(|unbind| unbind.0.as_ref() == binding.action.name())
+}
+
impl Keymap {
/// Create a new keymap with the given bindings.
pub fn new(bindings: Vec<KeyBinding>) -> Self {
@@ -44,8 +64,8 @@ impl Keymap {
pub fn add_bindings<T: IntoIterator<Item = KeyBinding>>(&mut self, bindings: T) {
for binding in bindings {
let action_id = binding.action().as_any().type_id();
- if is_no_action(&*binding.action) {
- self.no_action_binding_indices.push(self.bindings.len());
+ if is_no_action(&*binding.action) || is_unbind(&*binding.action) {
+ self.disabled_binding_indices.push(self.bindings.len());
} else {
self.binding_indices_by_action_id
.entry(action_id)
@@ -62,7 +82,7 @@ impl Keymap {
pub fn clear(&mut self) {
self.bindings.clear();
self.binding_indices_by_action_id.clear();
- self.no_action_binding_indices.clear();
+ self.disabled_binding_indices.clear();
self.version.0 += 1;
}
@@ -90,21 +110,22 @@ impl Keymap {
return None;
}
- for null_ix in &self.no_action_binding_indices {
- if null_ix > ix {
- let null_binding = &self.bindings[*null_ix];
- if null_binding.keystrokes == binding.keystrokes {
- let null_binding_matches =
- match (&null_binding.context_predicate, &binding.context_predicate) {
- (None, _) => true,
- (Some(_), None) => false,
- (Some(null_predicate), Some(predicate)) => {
- null_predicate.is_superset(predicate)
- }
- };
- if null_binding_matches {
+ for disabled_ix in &self.disabled_binding_indices {
+ if disabled_ix > ix {
+ let disabled_binding = &self.bindings[*disabled_ix];
+ if disabled_binding.keystrokes != binding.keystrokes {
+ continue;
+ }
+
+ if is_no_action(&*disabled_binding.action) {
+ if disabled_binding_matches_context(disabled_binding, binding) {
return None;
}
+ } else if is_unbind(&*disabled_binding.action)
+ && disabled_binding_matches_context(disabled_binding, binding)
+ && binding_is_unbound(disabled_binding, binding)
+ {
+ return None;
}
}
}
@@ -170,6 +191,7 @@ impl Keymap {
let mut bindings: SmallVec<[_; 1]> = SmallVec::new();
let mut first_binding_index = None;
+ let mut unbound_bindings: Vec<&KeyBinding> = Vec::new();
for (_, ix, binding) in matched_bindings {
if is_no_action(&*binding.action) {
@@ -186,6 +208,19 @@ impl Keymap {
// For non-user NoAction bindings, continue searching for user overrides
continue;
}
+
+ if is_unbind(&*binding.action) {
+ unbound_bindings.push(binding);
+ continue;
+ }
+
+ if unbound_bindings
+ .iter()
+ .any(|disabled_binding| binding_is_unbound(disabled_binding, binding))
+ {
+ continue;
+ }
+
bindings.push(binding.clone());
first_binding_index.get_or_insert(ix);
}
@@ -197,7 +232,7 @@ impl Keymap {
{
continue;
}
- if is_no_action(&*binding.action) {
+ if is_no_action(&*binding.action) || is_unbind(&*binding.action) {
pending.remove(&&binding.keystrokes);
continue;
}
@@ -232,7 +267,10 @@ impl Keymap {
match pending {
None => None,
Some(is_pending) => {
- if !is_pending || is_no_action(&*binding.action) {
+ if !is_pending
+ || is_no_action(&*binding.action)
+ || is_unbind(&*binding.action)
+ {
return None;
}
Some((depth, BindingIndex(ix), binding))
@@ -256,7 +294,7 @@ impl Keymap {
mod tests {
use super::*;
use crate as gpui;
- use gpui::NoAction;
+ use gpui::{NoAction, Unbind};
actions!(
test_only,
@@ -720,6 +758,76 @@ mod tests {
}
}
+ #[test]
+ fn test_targeted_unbind_ignores_target_context() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor")),
+ KeyBinding::new("tab", ActionBeta {}, Some("Editor && showing_completions")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ let (result, pending) = keymap.bindings_for_input(
+ &[Keystroke::parse("tab").unwrap()],
+ &[KeyContext::parse("Editor showing_completions edit_prediction").unwrap()],
+ );
+
+ assert!(!pending);
+ assert_eq!(result.len(), 1);
+ assert!(result[0].action.partial_eq(&ActionBeta {}));
+ }
+
+ #[test]
+ fn test_bindings_for_action_keeps_binding_for_narrower_targeted_unbind() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new("tab", ActionBeta {}, Some("Editor && showing_completions")),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ assert_bindings(&keymap, &ActionAlpha {}, &["tab"]);
+ assert_bindings(&keymap, &ActionBeta {}, &["tab"]);
+
+ #[track_caller]
+ fn assert_bindings(keymap: &Keymap, action: &dyn Action, expected: &[&str]) {
+ let actual = keymap
+ .bindings_for_action(action)
+ .map(|binding| binding.keystrokes[0].inner().unparse())
+ .collect::<Vec<_>>();
+ assert_eq!(actual, expected, "{:?}", action);
+ }
+ }
+
+ #[test]
+ fn test_bindings_for_action_removes_binding_for_broader_targeted_unbind() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor && edit_prediction")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor"),
+ ),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ assert!(keymap.bindings_for_action(&ActionAlpha {}).next().is_none());
+ }
+
#[test]
fn test_source_precedence_sorting() {
// KeybindSource precedence: User (0) > Vim (1) > Base (2) > Default (3)
@@ -37,6 +37,8 @@ use crate::{
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
};
use anyhow::Result;
+#[cfg(any(target_os = "linux", target_os = "freebsd"))]
+use anyhow::bail;
use async_task::Runnable;
use futures::channel::oneshot;
#[cfg(any(test, feature = "test-support"))]
@@ -156,6 +158,11 @@ pub trait Platform: 'static {
/// Returns the appearance of the application's windows.
fn window_appearance(&self) -> WindowAppearance;
+ /// Returns the window button layout configuration when supported.
+ fn button_layout(&self) -> Option<WindowButtonLayout> {
+ None
+ }
+
fn open_url(&self, url: &str);
fn on_open_urls(&self, callback: Box<dyn FnMut(Vec<String>)>);
fn register_url_scheme(&self, url: &str) -> Task<Result<()>>;
@@ -407,6 +414,145 @@ impl Default for WindowControls {
}
}
+/// A window control button type used in [`WindowButtonLayout`].
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum WindowButton {
+ /// The minimize button
+ Minimize,
+ /// The maximize button
+ Maximize,
+ /// The close button
+ Close,
+}
+
+impl WindowButton {
+ /// Returns a stable element ID for rendering this button.
+ pub fn id(&self) -> &'static str {
+ match self {
+ WindowButton::Minimize => "minimize",
+ WindowButton::Maximize => "maximize",
+ WindowButton::Close => "close",
+ }
+ }
+
+ #[cfg(any(target_os = "linux", target_os = "freebsd"))]
+ fn index(&self) -> usize {
+ match self {
+ WindowButton::Minimize => 0,
+ WindowButton::Maximize => 1,
+ WindowButton::Close => 2,
+ }
+ }
+}
+
+/// Maximum number of [`WindowButton`]s per side in the titlebar.
+pub const MAX_BUTTONS_PER_SIDE: usize = 3;
+
+/// Describes which [`WindowButton`]s appear on each side of the titlebar.
+///
+/// On Linux, this is read from the desktop environment's configuration
+/// (e.g. GNOME's `gtk-decoration-layout` gsetting) via [`WindowButtonLayout::parse`].
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct WindowButtonLayout {
+ /// Buttons on the left side of the titlebar.
+ pub left: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ /// Buttons on the right side of the titlebar.
+ pub right: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+}
+
+#[cfg(any(target_os = "linux", target_os = "freebsd"))]
+impl WindowButtonLayout {
+ /// Returns Zed's built-in fallback button layout for Linux titlebars.
+ pub fn linux_default() -> Self {
+ Self {
+ left: [None; MAX_BUTTONS_PER_SIDE],
+ right: [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close),
+ ],
+ }
+ }
+
+ /// Parses a GNOME-style `button-layout` string (e.g. `"close,minimize:maximize"`).
+ pub fn parse(layout_string: &str) -> Result<Self> {
+ fn parse_side(
+ s: &str,
+ seen_buttons: &mut [bool; MAX_BUTTONS_PER_SIDE],
+ unrecognized: &mut Vec<String>,
+ ) -> [Option<WindowButton>; MAX_BUTTONS_PER_SIDE] {
+ let mut result = [None; MAX_BUTTONS_PER_SIDE];
+ let mut i = 0;
+ for name in s.split(',') {
+ let trimmed = name.trim();
+ if trimmed.is_empty() {
+ continue;
+ }
+ let button = match trimmed {
+ "minimize" => Some(WindowButton::Minimize),
+ "maximize" => Some(WindowButton::Maximize),
+ "close" => Some(WindowButton::Close),
+ other => {
+ unrecognized.push(other.to_string());
+ None
+ }
+ };
+ if let Some(button) = button {
+ if seen_buttons[button.index()] {
+ continue;
+ }
+ if let Some(slot) = result.get_mut(i) {
+ *slot = Some(button);
+ seen_buttons[button.index()] = true;
+ i += 1;
+ }
+ }
+ }
+ result
+ }
+
+ let (left_str, right_str) = layout_string.split_once(':').unwrap_or(("", layout_string));
+ let mut unrecognized = Vec::new();
+ let mut seen_buttons = [false; MAX_BUTTONS_PER_SIDE];
+ let layout = Self {
+ left: parse_side(left_str, &mut seen_buttons, &mut unrecognized),
+ right: parse_side(right_str, &mut seen_buttons, &mut unrecognized),
+ };
+
+ if !unrecognized.is_empty()
+ && layout.left.iter().all(Option::is_none)
+ && layout.right.iter().all(Option::is_none)
+ {
+ bail!(
+ "button layout string {:?} contains no valid buttons (unrecognized: {})",
+ layout_string,
+ unrecognized.join(", ")
+ );
+ }
+
+ Ok(layout)
+ }
+
+ /// Formats the layout back into a GNOME-style `button-layout` string.
+ #[cfg(test)]
+ pub fn format(&self) -> String {
+ fn format_side(buttons: &[Option<WindowButton>; MAX_BUTTONS_PER_SIDE]) -> String {
+ buttons
+ .iter()
+ .flatten()
+ .map(|button| match button {
+ WindowButton::Minimize => "minimize",
+ WindowButton::Maximize => "maximize",
+ WindowButton::Close => "close",
+ })
+ .collect::<Vec<_>>()
+ .join(",")
+ }
+
+ format!("{}:{}", format_side(&self.left), format_side(&self.right))
+ }
+}
+
/// A type to describe which sides of the window are currently tiled in some way
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Default)]
pub struct Tiling {
@@ -488,6 +634,7 @@ pub trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn on_hit_test_window_control(&self, callback: Box<dyn FnMut() -> Option<WindowControlArea>>);
fn on_close(&self, callback: Box<dyn FnOnce()>);
fn on_appearance_changed(&self, callback: Box<dyn FnMut()>);
+ fn on_button_layout_changed(&self, _callback: Box<dyn FnMut()>) {}
fn draw(&self, scene: &Scene);
fn completed_frame(&self) {}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas>;
@@ -2023,3 +2170,185 @@ impl From<String> for ClipboardString {
}
}
}
+
+#[cfg(all(test, any(target_os = "linux", target_os = "freebsd")))]
+mod tests {
+ use super::*;
+ use std::collections::HashSet;
+
+ #[test]
+ fn test_window_button_layout_parse_standard() {
+ let layout = WindowButtonLayout::parse("close,minimize:maximize").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_right_only() {
+ let layout = WindowButtonLayout::parse("minimize,maximize,close").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close)
+ ]
+ );
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_left_only() {
+ let layout = WindowButtonLayout::parse("close,minimize,maximize:").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize)
+ ]
+ );
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_with_whitespace() {
+ let layout = WindowButtonLayout::parse(" close , minimize : maximize ").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_empty() {
+ let layout = WindowButtonLayout::parse("").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_intentionally_empty() {
+ let layout = WindowButtonLayout::parse(":").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_invalid_buttons() {
+ let layout = WindowButtonLayout::parse("close,invalid,minimize:maximize,foo").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_deduplicates_same_side_buttons() {
+ let layout = WindowButtonLayout::parse("close,close,minimize").unwrap();
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.format(), ":close,minimize");
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_deduplicates_buttons_across_sides() {
+ let layout = WindowButtonLayout::parse("close:maximize,close,minimize").unwrap();
+ assert_eq!(layout.left, [Some(WindowButton::Close), None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+
+ let button_ids: Vec<_> = layout
+ .left
+ .iter()
+ .chain(layout.right.iter())
+ .flatten()
+ .map(WindowButton::id)
+ .collect();
+ let unique_button_ids = button_ids.iter().copied().collect::<HashSet<_>>();
+ assert_eq!(unique_button_ids.len(), button_ids.len());
+ assert_eq!(layout.format(), "close:maximize,minimize");
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_gnome_style() {
+ let layout = WindowButtonLayout::parse("close").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [Some(WindowButton::Close), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_elementary_style() {
+ let layout = WindowButtonLayout::parse("close:maximize").unwrap();
+ assert_eq!(layout.left, [Some(WindowButton::Close), None, None]);
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_round_trip() {
+ let cases = [
+ "close:minimize,maximize",
+ "minimize,maximize,close:",
+ ":close",
+ "close:",
+ "close:maximize",
+ ":",
+ ];
+
+ for case in cases {
+ let layout = WindowButtonLayout::parse(case).unwrap();
+ assert_eq!(layout.format(), case, "Round-trip failed for: {}", case);
+ }
+ }
+
+ #[test]
+ fn test_window_button_layout_linux_default() {
+ let layout = WindowButtonLayout::linux_default();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close)
+ ]
+ );
+
+ let round_tripped = WindowButtonLayout::parse(&layout.format()).unwrap();
+ assert_eq!(round_tripped, layout);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_all_invalid() {
+ assert!(WindowButtonLayout::parse("asdfghjkl").is_err());
+ }
+}
@@ -951,6 +951,7 @@ pub struct Window {
pub(crate) bounds_observers: SubscriberSet<(), AnyObserver>,
appearance: WindowAppearance,
pub(crate) appearance_observers: SubscriberSet<(), AnyObserver>,
+ pub(crate) button_layout_observers: SubscriberSet<(), AnyObserver>,
active: Rc<Cell<bool>>,
hovered: Rc<Cell<bool>>,
pub(crate) needs_present: Rc<Cell<bool>>,
@@ -1288,6 +1289,14 @@ impl Window {
.log_err();
}
}));
+ platform_window.on_button_layout_changed(Box::new({
+ let mut cx = cx.to_async();
+ move || {
+ handle
+ .update(&mut cx, |_, window, cx| window.button_layout_changed(cx))
+ .log_err();
+ }
+ }));
platform_window.on_active_status_change(Box::new({
let mut cx = cx.to_async();
move |active| {
@@ -1442,6 +1451,7 @@ impl Window {
bounds_observers: SubscriberSet::new(),
appearance,
appearance_observers: SubscriberSet::new(),
+ button_layout_observers: SubscriberSet::new(),
active,
hovered,
needs_present,
@@ -1534,6 +1544,22 @@ impl Window {
subscription
}
+ /// Registers a callback to be invoked when the window button layout changes.
+ pub fn observe_button_layout_changed(
+ &self,
+ mut callback: impl FnMut(&mut Window, &mut App) + 'static,
+ ) -> Subscription {
+ let (subscription, activate) = self.button_layout_observers.insert(
+ (),
+ Box::new(move |window, cx| {
+ callback(window, cx);
+ true
+ }),
+ );
+ activate();
+ subscription
+ }
+
/// Replaces the root entity of the window with a new one.
pub fn replace_root<E>(
&mut self,
@@ -1956,6 +1982,12 @@ impl Window {
.retain(&(), |callback| callback(self, cx));
}
+ pub(crate) fn button_layout_changed(&mut self, cx: &mut App) {
+ self.button_layout_observers
+ .clone()
+ .retain(&(), |callback| callback(self, cx));
+ }
+
/// Returns the appearance of the current window.
pub fn appearance(&self) -> WindowAppearance {
self.appearance
@@ -26,7 +26,8 @@ use gpui::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, Menu, MenuItem, OwnedMenu, PathPromptOptions, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem,
- PlatformWindow, Result, RunnableVariant, Task, ThermalState, WindowAppearance, WindowParams,
+ PlatformWindow, Result, RunnableVariant, Task, ThermalState, WindowAppearance,
+ WindowButtonLayout, WindowParams,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
use gpui::{Pixels, Point, px};
@@ -114,6 +115,7 @@ pub(crate) struct LinuxCommon {
pub(crate) text_system: Arc<dyn PlatformTextSystem>,
pub(crate) appearance: WindowAppearance,
pub(crate) auto_hide_scrollbars: bool,
+ pub(crate) button_layout: WindowButtonLayout,
pub(crate) callbacks: PlatformHandlers,
pub(crate) signal: LoopSignal,
pub(crate) menus: Vec<OwnedMenu>,
@@ -140,6 +142,7 @@ impl LinuxCommon {
text_system,
appearance: WindowAppearance::Light,
auto_hide_scrollbars: false,
+ button_layout: WindowButtonLayout::linux_default(),
callbacks,
signal,
menus: Vec::new(),
@@ -601,6 +604,10 @@ impl<P: LinuxClient + 'static> Platform for LinuxPlatform<P> {
self.inner.with_common(|common| common.appearance)
}
+ fn button_layout(&self) -> Option<WindowButtonLayout> {
+ Some(self.inner.with_common(|common| common.button_layout))
+ }
+
fn register_url_scheme(&self, _: &str) -> Task<anyhow::Result<()>> {
Task::ready(Err(anyhow!("register_url_scheme unimplemented")))
}
@@ -95,8 +95,8 @@ use gpui::{
ForegroundExecutor, KeyDownEvent, KeyUpEvent, Keystroke, Modifiers, ModifiersChangedEvent,
MouseButton, MouseDownEvent, MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection,
Pixels, PlatformDisplay, PlatformInput, PlatformKeyboardLayout, PlatformWindow, Point,
- ScrollDelta, ScrollWheelEvent, SharedString, Size, TaskTiming, TouchPhase, WindowParams, point,
- profiler, px, size,
+ ScrollDelta, ScrollWheelEvent, SharedString, Size, TaskTiming, TouchPhase, WindowButtonLayout,
+ WindowParams, point, profiler, px, size,
};
use gpui_wgpu::{CompositorGpuHint, GpuContext};
use wayland_protocols::wp::linux_dmabuf::zv1::client::{
@@ -567,6 +567,19 @@ impl WaylandClient {
}
}
}
+ XDPEvent::ButtonLayout(layout_str) => {
+ if let Some(client) = client.0.upgrade() {
+ let layout = WindowButtonLayout::parse(&layout_str)
+ .log_err()
+ .unwrap_or_else(WindowButtonLayout::linux_default);
+ let mut client = client.borrow_mut();
+ client.common.button_layout = layout;
+
+ for window in client.windows.values_mut() {
+ window.set_button_layout();
+ }
+ }
+ }
XDPEvent::CursorTheme(theme) => {
if let Some(client) = client.0.upgrade() {
let mut client = client.borrow_mut();
@@ -50,6 +50,7 @@ pub(crate) struct Callbacks {
should_close: Option<Box<dyn FnMut() -> bool>>,
close: Option<Box<dyn FnOnce()>>,
appearance_changed: Option<Box<dyn FnMut()>>,
+ button_layout_changed: Option<Box<dyn FnMut()>>,
}
#[derive(Debug, Clone, Copy)]
@@ -1038,6 +1039,14 @@ impl WaylandWindowStatePtr {
}
}
+ pub fn set_button_layout(&self) {
+ let callback = self.callbacks.borrow_mut().button_layout_changed.take();
+ if let Some(mut fun) = callback {
+ fun();
+ self.callbacks.borrow_mut().button_layout_changed = Some(fun);
+ }
+ }
+
pub fn primary_output_scale(&self) -> i32 {
self.state.borrow_mut().primary_output_scale()
}
@@ -1335,6 +1344,10 @@ impl PlatformWindow for WaylandWindow {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
+ fn on_button_layout_changed(&self, callback: Box<dyn FnMut()>) {
+ self.0.callbacks.borrow_mut().button_layout_changed = Some(callback);
+ }
+
fn draw(&self, scene: &Scene) {
let mut state = self.borrow_mut();
@@ -62,7 +62,7 @@ use gpui::{
AnyWindowHandle, Bounds, ClipboardItem, CursorStyle, DisplayId, FileDropEvent, Keystroke,
Modifiers, ModifiersChangedEvent, MouseButton, Pixels, PlatformDisplay, PlatformInput,
PlatformKeyboardLayout, PlatformWindow, Point, RequestFrameOptions, ScrollDelta, Size,
- TouchPhase, WindowParams, point, px,
+ TouchPhase, WindowButtonLayout, WindowParams, point, px,
};
use gpui_wgpu::{CompositorGpuHint, GpuContext};
@@ -472,6 +472,15 @@ impl X11Client {
window.window.set_appearance(appearance);
}
}
+ XDPEvent::ButtonLayout(layout_str) => {
+ let layout = WindowButtonLayout::parse(&layout_str)
+ .log_err()
+ .unwrap_or_else(WindowButtonLayout::linux_default);
+ client.with_common(|common| common.button_layout = layout);
+ for window in client.0.borrow_mut().windows.values_mut() {
+ window.window.set_button_layout();
+ }
+ }
XDPEvent::CursorTheme(_) | XDPEvent::CursorSize(_) => {
// noop, X11 manages this for us.
}
@@ -250,6 +250,7 @@ pub struct Callbacks {
should_close: Option<Box<dyn FnMut() -> bool>>,
close: Option<Box<dyn FnOnce()>>,
appearance_changed: Option<Box<dyn FnMut()>>,
+ button_layout_changed: Option<Box<dyn FnMut()>>,
}
pub struct X11WindowState {
@@ -1256,6 +1257,14 @@ impl X11WindowStatePtr {
self.callbacks.borrow_mut().appearance_changed = Some(fun);
}
}
+
+ pub fn set_button_layout(&self) {
+ let callback = self.callbacks.borrow_mut().button_layout_changed.take();
+ if let Some(mut fun) = callback {
+ fun();
+ self.callbacks.borrow_mut().button_layout_changed = Some(fun);
+ }
+ }
}
impl PlatformWindow for X11Window {
@@ -1602,6 +1611,10 @@ impl PlatformWindow for X11Window {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
+ fn on_button_layout_changed(&self, callback: Box<dyn FnMut()>) {
+ self.0.callbacks.borrow_mut().button_layout_changed = Some(callback);
+ }
+
fn draw(&self, scene: &Scene) {
let mut inner = self.0.state.borrow_mut();
@@ -15,6 +15,7 @@ pub enum Event {
CursorTheme(String),
#[cfg_attr(feature = "x11", allow(dead_code))]
CursorSize(u32),
+ ButtonLayout(String),
}
pub struct XDPEventSource {
@@ -51,6 +52,13 @@ impl XDPEventSource {
sender.send(Event::CursorSize(initial_size as u32))?;
}
+ if let Ok(initial_layout) = settings
+ .read::<String>("org.gnome.desktop.wm.preferences", "button-layout")
+ .await
+ {
+ sender.send(Event::ButtonLayout(initial_layout))?;
+ }
+
if let Ok(mut cursor_theme_changed) = settings
.receive_setting_changed_with_args(
"org.gnome.desktop.interface",
@@ -89,6 +97,25 @@ impl XDPEventSource {
.detach();
}
+ if let Ok(mut button_layout_changed) = settings
+ .receive_setting_changed_with_args(
+ "org.gnome.desktop.wm.preferences",
+ "button-layout",
+ )
+ .await
+ {
+ let sender = sender.clone();
+ background
+ .spawn(async move {
+ while let Some(layout) = button_layout_changed.next().await {
+ let layout = layout?;
+ sender.send(Event::ButtonLayout(layout))?;
+ }
+ anyhow::Ok(())
+ })
+ .detach();
+ }
+
let mut appearance_changed = settings.receive_color_scheme_changed().await?;
while let Some(scheme) = appearance_changed.next().await {
sender.send(Event::WindowAppearance(
@@ -414,7 +414,7 @@ impl MacPlatform {
submenu.addItem_(Self::create_menu_item(item, delegate, actions, keymap));
}
item.setSubmenu_(submenu);
- item.setEnabled_(!disabled);
+ item.setEnabled_(if *disabled { NO } else { YES });
item.setTitle_(ns_string(name));
item
}
@@ -55,7 +55,10 @@ use std::{
path::PathBuf,
ptr::{self, NonNull},
rc::Rc,
- sync::{Arc, Weak},
+ sync::{
+ Arc, Weak,
+ atomic::{AtomicBool, Ordering},
+ },
time::Duration,
};
use util::ResultExt;
@@ -440,6 +443,7 @@ struct MacWindowState {
select_previous_tab_callback: Option<Box<dyn FnMut()>>,
toggle_tab_bar_callback: Option<Box<dyn FnMut()>>,
activated_least_once: bool,
+ closed: Arc<AtomicBool>,
// The parent window if this window is a sheet (Dialog kind)
sheet_parent: Option<id>,
}
@@ -764,6 +768,7 @@ impl MacWindow {
select_previous_tab_callback: None,
toggle_tab_bar_callback: None,
activated_least_once: false,
+ closed: Arc::new(AtomicBool::new(false)),
sheet_parent: None,
})));
@@ -1020,6 +1025,17 @@ impl Drop for MacWindow {
}
}
+/// Calls `f` if the window is not closed.
+///
+/// This should be used when spawning foreground tasks interacting with the
+/// window, as some messages will end hard faulting if dispatched to no longer
+/// valid window handles.
+fn if_window_not_closed(closed: Arc<AtomicBool>, f: impl FnOnce()) {
+ if !closed.load(Ordering::Acquire) {
+ f();
+ }
+}
+
impl PlatformWindow for MacWindow {
fn bounds(&self) -> Bounds<Pixels> {
self.0.as_ref().lock().bounds()
@@ -1040,14 +1056,15 @@ impl PlatformWindow for MacWindow {
fn resize(&mut self, size: Size<Pixels>) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.setContentSize_(NSSize {
width: size.width.as_f32() as f64,
height: size.height.as_f32() as f64,
});
- }
+ })
})
.detach();
}
@@ -1260,15 +1277,21 @@ impl PlatformWindow for MacWindow {
}
});
let block = block.copy();
- let native_window = self.0.lock().native_window;
- let executor = self.0.lock().foreground_executor.clone();
+ let lock = self.0.lock();
+ let native_window = lock.native_window;
+ let closed = lock.closed.clone();
+ let executor = lock.foreground_executor.clone();
executor
.spawn(async move {
- let _: () = msg_send![
- alert,
- beginSheetModalForWindow: native_window
- completionHandler: block
- ];
+ if !closed.load(Ordering::Acquire) {
+ let _: () = msg_send![
+ alert,
+ beginSheetModalForWindow: native_window
+ completionHandler: block
+ ];
+ } else {
+ let _: () = msg_send![alert, release];
+ }
})
.detach();
@@ -1277,12 +1300,16 @@ impl PlatformWindow for MacWindow {
}
fn activate(&self) {
- let window = self.0.lock().native_window;
- let executor = self.0.lock().foreground_executor.clone();
+ let lock = self.0.lock();
+ let window = lock.native_window;
+ let closed = lock.closed.clone();
+ let executor = lock.foreground_executor.clone();
executor
.spawn(async move {
- unsafe {
- let _: () = msg_send![window, makeKeyAndOrderFront: nil];
+ if !closed.load(Ordering::Acquire) {
+ unsafe {
+ let _: () = msg_send![window, makeKeyAndOrderFront: nil];
+ }
}
})
.detach();
@@ -1420,11 +1447,12 @@ impl PlatformWindow for MacWindow {
fn zoom(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.zoom_(nil);
- }
+ })
})
.detach();
}
@@ -1432,11 +1460,12 @@ impl PlatformWindow for MacWindow {
fn toggle_fullscreen(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.toggleFullScreen_(nil);
- }
+ })
})
.detach();
}
@@ -1577,45 +1606,48 @@ impl PlatformWindow for MacWindow {
fn titlebar_double_click(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
- let defaults: id = NSUserDefaults::standardUserDefaults();
- let domain = ns_string("NSGlobalDomain");
- let key = ns_string("AppleActionOnDoubleClick");
-
- let dict: id = msg_send![defaults, persistentDomainForName: domain];
- let action: id = if !dict.is_null() {
- msg_send![dict, objectForKey: key]
- } else {
- nil
- };
+ if_window_not_closed(closed, || {
+ unsafe {
+ let defaults: id = NSUserDefaults::standardUserDefaults();
+ let domain = ns_string("NSGlobalDomain");
+ let key = ns_string("AppleActionOnDoubleClick");
+
+ let dict: id = msg_send![defaults, persistentDomainForName: domain];
+ let action: id = if !dict.is_null() {
+ msg_send![dict, objectForKey: key]
+ } else {
+ nil
+ };
- let action_str = if !action.is_null() {
- CStr::from_ptr(NSString::UTF8String(action)).to_string_lossy()
- } else {
- "".into()
- };
+ let action_str = if !action.is_null() {
+ CStr::from_ptr(NSString::UTF8String(action)).to_string_lossy()
+ } else {
+ "".into()
+ };
- match action_str.as_ref() {
- "None" => {
- // "Do Nothing" selected, so do no action
- }
- "Minimize" => {
- window.miniaturize_(nil);
- }
- "Maximize" => {
- window.zoom_(nil);
- }
- "Fill" => {
- // There is no documented API for "Fill" action, so we'll just zoom the window
- window.zoom_(nil);
- }
- _ => {
- window.zoom_(nil);
+ match action_str.as_ref() {
+ "None" => {
+ // "Do Nothing" selected, so do no action
+ }
+ "Minimize" => {
+ window.miniaturize_(nil);
+ }
+ "Maximize" => {
+ window.zoom_(nil);
+ }
+ "Fill" => {
+ // There is no documented API for "Fill" action, so we'll just zoom the window
+ window.zoom_(nil);
+ }
+ _ => {
+ window.zoom_(nil);
+ }
}
}
- }
+ })
})
.detach();
}
@@ -2185,6 +2217,7 @@ extern "C" fn close_window(this: &Object, _: Sel) {
let close_callback = {
let window_state = get_window_state(this);
let mut lock = window_state.as_ref().lock();
+ lock.closed.store(true, Ordering::Release);
lock.close_callback.take()
};
@@ -22,6 +22,7 @@ pub enum IconName {
AiOllama,
AiOpenAi,
AiOpenAiCompat,
+ AiOpenCode,
AiOpenRouter,
AiVercel,
AiVZero,
@@ -151,6 +152,7 @@ pub enum IconName {
GitCommit,
GitGraph,
GitMergeConflict,
+ GitWorktree,
Github,
Hash,
HistoryRerun,
@@ -4610,7 +4610,7 @@ impl BufferSnapshot {
continue;
}
- let mut all_brackets: Vec<(BracketMatch<usize>, bool)> = Vec::new();
+ let mut all_brackets: Vec<(BracketMatch<usize>, usize, bool)> = Vec::new();
let mut opens = Vec::new();
let mut color_pairs = Vec::new();
@@ -4636,8 +4636,9 @@ impl BufferSnapshot {
let mut open = None;
let mut close = None;
let syntax_layer_depth = mat.depth;
+ let pattern_index = mat.pattern_index;
let config = configs[mat.grammar_index];
- let pattern = &config.patterns[mat.pattern_index];
+ let pattern = &config.patterns[pattern_index];
for capture in mat.captures {
if capture.index == config.open_capture_ix {
open = Some(capture.node.byte_range());
@@ -4658,7 +4659,7 @@ impl BufferSnapshot {
}
open_to_close_ranges
- .entry((open_range.start, open_range.end))
+ .entry((open_range.start, open_range.end, pattern_index))
.or_insert_with(BTreeMap::new)
.insert(
(close_range.start, close_range.end),
@@ -4679,6 +4680,7 @@ impl BufferSnapshot {
newline_only: pattern.newline_only,
color_index: None,
},
+ pattern_index,
pattern.rainbow_exclude,
));
}
@@ -4692,22 +4694,43 @@ impl BufferSnapshot {
// For each close, we know the expected open_len from tree-sitter matches.
// Map each close to its expected open length (for inferring opens)
- let close_to_open_len: HashMap<(usize, usize), usize> = all_brackets
+ let close_to_open_len: HashMap<(usize, usize, usize), usize> = all_brackets
.iter()
- .map(|(m, _)| ((m.close_range.start, m.close_range.end), m.open_range.len()))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ ),
+ bracket_match.open_range.len(),
+ )
+ })
.collect();
// Collect unique opens and closes within this chunk
- let mut unique_opens: HashSet<(usize, usize)> = all_brackets
+ let mut unique_opens: HashSet<(usize, usize, usize)> = all_brackets
.iter()
- .map(|(m, _)| (m.open_range.start, m.open_range.end))
- .filter(|(start, _)| chunk_range.contains(start))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ bracket_match.open_range.start,
+ bracket_match.open_range.end,
+ *pattern_index,
+ )
+ })
+ .filter(|(start, _, _)| chunk_range.contains(start))
.collect();
- let mut unique_closes: Vec<(usize, usize)> = all_brackets
+ let mut unique_closes: Vec<(usize, usize, usize)> = all_brackets
.iter()
- .map(|(m, _)| (m.close_range.start, m.close_range.end))
- .filter(|(start, _)| chunk_range.contains(start))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ )
+ })
+ .filter(|(start, _, _)| chunk_range.contains(start))
.collect();
unique_closes.sort();
unique_closes.dedup();
@@ -4716,8 +4739,9 @@ impl BufferSnapshot {
let mut unique_opens_vec: Vec<_> = unique_opens.iter().copied().collect();
unique_opens_vec.sort();
- let mut valid_pairs: HashSet<((usize, usize), (usize, usize))> = HashSet::default();
- let mut open_stack: Vec<(usize, usize)> = Vec::new();
+ let mut valid_pairs: HashSet<((usize, usize, usize), (usize, usize, usize))> =
+ HashSet::default();
+ let mut open_stacks: HashMap<usize, Vec<(usize, usize)>> = HashMap::default();
let mut open_idx = 0;
for close in &unique_closes {
@@ -4725,36 +4749,53 @@ impl BufferSnapshot {
while open_idx < unique_opens_vec.len()
&& unique_opens_vec[open_idx].0 < close.0
{
- open_stack.push(unique_opens_vec[open_idx]);
+ let (start, end, pattern_index) = unique_opens_vec[open_idx];
+ open_stacks
+ .entry(pattern_index)
+ .or_default()
+ .push((start, end));
open_idx += 1;
}
// Try to match with most recent open
- if let Some(open) = open_stack.pop() {
- valid_pairs.insert((open, *close));
+ let (close_start, close_end, pattern_index) = *close;
+ if let Some(open) = open_stacks
+ .get_mut(&pattern_index)
+ .and_then(|open_stack| open_stack.pop())
+ {
+ valid_pairs.insert(((open.0, open.1, pattern_index), *close));
} else if let Some(&open_len) = close_to_open_len.get(close) {
// No open on stack - infer one based on expected open_len
- if close.0 >= open_len {
- let inferred = (close.0 - open_len, close.0);
+ if close_start >= open_len {
+ let inferred = (close_start - open_len, close_start, pattern_index);
unique_opens.insert(inferred);
valid_pairs.insert((inferred, *close));
all_brackets.push((
BracketMatch {
open_range: inferred.0..inferred.1,
- close_range: close.0..close.1,
+ close_range: close_start..close_end,
newline_only: false,
syntax_layer_depth: 0,
color_index: None,
},
+ pattern_index,
false,
));
}
}
}
- all_brackets.retain(|(m, _)| {
- let open = (m.open_range.start, m.open_range.end);
- let close = (m.close_range.start, m.close_range.end);
+ all_brackets.retain(|(bracket_match, pattern_index, _)| {
+ let open = (
+ bracket_match.open_range.start,
+ bracket_match.open_range.end,
+ *pattern_index,
+ );
+ let close = (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ );
valid_pairs.contains(&(open, close))
});
}
@@ -4762,7 +4803,7 @@ impl BufferSnapshot {
let mut all_brackets = all_brackets
.into_iter()
.enumerate()
- .map(|(index, (bracket_match, rainbow_exclude))| {
+ .map(|(index, (bracket_match, _, rainbow_exclude))| {
// Certain languages have "brackets" that are not brackets, e.g. tags. and such
// bracket will match the entire tag with all text inside.
// For now, avoid highlighting any pair that has more than single char in each bracket.
@@ -47,6 +47,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
+opencode = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
release_channel.workspace = true
@@ -24,6 +24,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
+use crate::provider::opencode::OpenCodeLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
@@ -220,5 +221,9 @@ fn register_language_model_providers(
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
+ registry.register_provider(
+ Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
+ cx,
+ );
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
}
@@ -10,6 +10,7 @@ pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
+pub mod opencode;
mod util;
pub mod vercel;
pub mod vercel_ai_gateway;
@@ -331,15 +331,25 @@ pub fn into_deepseek(
for message in request.messages {
for content in message.content {
match content {
- MessageContent::Text(text) => messages.push(match message.role {
- Role::User => deepseek::RequestMessage::User { content: text },
- Role::Assistant => deepseek::RequestMessage::Assistant {
- content: Some(text),
- tool_calls: Vec::new(),
- reasoning_content: current_reasoning.take(),
- },
- Role::System => deepseek::RequestMessage::System { content: text },
- }),
+ MessageContent::Text(text) => {
+ let should_add = if message.role == Role::User {
+ !text.trim().is_empty()
+ } else {
+ !text.is_empty()
+ };
+
+ if should_add {
+ messages.push(match message.role {
+ Role::User => deepseek::RequestMessage::User { content: text },
+ Role::Assistant => deepseek::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ reasoning_content: current_reasoning.take(),
+ },
+ Role::System => deepseek::RequestMessage::System { content: text },
+ });
+ }
+ }
MessageContent::Thinking { text, .. } => {
// Accumulate reasoning content for next assistant message
current_reasoning.get_or_insert_default().push_str(&text);
@@ -445,7 +455,9 @@ impl DeepSeekEventMapper {
};
let mut events = Vec::new();
- if let Some(content) = choice.delta.content.clone() {
+ if let Some(content) = choice.delta.content.clone()
+ && !content.is_empty()
+ {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
@@ -514,8 +514,7 @@ pub fn into_open_ai(
temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
- // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
- Some(false)
+ Some(supports_parallel_tool_calls)
} else {
None
},
@@ -0,0 +1,646 @@
+use anyhow::Result;
+use collections::BTreeMap;
+use futures::{FutureExt, StreamExt, future::BoxFuture};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use http_client::HttpClient;
+use language_model::{
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var,
+};
+use opencode::{ApiProtocol, OPENCODE_API_URL};
+pub use settings::OpenCodeAvailableModel as AvailableModel;
+use settings::{Settings, SettingsStore};
+use std::sync::{Arc, LazyLock};
+use strum::IntoEnumIterator;
+use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+
+use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
+use crate::provider::google::{GoogleEventMapper, into_google};
+use crate::provider::open_ai::{
+ OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
+};
+
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("opencode");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenCode Zen");
+
+const API_KEY_ENV_VAR_NAME: &str = "OPENCODE_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenCodeSettings {
+ pub api_url: String,
+ pub available_models: Vec<AvailableModel>,
+}
+
+pub struct OpenCodeLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+pub struct State {
+ api_key_state: ApiKeyState,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ }
+}
+
+impl OpenCodeLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state
+ .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ cx.notify();
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: opencode::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(OpenCodeLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ pub fn settings(cx: &App) -> &OpenCodeSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).opencode
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ OPENCODE_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+}
+
+impl LanguageModelProviderState for OpenCodeLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenCodeLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOpenCode)
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(opencode::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(opencode::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ for model in opencode::Model::iter() {
+ if !matches!(model, opencode::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ for model in &Self::settings(cx).available_models {
+ let protocol = match model.protocol.as_str() {
+ "anthropic" => ApiProtocol::Anthropic,
+ "openai_responses" => ApiProtocol::OpenAiResponses,
+ "openai_chat" => ApiProtocol::OpenAiChat,
+ "google" => ApiProtocol::Google,
+ _ => ApiProtocol::OpenAiChat, // default fallback
+ };
+ models.insert(
+ model.name.clone(),
+ opencode::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
+ protocol,
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct OpenCodeLanguageModel {
+ id: LanguageModelId,
+ model: opencode::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+impl OpenCodeLanguageModel {
+ /// Returns the base API URL (e.g., "https://opencode.ai/zen").
+ fn base_api_url(&self, cx: &AsyncApp) -> SharedString {
+ self.state
+ .read_with(cx, |_, cx| OpenCodeLanguageModelProvider::api_url(cx))
+ }
+
+ fn api_key(&self, cx: &AsyncApp) -> Option<Arc<str>> {
+ self.state.read_with(cx, |state, cx| {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ state.api_key_state.key(&api_url)
+ })
+ }
+
+ fn stream_anthropic(
+ &self,
+ request: anthropic::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<anthropic::Event, anthropic::AnthropicError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let http_client = self.http_client.clone();
+ // Anthropic crate appends /v1/messages to api_url
+ let api_url = self.base_api_url(cx);
+ let api_key = self.api_key(cx);
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = anthropic::stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ None,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_openai_chat(
+ &self,
+ request: open_ai::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<open_ai::ResponseStreamEvent>>>,
+ > {
+ let http_client = self.http_client.clone();
+ // OpenAI crate appends /chat/completions to api_url, so we pass base + "/v1"
+ let base_url = self.base_api_url(cx);
+ let api_url: SharedString = format!("{base_url}/v1").into();
+ let api_key = self.api_key(cx);
+ let provider_name = PROVIDER_NAME.0.to_string();
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = open_ai::stream_completion(
+ http_client.as_ref(),
+ &provider_name,
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_openai_response(
+ &self,
+ request: open_ai::responses::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<open_ai::responses::StreamEvent>>>,
+ > {
+ let http_client = self.http_client.clone();
+ // Responses crate appends /responses to api_url, so we pass base + "/v1"
+ let base_url = self.base_api_url(cx);
+ let api_url: SharedString = format!("{base_url}/v1").into();
+ let api_key = self.api_key(cx);
+ let provider_name = PROVIDER_NAME.0.to_string();
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = open_ai::responses::stream_response(
+ http_client.as_ref(),
+ &provider_name,
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_google_zen(
+ &self,
+ request: google_ai::GenerateContentRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<google_ai::GenerateContentResponse>>>,
+ > {
+ let http_client = self.http_client.clone();
+ let api_url = self.base_api_url(cx);
+ let api_key = self.api_key(cx);
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = opencode::stream_generate_content_zen(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for OpenCodeLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images()
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => {
+ // Google models don't support None tool choice
+ self.model.protocol() != ApiProtocol::Google
+ }
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("opencode/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ self.model.max_output_tokens()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ cx.background_spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
+ })
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ match self.model.protocol() {
+ ApiProtocol::Anthropic => {
+ let anthropic_request = into_anthropic(
+ request,
+ self.model.id().to_string(),
+ 1.0,
+ self.model.max_output_tokens().unwrap_or(8192),
+ anthropic::AnthropicModelMode::Default,
+ );
+ let stream = self.stream_anthropic(anthropic_request, cx);
+ async move {
+ let mapper = AnthropicEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::OpenAiChat => {
+ let openai_request = into_open_ai(
+ request,
+ self.model.id(),
+ false,
+ false,
+ self.model.max_output_tokens(),
+ None,
+ );
+ let stream = self.stream_openai_chat(openai_request, cx);
+ async move {
+ let mapper = OpenAiEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::OpenAiResponses => {
+ let response_request = into_open_ai_response(
+ request,
+ self.model.id(),
+ false,
+ false,
+ self.model.max_output_tokens(),
+ None,
+ );
+ let stream = self.stream_openai_response(response_request, cx);
+ async move {
+ let mapper = OpenAiResponseEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::Google => {
+ let google_request = into_google(
+ request,
+ self.model.id().to_string(),
+ google_ai::GoogleModelMode::Default,
+ );
+ let stream = self.stream_google_zen(google_request, cx);
+ async move {
+ let mapper = GoogleEventMapper::new();
+ Ok(mapper.map_stream(stream.await?.boxed()).boxed())
+ }
+ .boxed()
+ }
+ }
+ }
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ load_credentials_task: Option<Task<()>>,
+}
+
+impl ConfigurationView {
+ fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ InputField::new(window, cx, "sk-00000000000000000000000000000000").label("API key")
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ } else {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ if api_url == OPENCODE_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ let api_key_section = if self.should_render_editor(cx) {
+ v_flex()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new(
+ "To use OpenCode Zen models in Zed, you need an API key:",
+ ))
+ .child(
+ List::new()
+ .child(
+ ListBulletItem::new("")
+ .child(Label::new("Sign in and get your key at"))
+ .child(ButtonLink::new(
+ "OpenCode Zen Console",
+ "https://opencode.ai/zen",
+ )),
+ )
+ .child(ListBulletItem::new(
+ "Paste your API key below and hit enter to start using OpenCode Zen",
+ )),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(format!(
+ "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!(
+ "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
+ ))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .into_any_element()
+ };
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials...")).into_any()
+ } else {
+ v_flex().size_full().child(api_key_section).into_any()
+ }
+ }
+}
@@ -8,7 +8,8 @@ use crate::provider::{
deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings,
mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings,
open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings,
- vercel::VercelSettings, vercel_ai_gateway::VercelAiGatewaySettings, x_ai::XAiSettings,
+ opencode::OpenCodeSettings, vercel::VercelSettings, vercel_ai_gateway::VercelAiGatewaySettings,
+ x_ai::XAiSettings,
};
#[derive(Debug, RegisterSetting)]
@@ -20,6 +21,7 @@ pub struct AllLanguageModelSettings {
pub lmstudio: LmStudioSettings,
pub mistral: MistralSettings,
pub ollama: OllamaSettings,
+ pub opencode: OpenCodeSettings,
pub open_router: OpenRouterSettings,
pub openai: OpenAiSettings,
pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>,
@@ -41,6 +43,7 @@ impl settings::Settings for AllLanguageModelSettings {
let lmstudio = language_models.lmstudio.unwrap();
let mistral = language_models.mistral.unwrap();
let ollama = language_models.ollama.unwrap();
+ let opencode = language_models.opencode.unwrap();
let open_router = language_models.open_router.unwrap();
let openai = language_models.openai.unwrap();
let openai_compatible = language_models.openai_compatible.unwrap();
@@ -85,6 +88,10 @@ impl settings::Settings for AllLanguageModelSettings {
available_models: ollama.available_models.unwrap_or_default(),
context_window: ollama.context_window,
},
+ opencode: OpenCodeSettings {
+ api_url: opencode.api_url.unwrap(),
+ available_models: opencode.available_models.unwrap_or_default(),
+ },
open_router: OpenRouterSettings {
api_url: open_router.api_url.unwrap(),
available_models: open_router.available_models.unwrap_or_default(),
@@ -247,7 +247,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -294,6 +293,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -191,7 +191,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
context: Some(python_context_provider),
toolchain: Some(python_toolchain_provider),
manifest_name: Some(SharedString::new_static("pyproject.toml").into()),
- ..Default::default()
+ semantic_token_rules: Some(python::semantic_token_rules()),
},
LanguageInfo {
name: "rust",
@@ -24,7 +24,7 @@ use project::lsp_store::language_server_settings;
use semver::Version;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
-use settings::Settings;
+use settings::{SemanticTokenRules, Settings};
use terminal::terminal_settings::TerminalSettings;
use smol::lock::OnceCell;
@@ -37,6 +37,7 @@ use util::fs::{make_file_executable, remove_matching};
use util::paths::PathStyle;
use util::rel_path::RelPath;
+use crate::LanguageDir;
use http_client::github_download::{GithubBinaryMetadata, download_server_binary};
use parking_lot::Mutex;
use std::str::FromStr;
@@ -49,6 +50,14 @@ use std::{
use task::{ShellKind, TaskTemplate, TaskTemplates, VariableName};
use util::{ResultExt, maybe};
+pub(crate) fn semantic_token_rules() -> SemanticTokenRules {
+ let content = LanguageDir::get("python/semantic_token_rules.json")
+ .expect("missing python/semantic_token_rules.json");
+ let json = std::str::from_utf8(&content.data).expect("invalid utf-8 in semantic_token_rules");
+ settings::parse_json_with_comments::<SemanticTokenRules>(json)
+ .expect("failed to parse python semantic_token_rules.json")
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct PythonToolchainData {
#[serde(flatten)]
@@ -0,0 +1,15 @@
+[
+ {
+ "token_type": "selfParameter",
+ "style": ["variable.special"]
+ },
+ {
+ "token_type": "clsParameter",
+ "style": ["variable.special"]
+ },
+ // ty specific
+ {
+ "token_type": "builtinConstant",
+ "style": ["constant.builtin"]
+ }
+]
@@ -268,7 +268,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -318,6 +317,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -387,7 +387,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -437,6 +436,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -49,6 +49,7 @@ livekit.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
tokio = { workspace = true, features = ["time"] }
+webrtc-sys.workspace = true
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "windows"))'.dependencies]
scap.workspace = true
@@ -14,6 +14,7 @@ use std::sync::{
};
static NEXT_WAYLAND_SHARE_ID: AtomicU64 = AtomicU64::new(1);
+const PIPEWIRE_TIMEOUT_S: u64 = 30;
pub struct WaylandScreenCaptureStream {
id: u64,
@@ -64,6 +65,17 @@ pub(crate) async fn start_wayland_desktop_capture(
};
use libwebrtc::native::yuv_helper::argb_to_nv12;
use std::time::Duration;
+ use webrtc_sys::webrtc::ffi as webrtc_ffi;
+
+ fn webrtc_log_callback(message: String, severity: webrtc_ffi::LoggingSeverity) {
+ match severity {
+ webrtc_ffi::LoggingSeverity::Error => log::error!("[webrtc] {}", message.trim()),
+ _ => log::debug!("[webrtc] {}", message.trim()),
+ }
+ }
+
+ let _webrtc_log_sink = webrtc_ffi::new_log_sink(webrtc_log_callback);
+ log::debug!("Wayland desktop capture: WebRTC internal logging enabled");
let stop_flag = Arc::new(AtomicBool::new(false));
let (mut video_source_tx, mut video_source_rx) = mpsc::channel::<NativeVideoSource>(1);
@@ -79,7 +91,6 @@ pub(crate) async fn start_wayland_desktop_capture(
})?;
let permanent_error = Arc::new(AtomicBool::new(false));
-
let stop_cb = stop_flag.clone();
let permanent_error_cb = permanent_error.clone();
capturer.start_capture(None, {
@@ -136,6 +147,8 @@ pub(crate) async fn start_wayland_desktop_capture(
}
});
+ log::info!("Wayland desktop capture: starting capture loop");
+
let stop = stop_flag.clone();
let tokio_task = gpui_tokio::Tokio::spawn(cx, async move {
loop {
@@ -162,10 +175,11 @@ pub(crate) async fn start_wayland_desktop_capture(
let executor = cx.background_executor().clone();
let video_source = video_source_rx
.next()
- .with_timeout(Duration::from_secs(15), &executor)
+ .with_timeout(Duration::from_secs(PIPEWIRE_TIMEOUT_S), &executor)
.await
.map_err(|_| {
stop_flag.store(true, Ordering::Relaxed);
+ log::error!("Wayland desktop capture timed out.");
anyhow::anyhow!(
"Screen sharing timed out waiting for the first frame. \
Check that xdg-desktop-portal and PipeWire are running, \
@@ -0,0 +1,27 @@
+[package]
+name = "opencode"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/opencode.rs"
+test = false
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+google_ai.workspace = true
+http_client.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+strum.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,453 @@
+use anyhow::{Result, anyhow};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
+
+pub const OPENCODE_API_URL: &str = "https://opencode.ai/zen";
+
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[serde(rename_all = "snake_case")]
+pub enum ApiProtocol {
+ #[default]
+ Anthropic,
+ OpenAiResponses,
+ OpenAiChat,
+ Google,
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
+pub enum Model {
+ // -- Anthropic protocol models --
+ #[serde(rename = "claude-opus-4-6")]
+ ClaudeOpus4_6,
+ #[serde(rename = "claude-opus-4-5")]
+ ClaudeOpus4_5,
+ #[serde(rename = "claude-opus-4-1")]
+ ClaudeOpus4_1,
+ #[default]
+ #[serde(rename = "claude-sonnet-4-6")]
+ ClaudeSonnet4_6,
+ #[serde(rename = "claude-sonnet-4-5")]
+ ClaudeSonnet4_5,
+ #[serde(rename = "claude-sonnet-4")]
+ ClaudeSonnet4,
+ #[serde(rename = "claude-haiku-4-5")]
+ ClaudeHaiku4_5,
+ #[serde(rename = "claude-3-5-haiku")]
+ Claude3_5Haiku,
+
+ // -- OpenAI Responses API models --
+ #[serde(rename = "gpt-5.4")]
+ Gpt5_4,
+ #[serde(rename = "gpt-5.4-pro")]
+ Gpt5_4Pro,
+ #[serde(rename = "gpt-5.4-mini")]
+ Gpt5_4Mini,
+ #[serde(rename = "gpt-5.4-nano")]
+ Gpt5_4Nano,
+ #[serde(rename = "gpt-5.3-codex")]
+ Gpt5_3Codex,
+ #[serde(rename = "gpt-5.3-codex-spark")]
+ Gpt5_3Spark,
+ #[serde(rename = "gpt-5.2")]
+ Gpt5_2,
+ #[serde(rename = "gpt-5.2-codex")]
+ Gpt5_2Codex,
+ #[serde(rename = "gpt-5.1")]
+ Gpt5_1,
+ #[serde(rename = "gpt-5.1-codex")]
+ Gpt5_1Codex,
+ #[serde(rename = "gpt-5.1-codex-max")]
+ Gpt5_1CodexMax,
+ #[serde(rename = "gpt-5.1-codex-mini")]
+ Gpt5_1CodexMini,
+ #[serde(rename = "gpt-5")]
+ Gpt5,
+ #[serde(rename = "gpt-5-codex")]
+ Gpt5Codex,
+ #[serde(rename = "gpt-5-nano")]
+ Gpt5Nano,
+
+ // -- Google protocol models --
+ #[serde(rename = "gemini-3.1-pro")]
+ Gemini3_1Pro,
+ #[serde(rename = "gemini-3-flash")]
+ Gemini3Flash,
+
+ // -- OpenAI Chat Completions protocol models --
+ #[serde(rename = "minimax-m2.5")]
+ MiniMaxM2_5,
+ #[serde(rename = "minimax-m2.5-free")]
+ MiniMaxM2_5Free,
+ #[serde(rename = "glm-5")]
+ Glm5,
+ #[serde(rename = "kimi-k2.5")]
+ KimiK2_5,
+ #[serde(rename = "mimo-v2-pro-free")]
+ MimoV2ProFree,
+ #[serde(rename = "mimo-v2-omni-free")]
+ MimoV2OmniFree,
+ #[serde(rename = "mimo-v2-flash-free")]
+ MimoV2FlashFree,
+ #[serde(rename = "trinity-large-preview-free")]
+ TrinityLargePreviewFree,
+ #[serde(rename = "big-pickle")]
+ BigPickle,
+ #[serde(rename = "nemotron-3-super-free")]
+ Nemotron3SuperFree,
+
+ // -- Custom model --
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ display_name: Option<String>,
+ max_tokens: u64,
+ max_output_tokens: Option<u64>,
+ protocol: ApiProtocol,
+ },
+}
+
+impl Model {
+ pub fn default_fast() -> Self {
+ Self::ClaudeHaiku4_5
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ Self::ClaudeOpus4_6 => "claude-opus-4-6",
+ Self::ClaudeOpus4_5 => "claude-opus-4-5",
+ Self::ClaudeOpus4_1 => "claude-opus-4-1",
+ Self::ClaudeSonnet4_6 => "claude-sonnet-4-6",
+ Self::ClaudeSonnet4_5 => "claude-sonnet-4-5",
+ Self::ClaudeSonnet4 => "claude-sonnet-4",
+ Self::ClaudeHaiku4_5 => "claude-haiku-4-5",
+ Self::Claude3_5Haiku => "claude-3-5-haiku",
+
+ Self::Gpt5_4 => "gpt-5.4",
+ Self::Gpt5_4Pro => "gpt-5.4-pro",
+ Self::Gpt5_4Mini => "gpt-5.4-mini",
+ Self::Gpt5_4Nano => "gpt-5.4-nano",
+ Self::Gpt5_3Codex => "gpt-5.3-codex",
+ Self::Gpt5_3Spark => "gpt-5.3-codex-spark",
+ Self::Gpt5_2 => "gpt-5.2",
+ Self::Gpt5_2Codex => "gpt-5.2-codex",
+ Self::Gpt5_1 => "gpt-5.1",
+ Self::Gpt5_1Codex => "gpt-5.1-codex",
+ Self::Gpt5_1CodexMax => "gpt-5.1-codex-max",
+ Self::Gpt5_1CodexMini => "gpt-5.1-codex-mini",
+ Self::Gpt5 => "gpt-5",
+ Self::Gpt5Codex => "gpt-5-codex",
+ Self::Gpt5Nano => "gpt-5-nano",
+
+ Self::Gemini3_1Pro => "gemini-3.1-pro",
+ Self::Gemini3Flash => "gemini-3-flash",
+
+ Self::MiniMaxM2_5 => "minimax-m2.5",
+ Self::MiniMaxM2_5Free => "minimax-m2.5-free",
+ Self::Glm5 => "glm-5",
+ Self::KimiK2_5 => "kimi-k2.5",
+ Self::MimoV2ProFree => "mimo-v2-pro-free",
+ Self::MimoV2OmniFree => "mimo-v2-omni-free",
+ Self::MimoV2FlashFree => "mimo-v2-flash-free",
+ Self::TrinityLargePreviewFree => "trinity-large-preview-free",
+ Self::BigPickle => "big-pickle",
+ Self::Nemotron3SuperFree => "nemotron-3-super-free",
+
+ Self::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::ClaudeOpus4_6 => "Claude Opus 4.6",
+ Self::ClaudeOpus4_5 => "Claude Opus 4.5",
+ Self::ClaudeOpus4_1 => "Claude Opus 4.1",
+ Self::ClaudeSonnet4_6 => "Claude Sonnet 4.6",
+ Self::ClaudeSonnet4_5 => "Claude Sonnet 4.5",
+ Self::ClaudeSonnet4 => "Claude Sonnet 4",
+ Self::ClaudeHaiku4_5 => "Claude Haiku 4.5",
+ Self::Claude3_5Haiku => "Claude Haiku 3.5",
+
+ Self::Gpt5_4 => "GPT 5.4",
+ Self::Gpt5_4Pro => "GPT 5.4 Pro",
+ Self::Gpt5_4Mini => "GPT 5.4 Mini",
+ Self::Gpt5_4Nano => "GPT 5.4 Nano",
+ Self::Gpt5_3Codex => "GPT 5.3 Codex",
+ Self::Gpt5_3Spark => "GPT 5.3 Codex Spark",
+ Self::Gpt5_2 => "GPT 5.2",
+ Self::Gpt5_2Codex => "GPT 5.2 Codex",
+ Self::Gpt5_1 => "GPT 5.1",
+ Self::Gpt5_1Codex => "GPT 5.1 Codex",
+ Self::Gpt5_1CodexMax => "GPT 5.1 Codex Max",
+ Self::Gpt5_1CodexMini => "GPT 5.1 Codex Mini",
+ Self::Gpt5 => "GPT 5",
+ Self::Gpt5Codex => "GPT 5 Codex",
+ Self::Gpt5Nano => "GPT 5 Nano",
+
+ Self::Gemini3_1Pro => "Gemini 3.1 Pro",
+ Self::Gemini3Flash => "Gemini 3 Flash",
+
+ Self::MiniMaxM2_5 => "MiniMax M2.5",
+ Self::MiniMaxM2_5Free => "MiniMax M2.5 Free",
+ Self::Glm5 => "GLM 5",
+ Self::KimiK2_5 => "Kimi K2.5",
+ Self::MimoV2ProFree => "MiMo V2 Pro Free",
+ Self::MimoV2OmniFree => "MiMo V2 Omni Free",
+ Self::MimoV2FlashFree => "MiMo V2 Flash Free",
+ Self::TrinityLargePreviewFree => "Trinity Large Preview Free",
+ Self::BigPickle => "Big Pickle",
+ Self::Nemotron3SuperFree => "Nemotron 3 Super Free",
+
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_deref().unwrap_or(name),
+ }
+ }
+
+ pub fn protocol(&self) -> ApiProtocol {
+ match self {
+ Self::ClaudeOpus4_6
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_6
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5
+ | Self::Claude3_5Haiku => ApiProtocol::Anthropic,
+
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => ApiProtocol::OpenAiResponses,
+
+ Self::Gemini3_1Pro | Self::Gemini3Flash => ApiProtocol::Google,
+
+ Self::MiniMaxM2_5
+ | Self::MiniMaxM2_5Free
+ | Self::Glm5
+ | Self::KimiK2_5
+ | Self::MimoV2ProFree
+ | Self::MimoV2OmniFree
+ | Self::MimoV2FlashFree
+ | Self::TrinityLargePreviewFree
+ | Self::BigPickle
+ | Self::Nemotron3SuperFree => ApiProtocol::OpenAiChat,
+
+ Self::Custom { protocol, .. } => *protocol,
+ }
+ }
+
+ pub fn max_token_count(&self) -> u64 {
+ match self {
+ // Anthropic models
+ Self::ClaudeOpus4_6 | Self::ClaudeSonnet4_6 => 1_000_000,
+ Self::ClaudeOpus4_5 | Self::ClaudeSonnet4_5 | Self::ClaudeSonnet4 => 200_000,
+ Self::ClaudeOpus4_1 => 200_000,
+ Self::ClaudeHaiku4_5 => 200_000,
+ Self::Claude3_5Haiku => 200_000,
+
+ // OpenAI models
+ Self::Gpt5_4 | Self::Gpt5_4Pro => 1_050_000,
+ Self::Gpt5_4Mini | Self::Gpt5_4Nano => 400_000,
+ Self::Gpt5_3Codex => 400_000,
+ Self::Gpt5_3Spark => 128_000,
+ Self::Gpt5_2 | Self::Gpt5_2Codex => 400_000,
+ Self::Gpt5_1 | Self::Gpt5_1Codex | Self::Gpt5_1CodexMax | Self::Gpt5_1CodexMini => {
+ 400_000
+ }
+ Self::Gpt5 | Self::Gpt5Codex | Self::Gpt5Nano => 400_000,
+
+ // Google models
+ Self::Gemini3_1Pro => 1_048_576,
+ Self::Gemini3Flash => 1_048_576,
+
+ // OpenAI-compatible models
+ Self::MiniMaxM2_5 | Self::MiniMaxM2_5Free => 196_608,
+ Self::Glm5 => 200_000,
+ Self::KimiK2_5 => 262_144,
+ Self::MimoV2ProFree => 1_048_576,
+ Self::MimoV2OmniFree | Self::MimoV2FlashFree => 262_144,
+ Self::TrinityLargePreviewFree => 131_072,
+ Self::BigPickle => 200_000,
+ Self::Nemotron3SuperFree => 262_144,
+
+ Self::Custom { max_tokens, .. } => *max_tokens,
+ }
+ }
+
+ pub fn max_output_tokens(&self) -> Option<u64> {
+ match self {
+ // Anthropic models
+ Self::ClaudeOpus4_6 => Some(128_000),
+ Self::ClaudeSonnet4_6 => Some(64_000),
+ Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5 => Some(64_000),
+ Self::Claude3_5Haiku => Some(8_192),
+
+ // OpenAI models
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => Some(128_000),
+
+ // Google models
+ Self::Gemini3_1Pro | Self::Gemini3Flash => Some(65_536),
+
+ // OpenAI-compatible models
+ Self::MiniMaxM2_5 | Self::MiniMaxM2_5Free => Some(65_536),
+ Self::Glm5 | Self::BigPickle => Some(128_000),
+ Self::KimiK2_5 => Some(65_536),
+ Self::MimoV2ProFree => Some(131_072),
+ Self::MimoV2OmniFree | Self::MimoV2FlashFree => Some(65_536),
+ Self::TrinityLargePreviewFree | Self::Nemotron3SuperFree => Some(16_384),
+
+ Self::Custom {
+ max_output_tokens, ..
+ } => *max_output_tokens,
+ }
+ }
+
+ pub fn supports_tools(&self) -> bool {
+ true
+ }
+
+ pub fn supports_images(&self) -> bool {
+ match self {
+ // Anthropic models support images
+ Self::ClaudeOpus4_6
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_6
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5
+ | Self::Claude3_5Haiku => true,
+
+ // OpenAI models support images
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => true,
+
+ // Google models support images
+ Self::Gemini3_1Pro | Self::Gemini3Flash => true,
+
+ // OpenAI-compatible models β conservative default
+ Self::MiniMaxM2_5
+ | Self::MiniMaxM2_5Free
+ | Self::Glm5
+ | Self::KimiK2_5
+ | Self::MimoV2ProFree
+ | Self::MimoV2OmniFree
+ | Self::MimoV2FlashFree
+ | Self::TrinityLargePreviewFree
+ | Self::BigPickle
+ | Self::Nemotron3SuperFree => false,
+
+ Self::Custom { protocol, .. } => matches!(
+ protocol,
+ ApiProtocol::Anthropic
+ | ApiProtocol::OpenAiResponses
+ | ApiProtocol::OpenAiChat
+ | ApiProtocol::Google
+ ),
+ }
+ }
+}
+
+/// Stream generate content for Google models via OpenCode Zen.
+///
+/// Unlike `google_ai::stream_generate_content()`, this uses:
+/// - `/v1/models/{model}` path (not `/v1beta/models/{model}`)
+/// - `Authorization: Bearer` header (not `key=` query param)
+pub async fn stream_generate_content_zen(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: google_ai::GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<google_ai::GenerateContentResponse>>> {
+ let api_key = api_key.trim();
+
+ let model_id = &request.model.model_id;
+
+ let uri = format!("{api_url}/v1/models/{model_id}:streamGenerateContent?alt=sse");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {api_key}"));
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ if let Some(line) = line.strip_prefix("data: ") {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => {
+ Some(Err(anyhow!("Error parsing JSON: {error:?}\n{line:?}")))
+ }
+ }
+ } else {
+ None
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ Err(anyhow!(
+ "error during streamGenerateContent via OpenCode Zen, status code: {:?}, body: {}",
+ response.status(),
+ text
+ ))
+ }
+}
@@ -3,9 +3,9 @@ mod system_window_tabs;
use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
use gpui::{
- AnyElement, App, Context, Decorations, Entity, Hsla, InteractiveElement, IntoElement,
- MouseButton, ParentElement, StatefulInteractiveElement, Styled, Window, WindowControlArea, div,
- px,
+ Action, AnyElement, App, Context, Decorations, Entity, Hsla, InteractiveElement, IntoElement,
+ MouseButton, ParentElement, StatefulInteractiveElement, Styled, Window, WindowButtonLayout,
+ WindowControlArea, div, px,
};
use project::DisableAiSettings;
use settings::Settings;
@@ -31,6 +31,7 @@ pub struct PlatformTitleBar {
children: SmallVec<[AnyElement; 2]>,
should_move: bool,
system_window_tabs: Entity<SystemWindowTabs>,
+ button_layout: Option<WindowButtonLayout>,
workspace_sidebar_open: bool,
}
@@ -45,6 +46,7 @@ impl PlatformTitleBar {
children: SmallVec::new(),
should_move: false,
system_window_tabs,
+ button_layout: None,
workspace_sidebar_open: false,
}
}
@@ -68,6 +70,24 @@ impl PlatformTitleBar {
self.children = children.into_iter().collect();
}
+ pub fn set_button_layout(&mut self, button_layout: Option<WindowButtonLayout>) {
+ self.button_layout = button_layout;
+ }
+
+ fn effective_button_layout(
+ &self,
+ decorations: &Decorations,
+ cx: &App,
+ ) -> Option<WindowButtonLayout> {
+ if self.platform_style == PlatformStyle::Linux
+ && matches!(decorations, Decorations::Client { .. })
+ {
+ self.button_layout.or_else(|| cx.button_layout())
+ } else {
+ None
+ }
+ }
+
pub fn init(cx: &mut App) {
SystemWindowTabs::init(cx);
}
@@ -95,6 +115,7 @@ impl Render for PlatformTitleBar {
let close_action = Box::new(workspace::CloseWindow);
let children = mem::take(&mut self.children);
+ let button_layout = self.effective_button_layout(&decorations, cx);
let is_multiworkspace_sidebar_open =
PlatformTitleBar::is_multi_workspace_enabled(cx) && self.is_workspace_sidebar_open();
@@ -150,6 +171,14 @@ impl Render for PlatformTitleBar {
&& !is_multiworkspace_sidebar_open
{
this.pl(px(TRAFFIC_LIGHT_PADDING))
+ } else if let Some(button_layout) =
+ button_layout.filter(|button_layout| button_layout.left[0].is_some())
+ {
+ this.child(platform_linux::LinuxWindowControls::new(
+ "left-window-controls",
+ button_layout.left,
+ close_action.as_ref().boxed_clone(),
+ ))
} else {
this.pl_2()
}
@@ -188,14 +217,22 @@ impl Render for PlatformTitleBar {
PlatformStyle::Mac => title_bar,
PlatformStyle::Linux => {
if matches!(decorations, Decorations::Client { .. }) {
- title_bar
- .child(platform_linux::LinuxWindowControls::new(close_action))
- .when(supported_controls.window_menu, |titlebar| {
- titlebar
- .on_mouse_down(MouseButton::Right, move |ev, window, _| {
- window.show_window_menu(ev.position)
- })
+ let mut result = title_bar;
+ if let Some(button_layout) = button_layout
+ .filter(|button_layout| button_layout.right[0].is_some())
+ {
+ result = result.child(platform_linux::LinuxWindowControls::new(
+ "right-window-controls",
+ button_layout.right,
+ close_action.as_ref().boxed_clone(),
+ ));
+ }
+
+ result.when(supported_controls.window_menu, |titlebar| {
+ titlebar.on_mouse_down(MouseButton::Right, move |ev, window, _| {
+ window.show_window_menu(ev.position)
})
+ })
} else {
title_bar
}
@@ -1,46 +1,83 @@
-use gpui::{Action, Hsla, MouseButton, prelude::*, svg};
+use gpui::{
+ Action, AnyElement, Hsla, MAX_BUTTONS_PER_SIDE, MouseButton, WindowButton, prelude::*, svg,
+};
use ui::prelude::*;
#[derive(IntoElement)]
pub struct LinuxWindowControls {
- close_window_action: Box<dyn Action>,
+ id: &'static str,
+ buttons: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ close_action: Box<dyn Action>,
}
impl LinuxWindowControls {
- pub fn new(close_window_action: Box<dyn Action>) -> Self {
+ pub fn new(
+ id: &'static str,
+ buttons: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ close_action: Box<dyn Action>,
+ ) -> Self {
Self {
- close_window_action,
+ id,
+ buttons,
+ close_action,
}
}
}
impl RenderOnce for LinuxWindowControls {
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
+ let is_maximized = window.is_maximized();
+ let supported_controls = window.window_controls();
+ let button_elements: Vec<AnyElement> = self
+ .buttons
+ .iter()
+ .filter_map(|b| *b)
+ .filter(|button| match button {
+ WindowButton::Minimize => supported_controls.minimize,
+ WindowButton::Maximize => supported_controls.maximize,
+ WindowButton::Close => true,
+ })
+ .map(|button| {
+ create_window_button(button, button.id(), is_maximized, &*self.close_action, cx)
+ })
+ .collect();
+
h_flex()
- .id("generic-window-controls")
- .px_3()
- .gap_3()
- .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
- .child(WindowControl::new(
- "minimize",
- WindowControlType::Minimize,
- cx,
- ))
- .child(WindowControl::new(
- "maximize-or-restore",
- if window.is_maximized() {
- WindowControlType::Restore
- } else {
- WindowControlType::Maximize
- },
- cx,
- ))
- .child(WindowControl::new_close(
- "close",
- WindowControlType::Close,
- self.close_window_action,
- cx,
- ))
+ .id(self.id)
+ .when(!button_elements.is_empty(), |el| {
+ el.gap_3()
+ .px_3()
+ .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
+ .children(button_elements)
+ })
+ }
+}
+
+fn create_window_button(
+ button: WindowButton,
+ id: &'static str,
+ is_maximized: bool,
+ close_action: &dyn Action,
+ cx: &mut App,
+) -> AnyElement {
+ match button {
+ WindowButton::Minimize => {
+ WindowControl::new(id, WindowControlType::Minimize, cx).into_any_element()
+ }
+ WindowButton::Maximize => WindowControl::new(
+ id,
+ if is_maximized {
+ WindowControlType::Restore
+ } else {
+ WindowControlType::Maximize
+ },
+ cx,
+ )
+ .into_any_element(),
+ WindowButton::Close => {
+ WindowControl::new_close(id, WindowControlType::Close, close_action.boxed_clone(), cx)
+ .into_any_element()
+ }
}
}
@@ -45,6 +45,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
context_server.workspace = true
+credentials_provider.workspace = true
dap.workspace = true
extension.workspace = true
fancy-regex.workspace = true
@@ -7,10 +7,16 @@ use std::time::Duration;
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
+use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
+use context_server::transport::{HttpTransport, TransportError};
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::{FutureExt as _, future::Either, future::join_all};
+use credentials_provider::CredentialsProvider;
+use futures::future::Either;
+use futures::{FutureExt as _, StreamExt as _, future::join_all};
use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
+use http_client::HttpClient;
use itertools::Itertools;
+use rand::Rng as _;
use registry::ContextServerDescriptorRegistry;
use remote::RemoteClient;
use rpc::{AnyProtoClient, TypedEnvelope, proto};
@@ -45,6 +51,12 @@ pub enum ContextServerStatus {
Running,
Stopped,
Error(Arc<str>),
+ /// The server returned 401 and OAuth authorization is needed. The UI
+ /// should show an "Authenticate" button.
+ AuthRequired,
+ /// The OAuth browser flow is in progress β the user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating,
}
impl ContextServerStatus {
@@ -54,6 +66,8 @@ impl ContextServerStatus {
ContextServerState::Running { .. } => ContextServerStatus::Running,
ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
+ ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
+ ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
}
}
}
@@ -77,24 +91,42 @@ enum ContextServerState {
configuration: Arc<ContextServerConfiguration>,
error: Arc<str>,
},
+ /// The server requires OAuth authorization before it can be used. The
+ /// `OAuthDiscovery` holds everything needed to start the browser flow.
+ AuthRequired {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ discovery: Arc<OAuthDiscovery>,
+ },
+ /// The OAuth browser flow is in progress. The user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ _task: Task<()>,
+ },
}
impl ContextServerState {
pub fn server(&self) -> Arc<ContextServer> {
match self {
- ContextServerState::Starting { server, .. } => server.clone(),
- ContextServerState::Running { server, .. } => server.clone(),
- ContextServerState::Stopped { server, .. } => server.clone(),
- ContextServerState::Error { server, .. } => server.clone(),
+ ContextServerState::Starting { server, .. }
+ | ContextServerState::Running { server, .. }
+ | ContextServerState::Stopped { server, .. }
+ | ContextServerState::Error { server, .. }
+ | ContextServerState::AuthRequired { server, .. }
+ | ContextServerState::Authenticating { server, .. } => server.clone(),
}
}
pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
match self {
- ContextServerState::Starting { configuration, .. } => configuration.clone(),
- ContextServerState::Running { configuration, .. } => configuration.clone(),
- ContextServerState::Stopped { configuration, .. } => configuration.clone(),
- ContextServerState::Error { configuration, .. } => configuration.clone(),
+ ContextServerState::Starting { configuration, .. }
+ | ContextServerState::Running { configuration, .. }
+ | ContextServerState::Stopped { configuration, .. }
+ | ContextServerState::Error { configuration, .. }
+ | ContextServerState::AuthRequired { configuration, .. }
+ | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
}
}
}
@@ -126,6 +158,15 @@ impl ContextServerConfiguration {
}
}
+ pub fn has_static_auth_header(&self) -> bool {
+ match self {
+ ContextServerConfiguration::Http { headers, .. } => headers
+ .keys()
+ .any(|k| k.eq_ignore_ascii_case("authorization")),
+ _ => false,
+ }
+ }
+
pub fn remote(&self) -> bool {
match self {
ContextServerConfiguration::Custom { remote, .. } => *remote,
@@ -517,9 +558,10 @@ impl ContextServerStore {
pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
cx.spawn(async move |this, cx| {
let this = this.upgrade().context("Context server store dropped")?;
+ let id = server.id();
let settings = this
.update(cx, |this, _| {
- this.context_server_settings.get(&server.id().0).cloned()
+ this.context_server_settings.get(&id.0).cloned()
})
.context("Failed to get context server settings")?;
@@ -532,7 +574,7 @@ impl ContextServerStore {
});
let configuration = ContextServerConfiguration::from_settings(
settings,
- server.id(),
+ id.clone(),
registry,
worktree_store,
cx,
@@ -590,7 +632,11 @@ impl ContextServerStore {
let id = server.id();
if matches!(
self.servers.get(&id),
- Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
+ Some(
+ ContextServerState::Starting { .. }
+ | ContextServerState::Running { .. }
+ | ContextServerState::Authenticating { .. },
+ )
) {
self.stop_server(&id, cx).log_err();
}
@@ -600,38 +646,20 @@ impl ContextServerStore {
let configuration = configuration.clone();
async move |this, cx| {
- match server.clone().start(cx).await {
+ let new_state = match server.clone().start(cx).await {
Ok(_) => {
debug_assert!(server.client().is_some());
-
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Running {
- server,
- configuration,
- },
- cx,
- )
- })
- .log_err()
- }
- Err(err) => {
- log::error!("{} context server failed to start: {}", id, err);
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Error {
- configuration,
- server,
- error: err.to_string().into(),
- },
- cx,
- )
- })
- .log_err()
+ ContextServerState::Running {
+ server,
+ configuration,
+ }
}
+ Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
};
+ this.update(cx, |this, cx| {
+ this.update_server_state(id.clone(), new_state, cx)
+ })
+ .log_err();
}
});
@@ -651,6 +679,20 @@ impl ContextServerStore {
.servers
.remove(id)
.context("Context server not found")?;
+
+ if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
+ let server_url = url.clone();
+ let id = id.clone();
+ cx.spawn(async move |_this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
+ {
+ log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
+ }
+ })
+ .detach();
+ }
+
drop(state);
cx.emit(ServerStatusChangedEvent {
server_id: id.clone(),
@@ -742,29 +784,71 @@ impl ContextServerStore {
configuration
};
+ if let Some(server) = this.update(cx, |this, _| {
+ this.context_server_factory
+ .as_ref()
+ .map(|factory| factory(id.clone(), configuration.clone()))
+ })? {
+ return Ok((server, configuration));
+ }
+
+ let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
+ if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
+ if configuration.has_static_auth_header() {
+ None
+ } else {
+ let credentials_provider =
+ cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match Self::load_session(&credentials_provider, url, &cx).await {
+ Ok(Some(session)) => {
+ log::info!("{} loaded cached OAuth session from keychain", id);
+ Some(Self::create_oauth_token_provider(
+ &id,
+ url,
+ session,
+ http_client,
+ credentials_provider,
+ cx,
+ ))
+ }
+ Ok(None) => None,
+ Err(err) => {
+ log::warn!("{} failed to load cached OAuth session: {}", id, err);
+ None
+ }
+ }
+ }
+ } else {
+ None
+ };
+
let server: Arc<ContextServer> = this.update(cx, |this, cx| {
let global_timeout =
Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
- if let Some(factory) = this.context_server_factory.as_ref() {
- return anyhow::Ok(factory(id.clone(), configuration.clone()));
- }
-
match configuration.as_ref() {
ContextServerConfiguration::Http {
url,
headers,
timeout,
- } => anyhow::Ok(Arc::new(ContextServer::http(
- id,
- url,
- headers.clone(),
- cx.http_client(),
- cx.background_executor().clone(),
- Some(Duration::from_secs(
- timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
- )),
- )?)),
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ cx.http_client(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ cached_token_provider.clone(),
+ );
+ anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
+ id,
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
_ => {
let mut command = configuration
.command()
@@ -861,6 +945,310 @@ impl ContextServerStore {
ProjectSettings::get(location, cx)
}
+ fn create_oauth_token_provider(
+ id: &ContextServerId,
+ server_url: &url::Url,
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut AsyncApp,
+ ) -> Arc<dyn oauth::OAuthTokenProvider> {
+ let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
+ let id = id.clone();
+ let server_url = server_url.clone();
+
+ cx.spawn(async move |cx| {
+ while let Some(refreshed_session) = token_refresh_rx.next().await {
+ if let Err(err) =
+ Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
+ .await
+ {
+ log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
+ }
+ }
+ log::debug!("{} OAuth session persistence task ended", id);
+ })
+ .detach();
+
+ Arc::new(McpOAuthTokenProvider::new(
+ session,
+ http_client,
+ Some(token_refresh_tx),
+ ))
+ }
+
+ /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
+ ///
+ /// This starts a loopback HTTP callback server on an ephemeral port, builds
+ /// the authorization URL, opens the user's browser, waits for the callback,
+ /// exchanges the code for tokens, persists them in the keychain, and restarts
+ /// the server with the new token provider.
+ pub fn authenticate_server(
+ &mut self,
+ id: &ContextServerId,
+ cx: &mut Context<Self>,
+ ) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+
+ let (discovery, server, configuration) = match state {
+ ContextServerState::AuthRequired {
+ discovery,
+ server,
+ configuration,
+ } => (discovery.clone(), server.clone(), configuration.clone()),
+ _ => anyhow::bail!("Server is not in AuthRequired state"),
+ };
+
+ let id = id.clone();
+
+ let task = cx.spawn({
+ let id = id.clone();
+ let server = server.clone();
+ let configuration = configuration.clone();
+ async move |this, cx| {
+ let result = Self::run_oauth_flow(
+ this.clone(),
+ id.clone(),
+ discovery.clone(),
+ configuration.clone(),
+ cx,
+ )
+ .await;
+
+ if let Err(err) = &result {
+ log::error!("{} OAuth authentication failed: {:?}", id, err);
+ // Transition back to AuthRequired so the user can retry
+ // rather than landing in a terminal Error state.
+ this.update(cx, |this, cx| {
+ this.update_server_state(
+ id.clone(),
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery,
+ },
+ cx,
+ )
+ })
+ .log_err();
+ }
+ }
+ });
+
+ self.update_server_state(
+ id,
+ ContextServerState::Authenticating {
+ server,
+ configuration,
+ _task: task,
+ },
+ cx,
+ );
+
+ Ok(())
+ }
+
+ async fn run_oauth_flow(
+ this: WeakEntity<Self>,
+ id: ContextServerId,
+ discovery: Arc<OAuthDiscovery>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
+ let pkce = oauth::generate_pkce_challenge();
+
+ let mut state_bytes = [0u8; 32];
+ rand::rng().fill(&mut state_bytes);
+ let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
+
+ // Start a loopback HTTP server on an ephemeral port. The redirect URI
+ // includes this port so the browser sends the callback directly to our
+ // process.
+ let (redirect_uri, callback_rx) = oauth::start_callback_server()
+ .await
+ .context("Failed to start OAuth callback server")?;
+
+ let http_client = cx.update(|cx| cx.http_client());
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ };
+
+ let client_registration =
+ oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
+ .await
+ .context("Failed to resolve OAuth client registration")?;
+
+ let auth_url = oauth::build_authorization_url(
+ &discovery.auth_server_metadata,
+ &client_registration.client_id,
+ &redirect_uri,
+ &discovery.scopes,
+ &resource,
+ &pkce,
+ &state_param,
+ );
+
+ cx.update(|cx| cx.open_url(auth_url.as_str()));
+
+ let callback = callback_rx
+ .await
+ .map_err(|_| {
+ anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
+ })?
+ .context("OAuth callback server received an invalid request")?;
+
+ if callback.state != state_param {
+ anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
+ }
+
+ let tokens = oauth::exchange_code(
+ &http_client,
+ &discovery.auth_server_metadata,
+ &callback.code,
+ &client_registration.client_id,
+ &redirect_uri,
+ &pkce.verifier,
+ &resource,
+ )
+ .await
+ .context("Failed to exchange authorization code for tokens")?;
+
+ let session = OAuthSession {
+ token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
+ resource: discovery.resource_metadata.resource.clone(),
+ client_registration,
+ tokens,
+ };
+
+ Self::store_session(&credentials_provider, &server_url, &session, cx)
+ .await
+ .context("Failed to persist OAuth session in keychain")?;
+
+ let token_provider = Self::create_oauth_token_provider(
+ &id,
+ &server_url,
+ session,
+ http_client.clone(),
+ credentials_provider,
+ cx,
+ );
+
+ let new_server = this.update(cx, |this, cx| {
+ let global_timeout =
+ Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
+
+ match configuration.as_ref() {
+ ContextServerConfiguration::Http {
+ url,
+ headers,
+ timeout,
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ http_client.clone(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ Some(token_provider.clone()),
+ );
+ Ok(Arc::new(ContextServer::new_with_timeout(
+ id.clone(),
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ }
+ })??;
+
+ this.update(cx, |this, cx| {
+ this.run_server(new_server, configuration, cx);
+ })?;
+
+ Ok(())
+ }
+
+ /// Store the full OAuth session in the system keychain, keyed by the
+ /// server's canonical URI.
+ async fn store_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ session: &OAuthSession,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ let json = serde_json::to_string(session)?;
+ credentials_provider
+ .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
+ .await
+ }
+
+ /// Load the full OAuth session from the system keychain for the given
+ /// server URL.
+ async fn load_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<Option<OAuthSession>> {
+ let key = Self::keychain_key(server_url);
+ match credentials_provider.read_credentials(&key, cx).await? {
+ Some((_username, password_bytes)) => {
+ let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
+ Ok(Some(session))
+ }
+ None => Ok(None),
+ }
+ }
+
+ /// Clear the stored OAuth session from the system keychain.
+ async fn clear_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ credentials_provider.delete_credentials(&key, cx).await
+ }
+
+ fn keychain_key(server_url: &url::Url) -> String {
+ format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
+ }
+
+ /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
+ /// session from the keychain and stop the server.
+ pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+ let configuration = state.configuration();
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
+ };
+
+ let id = id.clone();
+ self.stop_server(&id, cx)?;
+
+ cx.spawn(async move |this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
+ log::error!("{} failed to clear OAuth session: {}", id, err);
+ }
+ // Trigger server recreation so the next start uses a fresh
+ // transport without the old (now-invalidated) token provider.
+ this.update(cx, |this, cx| {
+ this.available_context_servers_changed(cx);
+ })
+ .log_err();
+ })
+ .detach();
+
+ Ok(())
+ }
+
fn update_server_state(
&mut self,
id: ContextServerId,
@@ -1014,3 +1402,104 @@ impl ContextServerStore {
Ok(())
}
}
+
+/// Determines the appropriate server state after a start attempt fails.
+///
+/// When the error is an HTTP 401 with no static auth header configured,
+/// attempts OAuth discovery so the UI can offer an authentication flow.
+async fn resolve_start_failure(
+ id: &ContextServerId,
+ err: anyhow::Error,
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &AsyncApp,
+) -> ContextServerState {
+ let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
+ TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
+ });
+
+ if www_authenticate.is_some() && configuration.has_static_auth_header() {
+ log::warn!("{id} received 401 with a static Authorization header configured");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: "Server returned 401 Unauthorized. Check your configured Authorization header."
+ .into(),
+ };
+ }
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
+ url.clone()
+ }
+ _ => {
+ if www_authenticate.is_some() {
+ log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
+ } else {
+ log::error!("{id} context server failed to start: {err}");
+ }
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ };
+
+ // When the error is NOT a 401 but there is a cached OAuth session in the
+ // keychain, the session is likely stale/expired and caused the failure
+ // (e.g. timeout because the server rejected the token silently). Clear it
+ // so the next start attempt can get a clean 401 and trigger the auth flow.
+ if www_authenticate.is_none() {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
+ Ok(Some(_)) => {
+ log::info!("{id} start failed with a cached OAuth session present; clearing it");
+ ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
+ .await
+ .log_err();
+ }
+ _ => {
+ log::error!("{id} context server failed to start: {err}");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ }
+ }
+
+ let default_www_authenticate = oauth::WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let www_authenticate = www_authenticate
+ .as_ref()
+ .unwrap_or(&default_www_authenticate);
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
+ Ok(discovery) => {
+ log::info!(
+ "{id} requires OAuth authorization (auth server: {})",
+ discovery.auth_server_metadata.issuer,
+ );
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery: Arc::new(discovery),
+ }
+ }
+ Err(discovery_err) => {
+ log::error!("{id} OAuth discovery failed: {discovery_err}");
+ ContextServerState::Error {
+ configuration,
+ server,
+ error: format!("OAuth discovery failed: {discovery_err}").into(),
+ }
+ }
+ }
+}
@@ -6345,22 +6345,9 @@ impl Repository {
let RepositoryState::Local(LocalRepositoryState { backend, .. }) = state else {
bail!("not a local repository")
};
- let compute_snapshot = this.update(&mut cx, |this, _| {
- this.paths_needing_status_update.clear();
- compute_snapshot(
- this.id,
- this.work_directory_abs_path.clone(),
- this.snapshot.clone(),
- backend.clone(),
- )
- });
- let (snapshot, events) = cx.background_spawn(compute_snapshot).await?;
+ let snapshot = compute_snapshot(this.clone(), backend.clone(), &mut cx).await?;
this.update(&mut cx, |this, cx| {
- this.snapshot = snapshot.clone();
this.clear_pending_ops(cx);
- for event in events {
- cx.emit(event);
- }
});
if let Some(updates_tx) = updates_tx {
updates_tx
@@ -7087,47 +7074,124 @@ fn proto_to_commit_details(proto: &proto::GitCommitDetails) -> CommitDetails {
}
}
+/// This snapshot computes the repository state on the foreground thread while
+/// running the git commands on the background thread. We update branch, head,
+/// remotes, and worktrees first so the UI can react sooner, then compute file
+/// state and emit those events immediately after.
async fn compute_snapshot(
- id: RepositoryId,
- work_directory_abs_path: Arc<Path>,
- prev_snapshot: RepositorySnapshot,
+ this: Entity<Repository>,
backend: Arc<dyn GitRepository>,
-) -> Result<(RepositorySnapshot, Vec<RepositoryEvent>)> {
- let mut events = Vec::new();
- let branches = backend.branches().await?;
- let branch = branches.into_iter().find(|branch| branch.is_head);
-
- // Useful when branch is None in detached head state
- let head_commit = match backend.head_sha().await {
- Some(head_sha) => backend.show(head_sha).await.log_err(),
- None => None,
- };
+ cx: &mut AsyncApp,
+) -> Result<RepositorySnapshot> {
+ let (id, work_directory_abs_path, prev_snapshot) = this.update(cx, |this, _| {
+ this.paths_needing_status_update.clear();
+ (
+ this.id,
+ this.work_directory_abs_path.clone(),
+ this.snapshot.clone(),
+ )
+ });
- let diff_stat_future: BoxFuture<'_, Result<status::GitDiffStat>> = if head_commit.is_some() {
- backend.diff_stat(&[])
- } else {
- future::ready(Ok(status::GitDiffStat {
- entries: Arc::default(),
- }))
- .boxed()
+ let head_commit_future = {
+ let backend = backend.clone();
+ async move {
+ Ok(match backend.head_sha().await {
+ Some(head_sha) => backend.show(head_sha).await.log_err(),
+ None => None,
+ })
+ }
};
- let (statuses, diff_stats, all_worktrees) = futures::future::try_join3(
- backend.status(&[RepoPath::from_rel_path(
- &RelPath::new(".".as_ref(), PathStyle::local()).unwrap(),
- )]),
- diff_stat_future,
- backend.worktrees(),
- )
- .await?;
+ let (branches, head_commit, all_worktrees) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ async move {
+ futures::future::try_join3(
+ backend.branches(),
+ head_commit_future,
+ backend.worktrees(),
+ )
+ .await
+ }
+ })
+ .await?;
+ let branch = branches.into_iter().find(|branch| branch.is_head);
let linked_worktrees: Arc<[GitWorktree]> = all_worktrees
.into_iter()
.filter(|wt| wt.path != *work_directory_abs_path)
.collect();
+ let (remote_origin_url, remote_upstream_url) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ async move {
+ Ok::<_, anyhow::Error>(
+ futures::future::join(
+ backend.remote_url("origin"),
+ backend.remote_url("upstream"),
+ )
+ .await,
+ )
+ }
+ })
+ .await?;
+
+ let snapshot = this.update(cx, |this, cx| {
+ let branch_changed =
+ branch != this.snapshot.branch || head_commit != this.snapshot.head_commit;
+ let worktrees_changed = *linked_worktrees != *this.snapshot.linked_worktrees;
+
+ this.snapshot = RepositorySnapshot {
+ id,
+ work_directory_abs_path,
+ branch,
+ head_commit,
+ remote_origin_url,
+ remote_upstream_url,
+ linked_worktrees,
+ scan_id: prev_snapshot.scan_id + 1,
+ ..prev_snapshot
+ };
+
+ if branch_changed {
+ cx.emit(RepositoryEvent::BranchChanged);
+ }
+
+ if worktrees_changed {
+ cx.emit(RepositoryEvent::GitWorktreeListChanged);
+ }
+
+ this.snapshot.clone()
+ });
+
+ let (statuses, diff_stats, stash_entries) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ let snapshot = snapshot.clone();
+ async move {
+ let diff_stat_future: BoxFuture<'_, Result<status::GitDiffStat>> =
+ if snapshot.head_commit.is_some() {
+ backend.diff_stat(&[])
+ } else {
+ future::ready(Ok(status::GitDiffStat {
+ entries: Arc::default(),
+ }))
+ .boxed()
+ };
+ futures::future::try_join3(
+ backend.status(&[RepoPath::from_rel_path(
+ &RelPath::new(".".as_ref(), PathStyle::local()).unwrap(),
+ )]),
+ diff_stat_future,
+ backend.stash_entries(),
+ )
+ .await
+ }
+ })
+ .await?;
+
let diff_stat_map: HashMap<&RepoPath, DiffStat> =
diff_stats.entries.iter().map(|(p, s)| (p, *s)).collect();
- let stash_entries = backend.stash_entries().await?;
let mut conflicted_paths = Vec::new();
let statuses_by_path = SumTree::from_iter(
statuses.entries.iter().map(|(repo_path, status)| {
@@ -7142,42 +7206,35 @@ async fn compute_snapshot(
}),
(),
);
- let mut merge_details = prev_snapshot.merge;
- let conflicts_changed = merge_details.update(&backend, conflicted_paths).await?;
- log::debug!("new merge details: {merge_details:?}");
-
- if conflicts_changed || statuses_by_path != prev_snapshot.statuses_by_path {
- events.push(RepositoryEvent::StatusesChanged)
- }
- if branch != prev_snapshot.branch || head_commit != prev_snapshot.head_commit {
- events.push(RepositoryEvent::BranchChanged);
- }
-
- if *linked_worktrees != *prev_snapshot.linked_worktrees {
- events.push(RepositoryEvent::GitWorktreeListChanged);
- }
+ let merge_details = cx
+ .background_spawn({
+ let backend = backend.clone();
+ let mut merge_details = snapshot.merge.clone();
+ async move {
+ let conflicts_changed = merge_details.update(&backend, conflicted_paths).await?;
+ Ok::<_, anyhow::Error>((merge_details, conflicts_changed))
+ }
+ })
+ .await?;
+ let (merge_details, conflicts_changed) = merge_details;
+ log::debug!("new merge details: {merge_details:?}");
- let remote_origin_url = backend.remote_url("origin").await;
- let remote_upstream_url = backend.remote_url("upstream").await;
+ Ok(this.update(cx, |this, cx| {
+ if conflicts_changed || statuses_by_path != this.snapshot.statuses_by_path {
+ cx.emit(RepositoryEvent::StatusesChanged);
+ }
+ if stash_entries != this.snapshot.stash_entries {
+ cx.emit(RepositoryEvent::StashEntriesChanged);
+ }
- let snapshot = RepositorySnapshot {
- id,
- statuses_by_path,
- work_directory_abs_path,
- original_repo_abs_path: prev_snapshot.original_repo_abs_path,
- path_style: prev_snapshot.path_style,
- scan_id: prev_snapshot.scan_id + 1,
- branch,
- head_commit,
- merge: merge_details,
- remote_origin_url,
- remote_upstream_url,
- stash_entries,
- linked_worktrees,
- };
+ this.snapshot.scan_id += 1;
+ this.snapshot.merge = merge_details;
+ this.snapshot.statuses_by_path = statuses_by_path;
+ this.snapshot.stash_entries = stash_entries;
- Ok((snapshot, events))
+ this.snapshot.clone()
+ }))
}
fn status_from_proto(
@@ -1611,28 +1611,6 @@ impl LocalLspStore {
})
})?;
- /// Apply edits to the buffer that will become part of the formatting transaction.
- /// Fails if the buffer has been edited since the start of that transaction.
- fn extend_formatting_transaction(
- buffer: &FormattableBuffer,
- formatting_transaction_id: text::TransactionId,
- cx: &mut AsyncApp,
- operation: impl FnOnce(&mut Buffer, &mut Context<Buffer>),
- ) -> anyhow::Result<()> {
- buffer.handle.update(cx, |buffer, cx| {
- let last_transaction_id = buffer.peek_undo_stack().map(|t| t.transaction_id());
- if last_transaction_id != Some(formatting_transaction_id) {
- anyhow::bail!("Buffer edited while formatting. Aborting")
- }
- buffer.start_transaction();
- operation(buffer, cx);
- if let Some(transaction_id) = buffer.end_transaction(cx) {
- buffer.merge_transactions(transaction_id, formatting_transaction_id);
- }
- Ok(())
- })
- }
-
// handle whitespace formatting
if settings.remove_trailing_whitespace_on_save {
zlog::trace!(logger => "removing trailing whitespace");
@@ -1702,508 +1680,532 @@ impl LocalLspStore {
} else {
formatter
};
- match formatter {
- Formatter::None => {
- zlog::trace!(logger => "skipping formatter 'none'");
- continue;
- }
- Formatter::Auto => unreachable!("Auto resolved above"),
- Formatter::Prettier => {
- let logger = zlog::scoped!(logger => "prettier");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer via prettier");
+ if let Err(err) = Self::apply_formatter(
+ formatter,
+ &lsp_store,
+ buffer,
+ formatting_transaction_id,
+ &adapters_and_servers,
+ &settings,
+ request_timeout,
+ logger,
+ cx,
+ )
+ .await
+ {
+ zlog::error!(logger => "Formatter failed, skipping: {err:#}");
+ }
+ }
- let prettier = lsp_store.read_with(cx, |lsp_store, _cx| {
- lsp_store.prettier_store().unwrap().downgrade()
- })?;
- let diff = prettier_store::format_with_prettier(&prettier, &buffer.handle, cx)
+ Ok(())
+ }
+
+ async fn apply_formatter(
+ formatter: &Formatter,
+ lsp_store: &WeakEntity<LspStore>,
+ buffer: &FormattableBuffer,
+ formatting_transaction_id: clock::Lamport,
+ adapters_and_servers: &[(Arc<CachedLspAdapter>, Arc<LanguageServer>)],
+ settings: &LanguageSettings,
+ request_timeout: Duration,
+ logger: zlog::Logger,
+ cx: &mut AsyncApp,
+ ) -> anyhow::Result<()> {
+ match formatter {
+ Formatter::None => {
+ zlog::trace!(logger => "skipping formatter 'none'");
+ return Ok(());
+ }
+ Formatter::Auto => {
+ debug_panic!("Auto resolved above");
+ return Ok(());
+ }
+ Formatter::Prettier => {
+ let logger = zlog::scoped!(logger => "prettier");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer via prettier");
+
+ let prettier = lsp_store.read_with(cx, |lsp_store, _cx| {
+ lsp_store.prettier_store().unwrap().downgrade()
+ })?;
+ let diff = prettier_store::format_with_prettier(&prettier, &buffer.handle, cx)
+ .await
+ .transpose()?;
+ let Some(diff) = diff else {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
+ };
+
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.apply_diff(diff, cx);
+ },
+ )?;
+ }
+ Formatter::External { command, arguments } => {
+ let logger = zlog::scoped!(logger => "command");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer via external command");
+
+ let diff =
+ Self::format_via_external_command(buffer, &command, arguments.as_deref(), cx)
.await
- .transpose()?;
- let Some(diff) = diff else {
- zlog::trace!(logger => "No changes");
- continue;
- };
+ .with_context(|| {
+ format!("Failed to format buffer via external command: {}", command)
+ })?;
+ let Some(diff) = diff else {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
+ };
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- buffer.apply_diff(diff, cx);
- },
- )?;
- }
- Formatter::External { command, arguments } => {
- let logger = zlog::scoped!(logger => "command");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer via external command");
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.apply_diff(diff, cx);
+ },
+ )?;
+ }
+ Formatter::LanguageServer(specifier) => {
+ let logger = zlog::scoped!(logger => "language-server");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer using language server");
- let diff = Self::format_via_external_command(
- buffer,
- &command,
- arguments.as_deref(),
+ let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
+ zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using language servers. Skipping");
+ return Ok(());
+ };
+
+ let language_server = match specifier {
+ settings::LanguageServerFormatterSpecifier::Specific { name } => {
+ adapters_and_servers.iter().find_map(|(adapter, server)| {
+ if adapter.name.0.as_ref() == name {
+ Some(server.clone())
+ } else {
+ None
+ }
+ })
+ }
+ settings::LanguageServerFormatterSpecifier::Current => adapters_and_servers
+ .iter()
+ .find(|(_, server)| Self::server_supports_formatting(server))
+ .map(|(_, server)| server.clone()),
+ };
+
+ let Some(language_server) = language_server else {
+ log::debug!(
+ "No language server found to format buffer '{:?}'. Skipping",
+ buffer_path_abs.as_path().to_string_lossy()
+ );
+ return Ok(());
+ };
+
+ zlog::trace!(
+ logger =>
+ "Formatting buffer '{:?}' using language server '{:?}'",
+ buffer_path_abs.as_path().to_string_lossy(),
+ language_server.name()
+ );
+
+ let edits = if let Some(ranges) = buffer.ranges.as_ref() {
+ zlog::trace!(logger => "formatting ranges");
+ Self::format_ranges_via_lsp(
+ &lsp_store,
+ &buffer.handle,
+ ranges,
+ buffer_path_abs,
+ &language_server,
+ &settings,
cx,
)
.await
- .with_context(|| {
- format!("Failed to format buffer via external command: {}", command)
- })?;
- let Some(diff) = diff else {
- zlog::trace!(logger => "No changes");
- continue;
- };
-
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
+ .context("Failed to format ranges via language server")?
+ } else {
+ zlog::trace!(logger => "formatting full");
+ Self::format_via_lsp(
+ &lsp_store,
+ &buffer.handle,
+ buffer_path_abs,
+ &language_server,
+ &settings,
cx,
- |buffer, cx| {
- buffer.apply_diff(diff, cx);
- },
- )?;
+ )
+ .await
+ .context("failed to format via language server")?
+ };
+
+ if edits.is_empty() {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
}
- Formatter::LanguageServer(specifier) => {
- let logger = zlog::scoped!(logger => "language-server");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer using language server");
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ },
+ )?;
+ }
+ Formatter::CodeAction(code_action_name) => {
+ let logger = zlog::scoped!(logger => "code-actions");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer using code actions");
- let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
- zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using language servers. Skipping");
- continue;
- };
+ let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
+ zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using code actions. Skipping");
+ return Ok(());
+ };
- let language_server = match specifier {
- settings::LanguageServerFormatterSpecifier::Specific { name } => {
- adapters_and_servers.iter().find_map(|(adapter, server)| {
- if adapter.name.0.as_ref() == name {
- Some(server.clone())
- } else {
- None
- }
- })
- }
- settings::LanguageServerFormatterSpecifier::Current => adapters_and_servers
- .iter()
- .find(|(_, server)| Self::server_supports_formatting(server))
- .map(|(_, server)| server.clone()),
- };
+ let code_action_kind: CodeActionKind = code_action_name.clone().into();
+ zlog::trace!(logger => "Attempting to resolve code actions {:?}", &code_action_kind);
- let Some(language_server) = language_server else {
- log::debug!(
- "No language server found to format buffer '{:?}'. Skipping",
- buffer_path_abs.as_path().to_string_lossy()
+ let mut actions_and_servers = Vec::new();
+
+ for (index, (_, language_server)) in adapters_and_servers.iter().enumerate() {
+ let actions_result = Self::get_server_code_actions_from_action_kinds(
+ &lsp_store,
+ language_server.server_id(),
+ vec![code_action_kind.clone()],
+ &buffer.handle,
+ cx,
+ )
+ .await
+ .with_context(|| {
+ format!(
+ "Failed to resolve code action {:?} with language server {}",
+ code_action_kind,
+ language_server.name()
+ )
+ });
+ let Ok(actions) = actions_result else {
+ // note: it may be better to set result to the error and break formatters here
+ // but for now we try to execute the actions that we can resolve and skip the rest
+ zlog::error!(
+ logger =>
+ "Failed to resolve code action {:?} with language server {}",
+ code_action_kind,
+ language_server.name()
);
continue;
};
+ for action in actions {
+ actions_and_servers.push((action, index));
+ }
+ }
- zlog::trace!(
- logger =>
- "Formatting buffer '{:?}' using language server '{:?}'",
- buffer_path_abs.as_path().to_string_lossy(),
- language_server.name()
- );
+ if actions_and_servers.is_empty() {
+ zlog::warn!(logger => "No code actions were resolved, continuing");
+ return Ok(());
+ }
- let edits = if let Some(ranges) = buffer.ranges.as_ref() {
- zlog::trace!(logger => "formatting ranges");
- Self::format_ranges_via_lsp(
- &lsp_store,
- &buffer.handle,
- ranges,
- buffer_path_abs,
- &language_server,
- &settings,
- cx,
- )
- .await
- .context("Failed to format ranges via language server")?
- } else {
- zlog::trace!(logger => "formatting full");
- Self::format_via_lsp(
- &lsp_store,
- &buffer.handle,
- buffer_path_abs,
- &language_server,
- &settings,
- cx,
+ 'actions: for (mut action, server_index) in actions_and_servers {
+ let server = &adapters_and_servers[server_index].1;
+
+ let describe_code_action = |action: &CodeAction| {
+ format!(
+ "code action '{}' with title \"{}\" on server {}",
+ action
+ .lsp_action
+ .action_kind()
+ .unwrap_or("unknown".into())
+ .as_str(),
+ action.lsp_action.title(),
+ server.name(),
)
- .await
- .context("failed to format via language server")?
};
- if edits.is_empty() {
- zlog::trace!(logger => "No changes");
- continue;
- }
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- buffer.edit(edits, None, cx);
- },
- )?;
- }
- Formatter::CodeAction(code_action_name) => {
- let logger = zlog::scoped!(logger => "code-actions");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer using code actions");
+ zlog::trace!(logger => "Executing {}", describe_code_action(&action));
- let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
- zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using code actions. Skipping");
+ if let Err(err) =
+ Self::try_resolve_code_action(server, &mut action, request_timeout).await
+ {
+ zlog::error!(
+ logger =>
+ "Failed to resolve {}. Error: {}",
+ describe_code_action(&action),
+ err
+ );
continue;
- };
-
- let code_action_kind: CodeActionKind = code_action_name.clone().into();
- zlog::trace!(logger => "Attempting to resolve code actions {:?}", &code_action_kind);
-
- let mut actions_and_servers = Vec::new();
+ }
- for (index, (_, language_server)) in adapters_and_servers.iter().enumerate() {
- let actions_result = Self::get_server_code_actions_from_action_kinds(
- &lsp_store,
- language_server.server_id(),
- vec![code_action_kind.clone()],
- &buffer.handle,
- cx,
- )
- .await
- .with_context(|| {
- format!(
- "Failed to resolve code action {:?} with language server {}",
- code_action_kind,
- language_server.name()
- )
- });
- let Ok(actions) = actions_result else {
- // note: it may be better to set result to the error and break formatters here
- // but for now we try to execute the actions that we can resolve and skip the rest
- zlog::error!(
+ if let Some(edit) = action.lsp_action.edit().cloned() {
+ // NOTE: code below duplicated from `Self::deserialize_workspace_edit`
+ // but filters out and logs warnings for code actions that require unreasonably
+ // difficult handling on our part, such as:
+ // - applying edits that call commands
+ // which can result in arbitrary workspace edits being sent from the server that
+ // have no way of being tied back to the command that initiated them (i.e. we
+ // can't know which edits are part of the format request, or if the server is done sending
+ // actions in response to the command)
+ // - actions that create/delete/modify/rename files other than the one we are formatting
+ // as we then would need to handle such changes correctly in the local history as well
+ // as the remote history through the ProjectTransaction
+ // - actions with snippet edits, as these simply don't make sense in the context of a format request
+ // Supporting these actions is not impossible, but not supported as of yet.
+ if edit.changes.is_none() && edit.document_changes.is_none() {
+ zlog::trace!(
logger =>
- "Failed to resolve code action {:?} with language server {}",
- code_action_kind,
- language_server.name()
+ "No changes for code action. Skipping {}",
+ describe_code_action(&action),
);
continue;
- };
- for action in actions {
- actions_and_servers.push((action, index));
}
- }
-
- if actions_and_servers.is_empty() {
- zlog::warn!(logger => "No code actions were resolved, continuing");
- continue;
- }
- 'actions: for (mut action, server_index) in actions_and_servers {
- let server = &adapters_and_servers[server_index].1;
-
- let describe_code_action = |action: &CodeAction| {
- format!(
- "code action '{}' with title \"{}\" on server {}",
- action
- .lsp_action
- .action_kind()
- .unwrap_or("unknown".into())
- .as_str(),
- action.lsp_action.title(),
- server.name(),
- )
- };
+ let mut operations = Vec::new();
+ if let Some(document_changes) = edit.document_changes {
+ match document_changes {
+ lsp::DocumentChanges::Edits(edits) => operations.extend(
+ edits.into_iter().map(lsp::DocumentChangeOperation::Edit),
+ ),
+ lsp::DocumentChanges::Operations(ops) => operations = ops,
+ }
+ } else if let Some(changes) = edit.changes {
+ operations.extend(changes.into_iter().map(|(uri, edits)| {
+ lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit {
+ text_document: lsp::OptionalVersionedTextDocumentIdentifier {
+ uri,
+ version: None,
+ },
+ edits: edits.into_iter().map(Edit::Plain).collect(),
+ })
+ }));
+ }
- zlog::trace!(logger => "Executing {}", describe_code_action(&action));
+ let mut edits = Vec::with_capacity(operations.len());
- if let Err(err) =
- Self::try_resolve_code_action(server, &mut action, request_timeout)
- .await
- {
- zlog::error!(
+ if operations.is_empty() {
+ zlog::trace!(
logger =>
- "Failed to resolve {}. Error: {}",
+ "No changes for code action. Skipping {}",
describe_code_action(&action),
- err
);
continue;
}
-
- if let Some(edit) = action.lsp_action.edit().cloned() {
- // NOTE: code below duplicated from `Self::deserialize_workspace_edit`
- // but filters out and logs warnings for code actions that require unreasonably
- // difficult handling on our part, such as:
- // - applying edits that call commands
- // which can result in arbitrary workspace edits being sent from the server that
- // have no way of being tied back to the command that initiated them (i.e. we
- // can't know which edits are part of the format request, or if the server is done sending
- // actions in response to the command)
- // - actions that create/delete/modify/rename files other than the one we are formatting
- // as we then would need to handle such changes correctly in the local history as well
- // as the remote history through the ProjectTransaction
- // - actions with snippet edits, as these simply don't make sense in the context of a format request
- // Supporting these actions is not impossible, but not supported as of yet.
- if edit.changes.is_none() && edit.document_changes.is_none() {
- zlog::trace!(
+ for operation in operations {
+ let op = match operation {
+ lsp::DocumentChangeOperation::Edit(op) => op,
+ lsp::DocumentChangeOperation::Op(_) => {
+ zlog::warn!(
+ logger =>
+ "Code actions which create, delete, or rename files are not supported on format. Skipping {}",
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ }
+ };
+ let Ok(file_path) = op.text_document.uri.to_file_path() else {
+ zlog::warn!(
logger =>
- "No changes for code action. Skipping {}",
+ "Failed to convert URI '{:?}' to file path. Skipping {}",
+ &op.text_document.uri,
describe_code_action(&action),
);
- continue;
- }
-
- let mut operations = Vec::new();
- if let Some(document_changes) = edit.document_changes {
- match document_changes {
- lsp::DocumentChanges::Edits(edits) => operations.extend(
- edits.into_iter().map(lsp::DocumentChangeOperation::Edit),
- ),
- lsp::DocumentChanges::Operations(ops) => operations = ops,
- }
- } else if let Some(changes) = edit.changes {
- operations.extend(changes.into_iter().map(|(uri, edits)| {
- lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit {
- text_document:
- lsp::OptionalVersionedTextDocumentIdentifier {
- uri,
- version: None,
- },
- edits: edits.into_iter().map(Edit::Plain).collect(),
- })
- }));
- }
-
- let mut edits = Vec::with_capacity(operations.len());
-
- if operations.is_empty() {
- zlog::trace!(
+ continue 'actions;
+ };
+ if &file_path != buffer_path_abs {
+ zlog::warn!(
logger =>
- "No changes for code action. Skipping {}",
+ "File path '{:?}' does not match buffer path '{:?}'. Skipping {}",
+ file_path,
+ buffer_path_abs,
describe_code_action(&action),
);
- continue;
+ continue 'actions;
}
- for operation in operations {
- let op = match operation {
- lsp::DocumentChangeOperation::Edit(op) => op,
- lsp::DocumentChangeOperation::Op(_) => {
+
+ let mut lsp_edits = Vec::new();
+ for edit in op.edits {
+ match edit {
+ Edit::Plain(edit) => {
+ if !lsp_edits.contains(&edit) {
+ lsp_edits.push(edit);
+ }
+ }
+ Edit::Annotated(edit) => {
+ if !lsp_edits.contains(&edit.text_edit) {
+ lsp_edits.push(edit.text_edit);
+ }
+ }
+ Edit::Snippet(_) => {
zlog::warn!(
logger =>
- "Code actions which create, delete, or rename files are not supported on format. Skipping {}",
+ "Code actions which produce snippet edits are not supported during formatting. Skipping {}",
describe_code_action(&action),
);
continue 'actions;
}
- };
- let Ok(file_path) = op.text_document.uri.to_file_path() else {
- zlog::warn!(
- logger =>
- "Failed to convert URI '{:?}' to file path. Skipping {}",
- &op.text_document.uri,
- describe_code_action(&action),
- );
- continue 'actions;
- };
- if &file_path != buffer_path_abs {
- zlog::warn!(
- logger =>
- "File path '{:?}' does not match buffer path '{:?}'. Skipping {}",
- file_path,
- buffer_path_abs,
- describe_code_action(&action),
- );
- continue 'actions;
}
-
- let mut lsp_edits = Vec::new();
- for edit in op.edits {
- match edit {
- Edit::Plain(edit) => {
- if !lsp_edits.contains(&edit) {
- lsp_edits.push(edit);
- }
- }
- Edit::Annotated(edit) => {
- if !lsp_edits.contains(&edit.text_edit) {
- lsp_edits.push(edit.text_edit);
- }
- }
- Edit::Snippet(_) => {
- zlog::warn!(
- logger =>
- "Code actions which produce snippet edits are not supported during formatting. Skipping {}",
- describe_code_action(&action),
- );
- continue 'actions;
- }
- }
- }
- let edits_result = lsp_store
- .update(cx, |lsp_store, cx| {
- lsp_store.as_local_mut().unwrap().edits_from_lsp(
- &buffer.handle,
- lsp_edits,
- server.server_id(),
- op.text_document.version,
- cx,
- )
- })?
- .await;
- let Ok(resolved_edits) = edits_result else {
- zlog::warn!(
- logger =>
- "Failed to resolve edits from LSP for buffer {:?} while handling {}",
- buffer_path_abs.as_path(),
- describe_code_action(&action),
- );
- continue 'actions;
- };
- edits.extend(resolved_edits);
}
-
- if edits.is_empty() {
- zlog::warn!(logger => "No edits resolved from LSP");
- continue;
- }
-
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- zlog::info!(
- "Applying edits {edits:?}. Content: {:?}",
- buffer.text()
- );
- buffer.edit(edits, None, cx);
- zlog::info!("Applied edits. New Content: {:?}", buffer.text());
- },
- )?;
+ let edits_result = lsp_store
+ .update(cx, |lsp_store, cx| {
+ lsp_store.as_local_mut().unwrap().edits_from_lsp(
+ &buffer.handle,
+ lsp_edits,
+ server.server_id(),
+ op.text_document.version,
+ cx,
+ )
+ })?
+ .await;
+ let Ok(resolved_edits) = edits_result else {
+ zlog::warn!(
+ logger =>
+ "Failed to resolve edits from LSP for buffer {:?} while handling {}",
+ buffer_path_abs.as_path(),
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ };
+ edits.extend(resolved_edits);
}
- // bail early if command is invalid
- let Some(command) = action.lsp_action.command() else {
- continue;
- };
-
- zlog::warn!(
- logger =>
- "Executing code action command '{}'. This may cause formatting to abort unnecessarily as well as splitting formatting into two entries in the undo history",
- &command.command,
- );
-
- let server_capabilities = server.capabilities();
- let available_commands = server_capabilities
- .execute_command_provider
- .as_ref()
- .map(|options| options.commands.as_slice())
- .unwrap_or_default();
- if !available_commands.contains(&command.command) {
- zlog::warn!(
- logger =>
- "Cannot execute a command {} not listed in the language server capabilities of server {}",
- command.command,
- server.name(),
- );
+ if edits.is_empty() {
+ zlog::warn!(logger => "No edits resolved from LSP");
continue;
}
- // noop so we just ensure buffer hasn't been edited since resolving code actions
extend_formatting_transaction(
buffer,
formatting_transaction_id,
cx,
- |_, _| {},
+ |buffer, cx| {
+ zlog::info!(
+ "Applying edits {edits:?}. Content: {:?}",
+ buffer.text()
+ );
+ buffer.edit(edits, None, cx);
+ zlog::info!("Applied edits. New Content: {:?}", buffer.text());
+ },
)?;
- zlog::info!(logger => "Executing command {}", &command.command);
+ }
- lsp_store.update(cx, |this, _| {
- this.as_local_mut()
- .unwrap()
- .last_workspace_edits_by_language_server
- .remove(&server.server_id());
- })?;
+ let Some(command) = action.lsp_action.command() else {
+ continue;
+ };
- let execute_command_result = server
- .request::<lsp::request::ExecuteCommand>(
- lsp::ExecuteCommandParams {
- command: command.command.clone(),
- arguments: command.arguments.clone().unwrap_or_default(),
- ..Default::default()
- },
- request_timeout,
- )
- .await
- .into_response();
+ zlog::warn!(
+ logger =>
+ "Executing code action command '{}'. This may cause formatting to abort unnecessarily as well as splitting formatting into two entries in the undo history",
+ &command.command,
+ );
- if execute_command_result.is_err() {
- zlog::error!(
- logger =>
- "Failed to execute command '{}' as part of {}",
- &command.command,
- describe_code_action(&action),
- );
- continue 'actions;
- }
+ let server_capabilities = server.capabilities();
+ let available_commands = server_capabilities
+ .execute_command_provider
+ .as_ref()
+ .map(|options| options.commands.as_slice())
+ .unwrap_or_default();
+ if !available_commands.contains(&command.command) {
+ zlog::warn!(
+ logger =>
+ "Cannot execute a command {} not listed in the language server capabilities of server {}",
+ command.command,
+ server.name(),
+ );
+ continue;
+ }
- let mut project_transaction_command = lsp_store.update(cx, |this, _| {
- this.as_local_mut()
- .unwrap()
- .last_workspace_edits_by_language_server
- .remove(&server.server_id())
- .unwrap_or_default()
- })?;
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |_, _| {},
+ )?;
+ zlog::info!(logger => "Executing command {}", &command.command);
- if let Some(transaction) =
- project_transaction_command.0.remove(&buffer.handle)
- {
- zlog::trace!(
- logger =>
- "Successfully captured {} edits that resulted from command {}",
- transaction.edit_ids.len(),
- &command.command,
- );
- let transaction_id_project_transaction = transaction.id;
- buffer.handle.update(cx, |buffer, _| {
- // it may have been removed from history if push_to_history was
- // false in deserialize_workspace_edit. If so push it so we
- // can merge it with the format transaction
- // and pop the combined transaction off the history stack
- // later if push_to_history is false
- if buffer.get_transaction(transaction.id).is_none() {
- buffer.push_transaction(transaction, Instant::now());
- }
- buffer.merge_transactions(
- transaction_id_project_transaction,
- formatting_transaction_id,
- );
- });
- }
+ lsp_store.update(cx, |this, _| {
+ this.as_local_mut()
+ .unwrap()
+ .last_workspace_edits_by_language_server
+ .remove(&server.server_id());
+ })?;
- if project_transaction_command.0.is_empty() {
- continue;
- }
+ let execute_command_result = server
+ .request::<lsp::request::ExecuteCommand>(
+ lsp::ExecuteCommandParams {
+ command: command.command.clone(),
+ arguments: command.arguments.clone().unwrap_or_default(),
+ ..Default::default()
+ },
+ request_timeout,
+ )
+ .await
+ .into_response();
- let mut extra_buffers = String::new();
- for buffer in project_transaction_command.0.keys() {
- buffer.read_with(cx, |b, cx| {
- let Some(path) = b.project_path(cx) else {
- return;
- };
+ if execute_command_result.is_err() {
+ zlog::error!(
+ logger =>
+ "Failed to execute command '{}' as part of {}",
+ &command.command,
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ }
- if !extra_buffers.is_empty() {
- extra_buffers.push_str(", ");
- }
- extra_buffers.push_str(path.path.as_unix_str());
- });
- }
- zlog::warn!(
+ let mut project_transaction_command = lsp_store.update(cx, |this, _| {
+ this.as_local_mut()
+ .unwrap()
+ .last_workspace_edits_by_language_server
+ .remove(&server.server_id())
+ .unwrap_or_default()
+ })?;
+
+ if let Some(transaction) = project_transaction_command.0.remove(&buffer.handle)
+ {
+ zlog::trace!(
logger =>
- "Unexpected edits to buffers other than the buffer actively being formatted due to command {}. Impacted buffers: [{}].",
+ "Successfully captured {} edits that resulted from command {}",
+ transaction.edit_ids.len(),
&command.command,
- extra_buffers,
);
- // NOTE: if this case is hit, the proper thing to do is to for each buffer, merge the extra transaction
- // into the existing transaction in project_transaction if there is one, and if there isn't one in project_transaction,
- // add it so it's included, and merge it into the format transaction when its created later
+ let transaction_id_project_transaction = transaction.id;
+ buffer.handle.update(cx, |buffer, _| {
+ // it may have been removed from history if push_to_history was
+ // false in deserialize_workspace_edit. If so push it so we
+ // can merge it with the format transaction
+ // and pop the combined transaction off the history stack
+ // later if push_to_history is false
+ if buffer.get_transaction(transaction.id).is_none() {
+ buffer.push_transaction(transaction, Instant::now());
+ }
+ buffer.merge_transactions(
+ transaction_id_project_transaction,
+ formatting_transaction_id,
+ );
+ });
+ }
+
+ if project_transaction_command.0.is_empty() {
+ continue;
+ }
+
+ let mut extra_buffers = String::new();
+ for buffer in project_transaction_command.0.keys() {
+ buffer.read_with(cx, |b, cx| {
+ let Some(path) = b.project_path(cx) else {
+ return;
+ };
+
+ if !extra_buffers.is_empty() {
+ extra_buffers.push_str(", ");
+ }
+ extra_buffers.push_str(path.path.as_unix_str());
+ });
}
+ zlog::warn!(
+ logger =>
+ "Unexpected edits to buffers other than the buffer actively being formatted due to command {}. Impacted buffers: [{}].",
+ &command.command,
+ extra_buffers,
+ );
+ // NOTE: if this case is hit, the proper thing to do is to for each buffer, merge the extra transaction
+ // into the existing transaction in project_transaction if there is one, and if there isn't one in project_transaction,
+ // add it so it's included, and merge it into the format transaction when its created later
}
}
}
@@ -3918,6 +3920,7 @@ pub struct LspStore {
pub lsp_server_capabilities: HashMap<LanguageServerId, lsp::ServerCapabilities>,
semantic_token_config: SemanticTokenConfig,
lsp_data: HashMap<BufferId, BufferLspData>,
+ buffer_reload_tasks: HashMap<BufferId, Task<anyhow::Result<()>>>,
next_hint_id: Arc<AtomicUsize>,
}
@@ -4245,6 +4248,7 @@ impl LspStore {
lsp_server_capabilities: HashMap::default(),
semantic_token_config: SemanticTokenConfig::new(cx),
lsp_data: HashMap::default(),
+ buffer_reload_tasks: HashMap::default(),
next_hint_id: Arc::default(),
active_entry: None,
_maintain_workspace_config,
@@ -120,6 +120,7 @@ use std::{
borrow::Cow,
collections::BTreeMap,
ffi::OsString,
+ future::Future,
ops::{Not as _, Range},
path::{Path, PathBuf},
pin::pin,
@@ -2078,6 +2079,12 @@ impl Project {
self.worktree_store.clone()
}
+ /// Returns a future that resolves when all visible worktrees have completed
+ /// their initial scan.
+ pub fn wait_for_initial_scan(&self, cx: &App) -> impl Future<Output = ()> + use<> {
+ self.worktree_store.read(cx).wait_for_initial_scan()
+ }
+
#[inline]
pub fn context_server_store(&self) -> Entity<ContextServerStore> {
self.context_server_store.clone()
@@ -1,4 +1,5 @@
use std::{
+ future::Future,
path::{Path, PathBuf},
sync::{
Arc,
@@ -15,6 +16,7 @@ use gpui::{
WeakEntity,
};
use itertools::Either;
+use postage::{prelude::Stream as _, watch};
use rpc::{
AnyProtoClient, ErrorExt, TypedEnvelope,
proto::{self, REMOTE_SERVER_PROJECT_ID},
@@ -75,6 +77,7 @@ pub struct WorktreeStore {
#[allow(clippy::type_complexity)]
loading_worktrees:
HashMap<Arc<SanitizedPath>, Shared<Task<Result<Entity<Worktree>, Arc<anyhow::Error>>>>>,
+ initial_scan_complete: (watch::Sender<bool>, watch::Receiver<bool>),
state: WorktreeStoreState,
}
@@ -119,6 +122,7 @@ impl WorktreeStore {
worktrees_reordered: false,
scanning_enabled: true,
retain_worktrees,
+ initial_scan_complete: watch::channel_with(true),
state: WorktreeStoreState::Local { fs },
}
}
@@ -139,6 +143,7 @@ impl WorktreeStore {
worktrees_reordered: false,
scanning_enabled: true,
retain_worktrees,
+ initial_scan_complete: watch::channel_with(true),
state: WorktreeStoreState::Remote {
upstream_client,
upstream_project_id,
@@ -174,6 +179,57 @@ impl WorktreeStore {
pub fn disable_scanner(&mut self) {
self.scanning_enabled = false;
+ *self.initial_scan_complete.0.borrow_mut() = true;
+ }
+
+ /// Returns a future that resolves when all visible worktrees have completed
+ /// their initial scan (entries populated, git repos detected).
+ pub fn wait_for_initial_scan(&self) -> impl Future<Output = ()> + use<> {
+ let mut rx = self.initial_scan_complete.1.clone();
+ async move {
+ let mut done = *rx.borrow();
+ while !done {
+ if let Some(value) = rx.recv().await {
+ done = value;
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
+ /// Returns whether all visible worktrees have completed their initial scan.
+ pub fn initial_scan_completed(&self) -> bool {
+ *self.initial_scan_complete.1.borrow()
+ }
+
+ /// Checks whether all visible worktrees have completed their initial scan
+ /// and no worktree creations are pending, and updates the watch channel accordingly.
+ fn update_initial_scan_state(&mut self, cx: &App) {
+ let complete = self.loading_worktrees.is_empty()
+ && self
+ .visible_worktrees(cx)
+ .all(|wt| wt.read(cx).completed_scan_id() >= 1);
+ *self.initial_scan_complete.0.borrow_mut() = complete;
+ }
+
+ /// Spawns a detached task that waits for a worktree's initial scan to complete,
+ /// then rechecks and updates the aggregate initial scan state.
+ fn observe_worktree_scan_completion(
+ &mut self,
+ worktree: &Entity<Worktree>,
+ cx: &mut Context<Self>,
+ ) {
+ let await_scan = worktree.update(cx, |worktree, _cx| worktree.wait_for_snapshot(1));
+ cx.spawn(async move |this, cx| {
+ await_scan.await.ok();
+ this.update(cx, |this, cx| {
+ this.update_initial_scan_state(cx);
+ })
+ .ok();
+ anyhow::Ok(())
+ })
+ .detach();
}
/// Iterates through all worktrees, including ones that don't appear in the project panel
@@ -554,12 +610,22 @@ impl WorktreeStore {
self.loading_worktrees
.insert(abs_path.clone(), task.shared());
+
+ if visible && self.scanning_enabled {
+ *self.initial_scan_complete.0.borrow_mut() = false;
+ }
}
let task = self.loading_worktrees.get(&abs_path).unwrap().clone();
cx.spawn(async move |this, cx| {
let result = task.await;
- this.update(cx, |this, _| this.loading_worktrees.remove(&abs_path))
- .ok();
+ this.update(cx, |this, cx| {
+ this.loading_worktrees.remove(&abs_path);
+ if !visible || !this.scanning_enabled || result.is_err() {
+ this.update_initial_scan_state(cx);
+ }
+ })
+ .ok();
+
match result {
Ok(worktree) => {
if !is_via_collab {
@@ -578,6 +644,13 @@ impl WorktreeStore {
);
});
}
+
+ this.update(cx, |this, cx| {
+ if this.scanning_enabled && visible {
+ this.observe_worktree_scan_completion(&worktree, cx);
+ }
+ })
+ .ok();
}
Ok(worktree)
}
@@ -768,6 +841,7 @@ impl WorktreeStore {
false
}
});
+ self.update_initial_scan_state(cx);
self.send_project_updates(cx);
}
@@ -76,7 +76,7 @@ use std::{
path::{Path, PathBuf},
rc::Rc,
str::FromStr,
- sync::{Arc, OnceLock},
+ sync::{Arc, OnceLock, atomic},
task::Poll,
time::Duration,
};
@@ -3758,6 +3758,266 @@ async fn test_diagnostics_from_multiple_language_servers(cx: &mut gpui::TestAppC
});
}
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_worktree_entry_removal(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "one", "b.rs": "two" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [Path::new(path!("/dir"))], cx).await;
+ let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
+
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostic_entries(
+ LanguageServerId(0),
+ Path::new(path!("/dir/a.rs")).to_owned(),
+ None,
+ None,
+ vec![DiagnosticEntry {
+ range: Unclipped(PointUtf16::new(0, 0))..Unclipped(PointUtf16::new(0, 3)),
+ diagnostic: Diagnostic {
+ severity: DiagnosticSeverity::ERROR,
+ is_primary: true,
+ message: "error in a".to_string(),
+ source_kind: DiagnosticSourceKind::Pushed,
+ ..Diagnostic::default()
+ },
+ }],
+ cx,
+ )
+ .unwrap();
+ lsp_store
+ .update_diagnostic_entries(
+ LanguageServerId(0),
+ Path::new(path!("/dir/b.rs")).to_owned(),
+ None,
+ None,
+ vec![DiagnosticEntry {
+ range: Unclipped(PointUtf16::new(0, 0))..Unclipped(PointUtf16::new(0, 3)),
+ diagnostic: Diagnostic {
+ severity: DiagnosticSeverity::WARNING,
+ is_primary: true,
+ message: "warning in b".to_string(),
+ source_kind: DiagnosticSourceKind::Pushed,
+ ..Diagnostic::default()
+ },
+ }],
+ cx,
+ )
+ .unwrap();
+
+ assert_eq!(
+ lsp_store.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 1,
+ }
+ );
+ });
+
+ fs.remove_file(path!("/dir/a.rs").as_ref(), Default::default())
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ lsp_store.update(cx, |lsp_store, cx| {
+ assert_eq!(
+ lsp_store.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 0,
+ warning_count: 1,
+ },
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_server_restart(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "x" })).await;
+
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+ let mut fake_servers = language_registry.register_fake_lsp("Rust", FakeLspAdapter::default());
+
+ let (buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let fake_server = fake_servers.next().await.unwrap();
+ fake_server.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
+ uri: Uri::from_file_path(path!("/dir/a.rs")).unwrap(),
+ version: None,
+ diagnostics: vec![lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 1)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "error before restart".to_string(),
+ ..Default::default()
+ }],
+ });
+ cx.executor().run_until_parked();
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 0,
+ }
+ );
+ });
+
+ let mut events = cx.events(&project);
+
+ project.update(cx, |project, cx| {
+ project.restart_language_servers_for_buffers(vec![buffer.clone()], HashSet::default(), cx);
+ });
+ cx.executor().run_until_parked();
+
+ let mut received_diagnostics_updated = false;
+ while let Some(Some(event)) =
+ futures::FutureExt::now_or_never(futures::StreamExt::next(&mut events))
+ {
+ if matches!(event, Event::DiagnosticsUpdated { .. }) {
+ received_diagnostics_updated = true;
+ }
+ }
+ assert!(
+ received_diagnostics_updated,
+ "DiagnosticsUpdated event should be emitted when a language server is stopped"
+ );
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 0,
+ warning_count: 0,
+ }
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_buffer_reload(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "one two three" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+ let pull_count = Arc::new(atomic::AtomicUsize::new(0));
+ let closure_pull_count = pull_count.clone();
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options(
+ lsp::DiagnosticOptions {
+ identifier: Some("test-reload".to_string()),
+ inter_file_dependencies: true,
+ workspace_diagnostics: false,
+ work_done_progress_options: Default::default(),
+ },
+ )),
+ ..lsp::ServerCapabilities::default()
+ },
+ initializer: Some(Box::new(move |fake_server| {
+ let pull_count = closure_pull_count.clone();
+ fake_server.set_request_handler::<lsp::request::DocumentDiagnosticRequest, _, _>(
+ move |_, _| {
+ let pull_count = pull_count.clone();
+ async move {
+ pull_count.fetch_add(1, atomic::Ordering::SeqCst);
+ Ok(lsp::DocumentDiagnosticReportResult::Report(
+ lsp::DocumentDiagnosticReport::Full(
+ lsp::RelatedFullDocumentDiagnosticReport {
+ related_documents: None,
+ full_document_diagnostic_report:
+ lsp::FullDocumentDiagnosticReport {
+ result_id: None,
+ items: Vec::new(),
+ },
+ },
+ ),
+ ))
+ }
+ },
+ );
+ })),
+ ..FakeLspAdapter::default()
+ },
+ );
+
+ let (_buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let fake_server = fake_servers.next().await.unwrap();
+ cx.executor().run_until_parked();
+
+ // Publish initial diagnostics via the fake server.
+ fake_server.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
+ uri: Uri::from_file_path(path!("/dir/a.rs")).unwrap(),
+ version: None,
+ diagnostics: vec![lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 3)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "error in a".to_string(),
+ ..Default::default()
+ }],
+ });
+ cx.executor().run_until_parked();
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 0,
+ }
+ );
+ });
+
+ let pulls_before = pull_count.load(atomic::Ordering::SeqCst);
+
+ // Change the file on disk. The FS event triggers buffer reload,
+ // which in turn triggers pull_diagnostics_for_buffer.
+ fs.save(
+ path!("/dir/a.rs").as_ref(),
+ &"fixed content".into(),
+ LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ let pulls_after = pull_count.load(atomic::Ordering::SeqCst);
+ assert!(
+ pulls_after > pulls_before,
+ "Expected document diagnostic pull after buffer reload (before={pulls_before}, after={pulls_after})"
+ );
+}
+
#[gpui::test]
async fn test_edits_from_lsp2_with_past_version(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -11623,6 +11883,77 @@ async fn test_undo_encoding_change(cx: &mut gpui::TestAppContext) {
});
}
+#[gpui::test]
+async fn test_initial_scan_complete(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "a": {
+ ".git": {},
+ ".zed": {
+ "tasks.json": r#"[{"label": "task-a", "command": "echo a"}]"#
+ },
+ "src": { "main.rs": "" }
+ },
+ "b": {
+ ".git": {},
+ ".zed": {
+ "tasks.json": r#"[{"label": "task-b", "command": "echo b"}]"#
+ },
+ "src": { "lib.rs": "" }
+ },
+ }),
+ )
+ .await;
+
+ let repos_created = Rc::new(RefCell::new(Vec::new()));
+ let _observe = {
+ let repos_created = repos_created.clone();
+ cx.update(|cx| {
+ cx.observe_new::<Repository>(move |repo, _, cx| {
+ repos_created.borrow_mut().push(cx.entity().downgrade());
+ let _ = repo;
+ })
+ })
+ };
+
+ let project = Project::test(
+ fs.clone(),
+ [path!("/root/a").as_ref(), path!("/root/b").as_ref()],
+ cx,
+ )
+ .await;
+
+ let scan_complete = project.read_with(cx, |project, cx| project.wait_for_initial_scan(cx));
+ scan_complete.await;
+
+ project.read_with(cx, |project, cx| {
+ assert!(
+ project.worktree_store().read(cx).initial_scan_completed(),
+ "Expected initial scan to be completed after awaiting wait_for_initial_scan"
+ );
+ });
+
+ let created_repos_len = repos_created.borrow().len();
+ assert_eq!(
+ created_repos_len, 2,
+ "Expected 2 repositories to be created during scan, got {}",
+ created_repos_len
+ );
+
+ project.read_with(cx, |project, cx| {
+ let git_store = project.git_store().read(cx);
+ assert_eq!(
+ git_store.repositories().len(),
+ 2,
+ "Expected 2 repositories in GitStore"
+ );
+ });
+}
+
pub fn init_test(cx: &mut gpui::TestAppContext) {
zlog::init_test();
@@ -47,6 +47,7 @@ impl SidebarRecentProjects {
workspaces: Vec::new(),
filtered_workspaces: Vec::new(),
selected_index: 0,
+ has_any_non_local_projects: false,
focus_handle: cx.focus_handle(),
};
@@ -122,6 +123,7 @@ pub struct SidebarRecentProjectsDelegate {
)>,
filtered_workspaces: Vec<StringMatch>,
selected_index: usize,
+ has_any_non_local_projects: bool,
focus_handle: FocusHandle,
}
@@ -135,6 +137,9 @@ impl SidebarRecentProjectsDelegate {
DateTime<Utc>,
)>,
) {
+ self.has_any_non_local_projects = workspaces
+ .iter()
+ .any(|(_, location, _, _)| !matches!(location, SerializedWorkspaceLocation::Local));
self.workspaces = workspaces;
}
}
@@ -383,7 +388,9 @@ impl PickerDelegate for SidebarRecentProjectsDelegate {
h_flex()
.gap_3()
.flex_grow()
- .child(Icon::new(icon).color(Color::Muted))
+ .when(self.has_any_non_local_projects, |this| {
+ this.child(Icon::new(icon).color(Color::Muted))
+ })
.child(highlighted_match.render(window, cx)),
)
.tooltip(Tooltip::text(tooltip_path))
@@ -12,7 +12,7 @@ workspace = true
path = "src/rope.rs"
[dependencies]
-arrayvec = "0.7.1"
+heapless.workspace = true
log.workspace = true
rayon.workspace = true
sum_tree.workspace = true
@@ -1,5 +1,5 @@
use crate::{OffsetUtf16, Point, PointUtf16, TextSummary, Unclipped};
-use arrayvec::ArrayString;
+use heapless::String as ArrayString;
use std::{cmp, ops::Range};
use sum_tree::Bias;
use unicode_segmentation::GraphemeCursor;
@@ -29,7 +29,7 @@ pub struct Chunk {
newlines: Bitmap,
/// If bit[i] is set, then the character at index i is an ascii tab.
tabs: Bitmap,
- pub text: ArrayString<MAX_BASE>,
+ pub text: ArrayString<MAX_BASE, u8>,
}
#[inline(always)]
@@ -47,7 +47,11 @@ impl Chunk {
#[inline(always)]
pub fn new(text: &str) -> Self {
- let text = ArrayString::from(text).unwrap();
+ let text = {
+ let mut buf = ArrayString::new();
+ buf.push_str(text).unwrap();
+ buf
+ };
const CHUNK_SIZE: usize = 8;
@@ -118,7 +122,7 @@ impl Chunk {
self.chars_utf16 |= slice.chars_utf16 << base_ix;
self.newlines |= slice.newlines << base_ix;
self.tabs |= slice.tabs << base_ix;
- self.text.push_str(slice.text);
+ self.text.push_str(slice.text).unwrap();
}
#[inline(always)]
@@ -137,9 +141,9 @@ impl Chunk {
self.newlines = slice.newlines | (self.newlines << shift);
self.tabs = slice.tabs | (self.tabs << shift);
- let mut new_text = ArrayString::<MAX_BASE>::new();
- new_text.push_str(slice.text);
- new_text.push_str(&self.text);
+ let mut new_text = ArrayString::<MAX_BASE, u8>::new();
+ new_text.push_str(slice.text).unwrap();
+ new_text.push_str(&self.text).unwrap();
self.text = new_text;
}
@@ -4,7 +4,7 @@ mod point;
mod point_utf16;
mod unclipped;
-use arrayvec::ArrayVec;
+use heapless::Vec as ArrayVec;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use std::{
cmp, fmt, io, mem,
@@ -184,7 +184,7 @@ impl Rope {
return self.push_large(text);
}
// 16 is enough as otherwise we will hit the branch above
- let mut new_chunks = ArrayVec::<_, NUM_CHUNKS>::new();
+ let mut new_chunks = ArrayVec::<_, NUM_CHUNKS, u8>::new();
while !text.is_empty() {
let mut split_ix = cmp::min(chunk::MAX_BASE, text.len());
@@ -192,7 +192,7 @@ impl Rope {
split_ix -= 1;
}
let (chunk, remainder) = text.split_at(split_ix);
- new_chunks.push(chunk);
+ new_chunks.push(chunk).unwrap();
text = remainder;
}
self.chunks
@@ -976,7 +976,9 @@ impl BufferSearchBar {
if deploy.focus {
let mut handle = self.query_editor.focus_handle(cx);
let mut select_query = true;
- if deploy.replace_enabled && handle.is_focused(window) {
+
+ let has_seed_text = self.query_suggestion(window, cx).is_some();
+ if deploy.replace_enabled && has_seed_text {
handle = self.replacement_editor.focus_handle(cx);
select_query = false;
};
@@ -3188,6 +3190,47 @@ mod tests {
.await;
}
+ #[gpui::test]
+ async fn test_deploy_replace_focuses_replacement_editor(cx: &mut TestAppContext) {
+ init_globals(cx);
+ let (editor, search_bar, cx) = init_test(cx);
+
+ editor.update_in(cx, |editor, window, cx| {
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_display_ranges([
+ DisplayPoint::new(DisplayRow(0), 8)..DisplayPoint::new(DisplayRow(0), 16)
+ ])
+ });
+ });
+
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ search_bar.deploy(
+ &Deploy {
+ focus: true,
+ replace_enabled: true,
+ selection_search_enabled: false,
+ },
+ window,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ assert!(
+ search_bar
+ .replacement_editor
+ .focus_handle(cx)
+ .is_focused(window),
+ "replacement editor should be focused when deploying replace with a selection",
+ );
+ assert!(
+ !search_bar.query_editor.focus_handle(cx).is_focused(window),
+ "search editor should not be focused when replacement editor is focused",
+ );
+ });
+ }
+
#[perf]
#[gpui::test]
async fn test_find_matches_in_selections_singleton_buffer_multiple_selections(
@@ -4,7 +4,7 @@ use fs::Fs;
use gpui::{
Action, ActionBuildError, App, InvalidKeystrokeError, KEYSTROKE_PARSE_EXPECTED_MESSAGE,
KeyBinding, KeyBindingContextPredicate, KeyBindingMetaIndex, KeybindingKeystroke, Keystroke,
- NoAction, SharedString, generate_list_of_all_registered_actions, register_action,
+ NoAction, SharedString, Unbind, generate_list_of_all_registered_actions, register_action,
};
use schemars::{JsonSchema, json_schema};
use serde::Deserialize;
@@ -73,6 +73,10 @@ pub struct KeymapSection {
/// on macOS. See the documentation for more details.
#[serde(default)]
use_key_equivalents: bool,
+ /// This keymap section's unbindings, as a JSON object mapping keystrokes to actions. These are
+ /// parsed before `bindings`, so bindings later in the same section can still take precedence.
+ #[serde(default)]
+ unbind: Option<IndexMap<String, UnbindTargetAction>>,
/// This keymap section's bindings, as a JSON object mapping keystrokes to actions. The
/// keystrokes key is a string representing a sequence of keystrokes to type, where the
/// keystrokes are separated by whitespace. Each keystroke is a sequence of modifiers (`ctrl`,
@@ -135,6 +139,20 @@ impl JsonSchema for KeymapAction {
}
}
+#[derive(Debug, Deserialize, Default, Clone)]
+#[serde(transparent)]
+pub struct UnbindTargetAction(Value);
+
+impl JsonSchema for UnbindTargetAction {
+ fn schema_name() -> Cow<'static, str> {
+ "UnbindTargetAction".into()
+ }
+
+ fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
+ json_schema!(true)
+ }
+}
+
#[derive(Debug)]
#[must_use]
pub enum KeymapFileLoadResult {
@@ -231,6 +249,7 @@ impl KeymapFile {
for KeymapSection {
context,
use_key_equivalents,
+ unbind,
bindings,
unrecognized_fields,
} in keymap_file.0.iter()
@@ -244,7 +263,7 @@ impl KeymapFile {
// Leading space is to separate from the message indicating which section
// the error occurred in.
errors.push((
- context,
+ context.clone(),
format!(" Parse error in section `context` field: {}", err),
));
continue;
@@ -263,6 +282,38 @@ impl KeymapFile {
.unwrap();
}
+ if let Some(unbind) = unbind {
+ for (keystrokes, action) in unbind {
+ let result = Self::load_unbinding(
+ keystrokes,
+ action,
+ context_predicate.clone(),
+ *use_key_equivalents,
+ cx,
+ );
+ match result {
+ Ok(key_binding) => {
+ key_bindings.push(key_binding);
+ }
+ Err(err) => {
+ let mut lines = err.lines();
+ let mut indented_err = lines.next().unwrap().to_string();
+ for line in lines {
+ indented_err.push_str(" ");
+ indented_err.push_str(line);
+ indented_err.push_str("\n");
+ }
+ write!(
+ section_errors,
+ "\n\n- In unbind {}, {indented_err}",
+ MarkdownInlineCode(&format!("\"{}\"", keystrokes))
+ )
+ .unwrap();
+ }
+ }
+ }
+ }
+
if let Some(bindings) = bindings {
for (keystrokes, action) in bindings {
let result = Self::load_keybinding(
@@ -296,7 +347,7 @@ impl KeymapFile {
}
if !section_errors.is_empty() {
- errors.push((context, section_errors))
+ errors.push((context.clone(), section_errors))
}
}
@@ -332,7 +383,17 @@ impl KeymapFile {
use_key_equivalents: bool,
cx: &App,
) -> std::result::Result<KeyBinding, String> {
- let (action, action_input_string) = Self::build_keymap_action(action, cx)?;
+ Self::load_keybinding_action_value(keystrokes, &action.0, context, use_key_equivalents, cx)
+ }
+
+ fn load_keybinding_action_value(
+ keystrokes: &str,
+ action: &Value,
+ context: Option<Rc<KeyBindingContextPredicate>>,
+ use_key_equivalents: bool,
+ cx: &App,
+ ) -> std::result::Result<KeyBinding, String> {
+ let (action, action_input_string) = Self::build_keymap_action_value(action, cx)?;
let key_binding = match KeyBinding::load(
keystrokes,
@@ -362,23 +423,70 @@ impl KeymapFile {
}
}
+ fn load_unbinding(
+ keystrokes: &str,
+ action: &UnbindTargetAction,
+ context: Option<Rc<KeyBindingContextPredicate>>,
+ use_key_equivalents: bool,
+ cx: &App,
+ ) -> std::result::Result<KeyBinding, String> {
+ let key_binding = Self::load_keybinding_action_value(
+ keystrokes,
+ &action.0,
+ context,
+ use_key_equivalents,
+ cx,
+ )?;
+
+ if key_binding.action().partial_eq(&NoAction) {
+ return Err("expected action name string or [name, input] array.".to_string());
+ }
+
+ if key_binding.action().name() == Unbind::name_for_type() {
+ return Err(format!(
+ "can't use {} as an unbind target.",
+ MarkdownInlineCode(&format!("\"{}\"", Unbind::name_for_type()))
+ ));
+ }
+
+ KeyBinding::load(
+ keystrokes,
+ Box::new(Unbind(key_binding.action().name().into())),
+ key_binding.predicate(),
+ use_key_equivalents,
+ key_binding.action_input(),
+ cx.keyboard_mapper().as_ref(),
+ )
+ .map_err(|InvalidKeystrokeError { keystroke }| {
+ format!(
+ "invalid keystroke {}. {}",
+ MarkdownInlineCode(&format!("\"{}\"", &keystroke)),
+ KEYSTROKE_PARSE_EXPECTED_MESSAGE
+ )
+ })
+ }
+
pub fn parse_action(
action: &KeymapAction,
) -> Result<Option<(&String, Option<&Value>)>, String> {
- let name_and_input = match &action.0 {
+ Self::parse_action_value(&action.0)
+ }
+
+ fn parse_action_value(action: &Value) -> Result<Option<(&String, Option<&Value>)>, String> {
+ let name_and_input = match action {
Value::Array(items) => {
if items.len() != 2 {
return Err(format!(
"expected two-element array of `[name, input]`. \
Instead found {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
}
let serde_json::Value::String(ref name) = items[0] else {
return Err(format!(
"expected two-element array of `[name, input]`, \
but the first element is not a string in {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
};
Some((name, Some(&items[1])))
@@ -389,7 +497,7 @@ impl KeymapFile {
return Err(format!(
"expected two-element array of `[name, input]`. \
Instead found {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
}
};
@@ -400,7 +508,14 @@ impl KeymapFile {
action: &KeymapAction,
cx: &App,
) -> std::result::Result<(Box<dyn Action>, Option<String>), String> {
- let (build_result, action_input_string) = match Self::parse_action(action)? {
+ Self::build_keymap_action_value(&action.0, cx)
+ }
+
+ fn build_keymap_action_value(
+ action: &Value,
+ cx: &App,
+ ) -> std::result::Result<(Box<dyn Action>, Option<String>), String> {
+ let (build_result, action_input_string) = match Self::parse_action_value(action)? {
Some((name, action_input)) if name.as_str() == ActionSequence::name_for_type() => {
match action_input {
Some(action_input) => (
@@ -583,9 +698,15 @@ impl KeymapFile {
"minItems": 2,
"maxItems": 2
});
- let mut keymap_action_alternatives = vec![empty_action_name, empty_action_name_with_input];
+ let mut keymap_action_alternatives = vec![
+ empty_action_name.clone(),
+ empty_action_name_with_input.clone(),
+ ];
+ let mut unbind_target_action_alternatives =
+ vec![empty_action_name, empty_action_name_with_input];
let mut empty_schema_action_names = vec![];
+ let mut empty_schema_unbind_target_action_names = vec![];
for (name, action_schema) in action_schemas.into_iter() {
let deprecation = if name == NoAction.name() {
Some("null")
@@ -593,6 +714,9 @@ impl KeymapFile {
deprecations.get(name).copied()
};
+ let include_in_unbind_target_schema =
+ name != NoAction.name() && name != Unbind::name_for_type();
+
// Add an alternative for plain action names.
let mut plain_action = json_schema!({
"type": "string",
@@ -607,7 +731,10 @@ impl KeymapFile {
if let Some(description) = &description {
add_description(&mut plain_action, description);
}
- keymap_action_alternatives.push(plain_action);
+ keymap_action_alternatives.push(plain_action.clone());
+ if include_in_unbind_target_schema {
+ unbind_target_action_alternatives.push(plain_action);
+ }
// Add an alternative for actions with data specified as a [name, data] array.
//
@@ -633,9 +760,15 @@ impl KeymapFile {
"minItems": 2,
"maxItems": 2
});
- keymap_action_alternatives.push(action_with_input);
+ keymap_action_alternatives.push(action_with_input.clone());
+ if include_in_unbind_target_schema {
+ unbind_target_action_alternatives.push(action_with_input);
+ }
} else {
empty_schema_action_names.push(name);
+ if include_in_unbind_target_schema {
+ empty_schema_unbind_target_action_names.push(name);
+ }
}
}
@@ -659,20 +792,44 @@ impl KeymapFile {
keymap_action_alternatives.push(actions_with_empty_input);
}
+ if !empty_schema_unbind_target_action_names.is_empty() {
+ let action_names = json_schema!({ "enum": empty_schema_unbind_target_action_names });
+ let no_properties_allowed = json_schema!({
+ "type": "object",
+ "additionalProperties": false
+ });
+ let mut actions_with_empty_input = json_schema!({
+ "type": "array",
+ "items": [action_names, no_properties_allowed],
+ "minItems": 2,
+ "maxItems": 2
+ });
+ add_deprecation(
+ &mut actions_with_empty_input,
+ "This action does not take input - just the action name string should be used."
+ .to_string(),
+ );
+ unbind_target_action_alternatives.push(actions_with_empty_input);
+ }
+
// Placing null first causes json-language-server to default assuming actions should be
// null, so place it last.
keymap_action_alternatives.push(json_schema!({
"type": "null"
}));
- // The `KeymapSection` schema will reference the `KeymapAction` schema by name, so setting
- // the definition of `KeymapAction` results in the full action schema being used.
generator.definitions_mut().insert(
KeymapAction::schema_name().to_string(),
json!({
"anyOf": keymap_action_alternatives
}),
);
+ generator.definitions_mut().insert(
+ UnbindTargetAction::schema_name().to_string(),
+ json!({
+ "anyOf": unbind_target_action_alternatives
+ }),
+ );
generator.root_schema_for::<KeymapFile>().to_value()
}
@@ -1260,7 +1417,8 @@ impl Action for ActionSequence {
#[cfg(test)]
mod tests {
- use gpui::{DummyKeyboardMapper, KeybindingKeystroke, Keystroke};
+ use gpui::{Action, App, DummyKeyboardMapper, KeybindingKeystroke, Keystroke, Unbind};
+ use serde_json::Value;
use unindent::Unindent;
use crate::{
@@ -1268,6 +1426,8 @@ mod tests {
keymap_file::{KeybindUpdateOperation, KeybindUpdateTarget},
};
+ gpui::actions!(test_keymap_file, [StringAction, InputAction]);
+
#[test]
fn can_deserialize_keymap_with_trailing_comma() {
let json = indoc::indoc! {"[
@@ -1283,6 +1443,191 @@ mod tests {
KeymapFile::parse(json).unwrap();
}
+ #[gpui::test]
+ fn keymap_section_unbinds_are_loaded_before_bindings(cx: &mut App) {
+ let key_bindings = match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": "test_keymap_file::StringAction",
+ "ctrl-b": ["test_keymap_file::InputAction", {}]
+ },
+ "bindings": {
+ "ctrl-c": "test_keymap_file::StringAction"
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::Success { key_bindings } => key_bindings,
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ error_message, ..
+ } => {
+ panic!("{error_message}");
+ }
+ crate::keymap_file::KeymapFileLoadResult::JsonParseFailure { error } => {
+ panic!("JSON parse error: {error}");
+ }
+ };
+
+ assert_eq!(key_bindings.len(), 3);
+ assert!(
+ key_bindings[0]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::StringAction".into()))
+ );
+ assert_eq!(key_bindings[0].action_input(), None);
+ assert!(
+ key_bindings[1]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::InputAction".into()))
+ );
+ assert_eq!(
+ key_bindings[1]
+ .action_input()
+ .as_ref()
+ .map(ToString::to_string),
+ Some("{}".to_string())
+ );
+ assert_eq!(
+ key_bindings[2].action().name(),
+ "test_keymap_file::StringAction"
+ );
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_loads_valid_target_action_with_input(cx: &mut App) {
+ let key_bindings = match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": ["test_keymap_file::InputAction", {}]
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::Success { key_bindings } => key_bindings,
+ other => panic!("expected Success, got {other:?}"),
+ };
+
+ assert_eq!(key_bindings.len(), 1);
+ assert!(
+ key_bindings[0]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::InputAction".into()))
+ );
+ assert_eq!(
+ key_bindings[0]
+ .action_input()
+ .as_ref()
+ .map(ToString::to_string),
+ Some("{}".to_string())
+ );
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_rejects_null(cx: &mut App) {
+ match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": null
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ key_bindings,
+ error_message,
+ } => {
+ assert!(key_bindings.is_empty());
+ assert!(
+ error_message
+ .0
+ .contains("expected action name string or [name, input] array.")
+ );
+ }
+ other => panic!("expected SomeFailedToLoad, got {other:?}"),
+ }
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_rejects_unbind_action(cx: &mut App) {
+ match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": ["zed::Unbind", "test_keymap_file::StringAction"]
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ key_bindings,
+ error_message,
+ } => {
+ assert!(key_bindings.is_empty());
+ assert!(
+ error_message
+ .0
+ .contains("can't use `\"zed::Unbind\"` as an unbind target.")
+ );
+ }
+ other => panic!("expected SomeFailedToLoad, got {other:?}"),
+ }
+ }
+
+ #[test]
+ fn keymap_schema_for_unbind_excludes_null_and_unbind_action() {
+ fn schema_allows(schema: &Value, expected: &Value) -> bool {
+ match schema {
+ Value::Object(object) => {
+ if object.get("const") == Some(expected) {
+ return true;
+ }
+ if object.get("type") == Some(&Value::String("null".to_string()))
+ && expected == &Value::Null
+ {
+ return true;
+ }
+ object.values().any(|value| schema_allows(value, expected))
+ }
+ Value::Array(items) => items.iter().any(|value| schema_allows(value, expected)),
+ _ => false,
+ }
+ }
+
+ let schema = KeymapFile::generate_json_schema_from_inventory();
+ let unbind_schema = schema
+ .pointer("/$defs/UnbindTargetAction")
+ .expect("missing UnbindTargetAction schema");
+
+ assert!(!schema_allows(unbind_schema, &Value::Null));
+ assert!(!schema_allows(
+ unbind_schema,
+ &Value::String(Unbind::name_for_type().to_string())
+ ));
+ assert!(schema_allows(
+ unbind_schema,
+ &Value::String("test_keymap_file::StringAction".to_string())
+ ));
+ assert!(schema_allows(
+ unbind_schema,
+ &Value::String("test_keymap_file::InputAction".to_string())
+ ));
+ }
+
#[track_caller]
fn check_keymap_update(
input: impl ToString,
@@ -16,6 +16,7 @@ pub struct AllLanguageModelSettingsContent {
pub lmstudio: Option<LmStudioSettingsContent>,
pub mistral: Option<MistralSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
+ pub opencode: Option<OpenCodeSettingsContent>,
pub open_router: Option<OpenRouterSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>,
@@ -144,6 +145,24 @@ impl Default for KeepAlive {
}
}
+#[with_fallible_options]
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
+pub struct OpenCodeSettingsContent {
+ pub api_url: Option<String>,
+ pub available_models: Option<Vec<OpenCodeAvailableModel>>,
+}
+
+#[with_fallible_options]
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom)]
+pub struct OpenCodeAvailableModel {
+ pub name: String,
+ pub display_name: Option<String>,
+ pub max_tokens: u64,
+ pub max_output_tokens: Option<u64>,
+ /// The API protocol to use for this model: "anthropic", "openai_responses", "openai_chat", or "google".
+ pub protocol: String,
+}
+
#[with_fallible_options]
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
pub struct LmStudioSettingsContent {
@@ -9,6 +9,7 @@ mod project;
mod serde_helper;
mod terminal;
mod theme;
+mod title_bar;
mod workspace;
pub use agent::*;
@@ -26,6 +27,7 @@ pub use serde_helper::{
use settings_json::parse_json_with_comments;
pub use terminal::*;
pub use theme::*;
+pub use title_bar::*;
pub use workspace::*;
use collections::{HashMap, IndexMap};
@@ -316,43 +318,6 @@ impl strum::VariantNames for BaseKeymapContent {
];
}
-#[with_fallible_options]
-#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
-pub struct TitleBarSettingsContent {
- /// Whether to show the branch icon beside branch switcher in the title bar.
- ///
- /// Default: false
- pub show_branch_icon: Option<bool>,
- /// Whether to show onboarding banners in the title bar.
- ///
- /// Default: true
- pub show_onboarding_banner: Option<bool>,
- /// Whether to show user avatar in the title bar.
- ///
- /// Default: true
- pub show_user_picture: Option<bool>,
- /// Whether to show the branch name button in the titlebar.
- ///
- /// Default: true
- pub show_branch_name: Option<bool>,
- /// Whether to show the project host and name in the titlebar.
- ///
- /// Default: true
- pub show_project_items: Option<bool>,
- /// Whether to show the sign in button in the title bar.
- ///
- /// Default: true
- pub show_sign_in: Option<bool>,
- /// Whether to show the user menu button in the title bar.
- ///
- /// Default: true
- pub show_user_menu: Option<bool>,
- /// Whether to show the menus in the title bar.
- ///
- /// Default: false
- pub show_menus: Option<bool>,
-}
-
/// Configuration of audio in Zed.
#[with_fallible_options]
#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
@@ -0,0 +1,124 @@
+use gpui::WindowButtonLayout;
+use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
+use serde::{Deserialize, Serialize};
+use settings_macros::{MergeFrom, with_fallible_options};
+
+/// The layout of window control buttons as represented by user settings.
+///
+/// Custom layout strings use the GNOME `button-layout` format (e.g.
+/// `"close:minimize,maximize"`).
+#[derive(
+ Clone,
+ PartialEq,
+ Debug,
+ Serialize,
+ Deserialize,
+ JsonSchema,
+ MergeFrom,
+ Default,
+ strum::EnumDiscriminants,
+)]
+#[strum_discriminants(derive(strum::VariantArray, strum::VariantNames, strum::FromRepr))]
+#[schemars(schema_with = "window_button_layout_schema")]
+#[serde(from = "String", into = "String")]
+pub enum WindowButtonLayoutContent {
+ /// Follow the system/desktop configuration.
+ #[default]
+ PlatformDefault,
+ /// Use Zed's built-in standard layout, regardless of system config.
+ Standard,
+ /// A raw GNOME-style layout string.
+ Custom(String),
+}
+
+impl WindowButtonLayoutContent {
+ #[cfg(any(target_os = "linux", target_os = "freebsd"))]
+ pub fn into_layout(self) -> Option<WindowButtonLayout> {
+ use util::ResultExt;
+
+ match self {
+ Self::PlatformDefault => None,
+ Self::Standard => Some(WindowButtonLayout::linux_default()),
+ Self::Custom(layout) => WindowButtonLayout::parse(&layout).log_err(),
+ }
+ }
+
+ #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
+ pub fn into_layout(self) -> Option<WindowButtonLayout> {
+ None
+ }
+}
+
+fn window_button_layout_schema(_: &mut SchemaGenerator) -> Schema {
+ json_schema!({
+ "anyOf": [
+ { "enum": ["platform_default", "standard"] },
+ { "type": "string" }
+ ]
+ })
+}
+
+impl From<WindowButtonLayoutContent> for String {
+ fn from(value: WindowButtonLayoutContent) -> Self {
+ match value {
+ WindowButtonLayoutContent::PlatformDefault => "platform_default".to_string(),
+ WindowButtonLayoutContent::Standard => "standard".to_string(),
+ WindowButtonLayoutContent::Custom(s) => s,
+ }
+ }
+}
+
+impl From<String> for WindowButtonLayoutContent {
+ fn from(layout_string: String) -> Self {
+ match layout_string.as_str() {
+ "platform_default" => Self::PlatformDefault,
+ "standard" => Self::Standard,
+ _ => Self::Custom(layout_string),
+ }
+ }
+}
+
+#[with_fallible_options]
+#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
+pub struct TitleBarSettingsContent {
+ /// Whether to show the branch icon beside branch switcher in the title bar.
+ ///
+ /// Default: false
+ pub show_branch_icon: Option<bool>,
+ /// Whether to show onboarding banners in the title bar.
+ ///
+ /// Default: true
+ pub show_onboarding_banner: Option<bool>,
+ /// Whether to show user avatar in the title bar.
+ ///
+ /// Default: true
+ pub show_user_picture: Option<bool>,
+ /// Whether to show the branch name button in the titlebar.
+ ///
+ /// Default: true
+ pub show_branch_name: Option<bool>,
+ /// Whether to show the project host and name in the titlebar.
+ ///
+ /// Default: true
+ pub show_project_items: Option<bool>,
+ /// Whether to show the sign in button in the title bar.
+ ///
+ /// Default: true
+ pub show_sign_in: Option<bool>,
+ /// Whether to show the user menu button in the title bar.
+ ///
+ /// Default: true
+ pub show_user_menu: Option<bool>,
+ /// Whether to show the menus in the title bar.
+ ///
+ /// Default: false
+ pub show_menus: Option<bool>,
+ /// The layout of window control buttons in the title bar (Linux only).
+ ///
+ /// This can be set to "platform_default" to follow the system configuration, or
+ /// "standard" to use Zed's built-in layout. For custom layouts, use a
+ /// GNOME-style layout string like "close:minimize,maximize".
+ ///
+ /// Default: "platform_default"
+ pub button_layout: Option<WindowButtonLayoutContent>,
+}
@@ -3481,7 +3481,7 @@ fn window_and_layout_page() -> SettingsPage {
]
}
- fn title_bar_section() -> [SettingsPageItem; 9] {
+ fn title_bar_section() -> [SettingsPageItem; 10] {
[
SettingsPageItem::SectionHeader("Title Bar"),
SettingsPageItem::SettingItem(SettingItem {
@@ -3648,6 +3648,122 @@ fn window_and_layout_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::DynamicItem(DynamicItem {
+ discriminant: SettingItem {
+ files: USER,
+ title: "Button Layout",
+ description:
+ "(Linux only) choose how window control buttons are laid out in the titlebar.",
+ field: Box::new(SettingField {
+ json_path: Some("title_bar.button_layout$"),
+ pick: |settings_content| {
+ Some(
+ &dynamic_variants::<settings::WindowButtonLayoutContent>()[settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ .discriminant()
+ as usize],
+ )
+ },
+ write: |settings_content, value| {
+ let Some(value) = value else {
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = None;
+ return;
+ };
+
+ let current_custom_layout = settings_content
+ .title_bar
+ .as_ref()
+ .and_then(|title_bar| title_bar.button_layout.as_ref())
+ .and_then(|button_layout| match button_layout {
+ settings::WindowButtonLayoutContent::Custom(layout) => {
+ Some(layout.clone())
+ }
+ _ => None,
+ });
+
+ let button_layout = match value {
+ settings::WindowButtonLayoutContentDiscriminants::PlatformDefault => {
+ settings::WindowButtonLayoutContent::PlatformDefault
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Standard => {
+ settings::WindowButtonLayoutContent::Standard
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Custom => {
+ settings::WindowButtonLayoutContent::Custom(
+ current_custom_layout.unwrap_or_else(|| {
+ "close:minimize,maximize".to_string()
+ }),
+ )
+ }
+ };
+
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = Some(button_layout);
+ },
+ }),
+ metadata: None,
+ },
+ pick_discriminant: |settings_content| {
+ Some(
+ settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ .discriminant() as usize,
+ )
+ },
+ fields: dynamic_variants::<settings::WindowButtonLayoutContent>()
+ .into_iter()
+ .map(|variant| match variant {
+ settings::WindowButtonLayoutContentDiscriminants::PlatformDefault => {
+ vec![]
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Standard => vec![],
+ settings::WindowButtonLayoutContentDiscriminants::Custom => vec![
+ SettingItem {
+ files: USER,
+ title: "Custom Button Layout",
+ description:
+ "GNOME-style layout string such as \"close:minimize,maximize\".",
+ field: Box::new(SettingField {
+ json_path: Some("title_bar.button_layout"),
+ pick: |settings_content| match settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ {
+ settings::WindowButtonLayoutContent::Custom(layout) => {
+ Some(layout)
+ }
+ _ => DEFAULT_EMPTY_STRING,
+ },
+ write: |settings_content, value| {
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = value
+ .map(settings::WindowButtonLayoutContent::Custom);
+ },
+ }),
+ metadata: Some(Box::new(SettingsFieldMetadata {
+ placeholder: Some("close:minimize,maximize"),
+ ..Default::default()
+ })),
+ },
+ ],
+ })
+ .collect(),
+ }),
]
}
@@ -545,6 +545,7 @@ fn init_renderers(cx: &mut App) {
.add_basic_renderer::<settings::EditPredictionsMode>(render_dropdown)
.add_basic_renderer::<settings::RelativeLineNumbers>(render_dropdown)
.add_basic_renderer::<settings::WindowDecorations>(render_dropdown)
+ .add_basic_renderer::<settings::WindowButtonLayoutContentDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::FontSize>(render_editable_number_field)
.add_basic_renderer::<settings::OllamaModelName>(render_ollama_model_picker)
.add_basic_renderer::<settings::SemanticTokens>(render_dropdown)
@@ -5,7 +5,9 @@ use agent_ui::thread_metadata_store::{SidebarThreadMetadataStore, ThreadMetadata
use agent_ui::threads_archive_view::{
ThreadsArchiveView, ThreadsArchiveViewEvent, format_history_entry_timestamp,
};
-use agent_ui::{Agent, AgentPanel, AgentPanelEvent, NewThread, RemoveSelectedThread};
+use agent_ui::{
+ Agent, AgentPanel, AgentPanelEvent, DEFAULT_THREAD_TITLE, NewThread, RemoveSelectedThread,
+};
use chrono::Utc;
use editor::Editor;
use feature_flags::{AgentV2FeatureFlag, FeatureFlagViewExt as _};
@@ -29,8 +31,7 @@ use std::sync::Arc;
use theme::ActiveTheme;
use ui::{
AgentThreadStatus, CommonAnimationExt, ContextMenu, Divider, HighlightedLabel, KeyBinding,
- ListItem, PopoverMenu, PopoverMenuHandle, Tab, ThreadItem, TintColor, Tooltip, WithScrollbar,
- prelude::*,
+ PopoverMenu, PopoverMenuHandle, Tab, ThreadItem, TintColor, Tooltip, WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use util::path_list::PathList;
@@ -110,6 +111,7 @@ struct ThreadEntry {
is_title_generating: bool,
highlight_positions: Vec<usize>,
worktree_name: Option<SharedString>,
+ worktree_full_path: Option<SharedString>,
worktree_highlight_positions: Vec<usize>,
diff_stats: DiffStats,
}
@@ -127,12 +129,12 @@ enum ListEntry {
Thread(ThreadEntry),
ViewMore {
path_list: PathList,
- remaining_count: usize,
is_fully_expanded: bool,
},
NewThread {
path_list: PathList,
workspace: Entity<Workspace>,
+ is_active_draft: bool,
},
}
@@ -471,8 +473,15 @@ impl Sidebar {
cx.subscribe_in(
agent_panel,
window,
- |this, _agent_panel, event: &AgentPanelEvent, _window, cx| match event {
+ |this, agent_panel, event: &AgentPanelEvent, _window, cx| match event {
AgentPanelEvent::ActiveViewChanged => {
+ let is_new_draft = agent_panel
+ .read(cx)
+ .active_conversation_view()
+ .is_some_and(|cv| cv.read(cx).parent_id(cx).is_none());
+ if is_new_draft {
+ this.focused_thread = None;
+ }
this.observe_draft_editor(cx);
this.update_entries(cx);
}
@@ -485,16 +494,19 @@ impl Sidebar {
}
fn observe_docks(&mut self, workspace: &Entity<Workspace>, cx: &mut Context<Self>) {
- let workspace = workspace.clone();
let docks: Vec<_> = workspace
.read(cx)
.all_docks()
.into_iter()
.cloned()
.collect();
+ let workspace = workspace.downgrade();
for dock in docks {
let workspace = workspace.clone();
cx.observe(&dock, move |this, _dock, cx| {
+ let Some(workspace) = workspace.upgrade() else {
+ return;
+ };
if !this.is_active_workspace(&workspace, cx) {
return;
}
@@ -519,7 +531,7 @@ impl Sidebar {
ws.read(cx).panel::<AgentPanel>(cx)
})
.and_then(|panel| {
- let cv = panel.read(cx).active_conversation()?;
+ let cv = panel.read(cx).active_conversation_view()?;
let tv = cv.read(cx).active_thread()?;
Some(tv.read(cx).message_editor.clone())
})
@@ -534,7 +546,7 @@ impl Sidebar {
let mw = self.multi_workspace.upgrade()?;
let workspace = mw.read(cx).workspace();
let panel = workspace.read(cx).panel::<AgentPanel>(cx)?;
- let conversation_view = panel.read(cx).active_conversation()?;
+ let conversation_view = panel.read(cx).active_conversation_view()?;
let thread_view = conversation_view.read(cx).active_thread()?;
let raw = thread_view.read(cx).message_editor.read(cx).text(cx);
let cleaned = Self::clean_mention_links(&raw);
@@ -592,7 +604,9 @@ impl Sidebar {
let icon = thread_view_ref.agent_icon;
let icon_from_external_svg = thread_view_ref.agent_icon_from_external_svg.clone();
- let title = thread.title();
+ let title = thread
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into());
let is_native = thread_view_ref.as_native_thread(cx).is_some();
let is_title_generating = is_native && thread.has_provisional_title();
let session_id = thread.session_id().clone();
@@ -642,6 +656,19 @@ impl Sidebar {
let query = self.filter_editor.read(cx).text(cx);
+ // Re-derive agent_panel_visible from the active workspace so it stays
+ // correct after workspace switches.
+ self.agent_panel_visible = active_workspace
+ .as_ref()
+ .map_or(false, |ws| AgentPanel::is_visible(ws, cx));
+
+ // Derive active_thread_is_draft BEFORE focused_thread so we can
+ // use it as a guard below.
+ self.active_thread_is_draft = active_workspace
+ .as_ref()
+ .and_then(|ws| ws.read(cx).panel::<AgentPanel>(cx))
+ .map_or(false, |panel| panel.read(cx).active_thread_is_draft(cx));
+
// Derive focused_thread from the active workspace's agent panel.
// Only update when the panel gives us a positive signal β if the
// panel returns None (e.g. still loading after a thread activation),
@@ -652,24 +679,13 @@ impl Sidebar {
.and_then(|panel| {
panel
.read(cx)
- .active_conversation()
+ .active_conversation_view()
.and_then(|cv| cv.read(cx).parent_id(cx))
});
- if panel_focused.is_some() {
+ if panel_focused.is_some() && !self.active_thread_is_draft {
self.focused_thread = panel_focused;
}
- // Re-derive agent_panel_visible from the active workspace so it stays
- // correct after workspace switches.
- self.agent_panel_visible = active_workspace
- .as_ref()
- .map_or(false, |ws| AgentPanel::is_visible(ws, cx));
-
- self.active_thread_is_draft = active_workspace
- .as_ref()
- .and_then(|ws| ws.read(cx).panel::<AgentPanel>(cx))
- .map_or(false, |panel| panel.read(cx).active_thread_is_draft(cx));
-
let previous = mem::take(&mut self.contents);
let old_statuses: HashMap<acp::SessionId, AgentThreadStatus> = previous
@@ -755,6 +771,10 @@ impl Sidebar {
.iter()
.any(|ws| !workspace_path_list(ws, cx).paths().is_empty());
+ let active_ws_index = active_workspace
+ .as_ref()
+ .and_then(|active| workspaces.iter().position(|ws| ws == active));
+
for (ws_index, workspace) in workspaces.iter().enumerate() {
if absorbed.contains_key(&ws_index) {
continue;
@@ -820,6 +840,7 @@ impl Sidebar {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
});
@@ -908,6 +929,9 @@ impl Sidebar {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: Some(worktree_name.clone()),
+ worktree_full_path: Some(
+ worktree_path.display().to_string().into(),
+ ),
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
});
@@ -952,9 +976,7 @@ impl Sidebar {
ThreadEntryWorkspace::Closed(_) => false,
};
- if thread.is_background && thread.status == AgentThreadStatus::Completed {
- notified_threads.insert(session_id.clone());
- } else if thread.status == AgentThreadStatus::Completed
+ if thread.status == AgentThreadStatus::Completed
&& !is_thread_workspace_active
&& old_statuses.get(session_id) == Some(&AgentThreadStatus::Running)
{
@@ -1031,6 +1053,19 @@ impl Sidebar {
entries.push(thread.into());
}
} else {
+ let thread_count = threads.len();
+ let is_draft_for_workspace = self.agent_panel_visible
+ && self.active_thread_is_draft
+ && self.focused_thread.is_none()
+ && active_ws_index.is_some_and(|active_idx| {
+ active_idx == ws_index
+ || absorbed
+ .get(&active_idx)
+ .is_some_and(|(main_idx, _)| *main_idx == ws_index)
+ });
+
+ let show_new_thread_entry = thread_count == 0 || is_draft_for_workspace;
+
project_header_indices.push(entries.len());
entries.push(ListEntry::ProjectHeader {
path_list: path_list.clone(),
@@ -1045,10 +1080,13 @@ impl Sidebar {
continue;
}
- entries.push(ListEntry::NewThread {
- path_list: path_list.clone(),
- workspace: workspace.clone(),
- });
+ if show_new_thread_entry {
+ entries.push(ListEntry::NewThread {
+ path_list: path_list.clone(),
+ workspace: workspace.clone(),
+ is_active_draft: is_draft_for_workspace,
+ });
+ }
let total = threads.len();
@@ -1093,7 +1131,6 @@ impl Sidebar {
if total > DEFAULT_THREADS_SHOWN {
entries.push(ListEntry::ViewMore {
path_list: path_list.clone(),
- remaining_count: total.saturating_sub(visible),
is_fully_expanded,
});
}
@@ -1192,20 +1229,15 @@ impl Sidebar {
ListEntry::Thread(thread) => self.render_thread(ix, thread, is_selected, cx),
ListEntry::ViewMore {
path_list,
- remaining_count,
is_fully_expanded,
- } => self.render_view_more(
- ix,
- path_list,
- *remaining_count,
- *is_fully_expanded,
- is_selected,
- cx,
- ),
+ } => self.render_view_more(ix, path_list, *is_fully_expanded, is_selected, cx),
ListEntry::NewThread {
path_list,
workspace,
- } => self.render_new_thread(ix, path_list, workspace, is_selected, cx),
+ is_active_draft,
+ } => {
+ self.render_new_thread(ix, path_list, workspace, *is_active_draft, is_selected, cx)
+ }
};
if is_group_header_after_first {
@@ -1244,6 +1276,13 @@ impl Sidebar {
IconName::ChevronDown
};
+ let has_new_thread_entry = self
+ .contents
+ .entries
+ .get(ix + 1)
+ .is_some_and(|entry| matches!(entry, ListEntry::NewThread { .. }));
+ let show_new_thread_button = !has_new_thread_entry && !self.has_filter_query(cx);
+
let workspace_for_remove = workspace.clone();
let workspace_for_menu = workspace.clone();
@@ -1266,10 +1305,27 @@ impl Sidebar {
.into_any_element()
};
- ListItem::new(id)
- .height(Tab::content_height(cx))
- .group_name(group_name)
- .focused(is_selected)
+ let color = cx.theme().colors();
+ let hover_color = color
+ .element_active
+ .blend(color.element_background.opacity(0.2));
+
+ h_flex()
+ .id(id)
+ .group(&group_name)
+ .h(Tab::content_height(cx))
+ .w_full()
+ .px_1p5()
+ .border_1()
+ .map(|this| {
+ if is_selected {
+ this.border_color(color.border_focused)
+ } else {
+ this.border_color(gpui::transparent_black())
+ }
+ })
+ .justify_between()
+ .hover(|s| s.bg(hover_color))
.child(
h_flex()
.relative()
@@ -1280,7 +1336,7 @@ impl Sidebar {
h_flex().size_4().flex_none().justify_center().child(
Icon::new(disclosure_icon)
.size(IconSize::Small)
- .color(Color::Custom(cx.theme().colors().icon_muted.opacity(0.6))),
+ .color(Color::Custom(cx.theme().colors().icon_muted.opacity(0.5))),
),
)
.child(label)
@@ -1310,11 +1366,13 @@ impl Sidebar {
)
}),
)
- .end_hover_gradient_overlay(true)
- .end_slot({
+ .child({
+ let workspace_for_new_thread = workspace.clone();
+ let path_list_for_new_thread = path_list.clone();
+
h_flex()
.when(self.project_header_menu_ix != Some(ix), |this| {
- this.visible_on_hover("list_item")
+ this.visible_on_hover(group_name)
})
.on_mouse_down(gpui::MouseButton::Left, |_, _, cx| {
cx.stop_propagation();
@@ -1366,6 +1424,30 @@ impl Sidebar {
)),
)
})
+ .when(show_new_thread_button, |this| {
+ this.child(
+ IconButton::new(
+ SharedString::from(format!(
+ "{id_prefix}project-header-new-thread-{ix}",
+ )),
+ IconName::Plus,
+ )
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(Tooltip::text("New Thread"))
+ .on_click(cx.listener({
+ let workspace_for_new_thread = workspace_for_new_thread.clone();
+ let path_list_for_new_thread = path_list_for_new_thread.clone();
+ move |this, _, window, cx| {
+ // Uncollapse the group if collapsed so
+ // the new-thread entry becomes visible.
+ this.collapsed_groups.remove(&path_list_for_new_thread);
+ this.selection = None;
+ this.create_new_thread(&workspace_for_new_thread, window, cx);
+ }
+ })),
+ )
+ })
})
.on_click(cx.listener(move |this, _, window, cx| {
this.selection = None;
@@ -1468,7 +1550,7 @@ impl Sidebar {
let workspace_for_add = workspace.clone();
let multi_workspace_for_add = multi_workspace.clone();
- menu.separator().entry(
+ let menu = menu.separator().entry(
"Add Folder to Project",
Some(Box::new(AddFolderToProject)),
move |window, cx| {
@@ -1481,7 +1563,37 @@ impl Sidebar {
workspace.add_folder_to_project(&AddFolderToProject, window, cx);
});
},
- )
+ );
+
+ let workspace_count = multi_workspace
+ .upgrade()
+ .map_or(0, |mw| mw.read(cx).workspaces().len());
+ if workspace_count > 1 {
+ let workspace_for_move = workspace.clone();
+ let multi_workspace_for_move = multi_workspace.clone();
+ menu.entry(
+ "Move to New Window",
+ Some(Box::new(
+ zed_actions::agents_sidebar::MoveWorkspaceToNewWindow,
+ )),
+ move |window, cx| {
+ if let Some(mw) = multi_workspace_for_move.upgrade() {
+ mw.update(cx, |multi_workspace, cx| {
+ if let Some(index) = multi_workspace
+ .workspaces()
+ .iter()
+ .position(|w| *w == workspace_for_move)
+ {
+ multi_workspace
+ .move_workspace_to_new_window(index, window, cx);
+ }
+ });
+ }
+ },
+ )
+ } else {
+ menu
+ }
});
let this = this.clone();
@@ -1579,7 +1691,7 @@ impl Sidebar {
let color = cx.theme().colors();
let background = color
.title_bar_background
- .blend(color.panel_background.opacity(0.8));
+ .blend(color.panel_background.opacity(0.2));
let element = v_flex()
.absolute()
@@ -2414,17 +2526,21 @@ impl Sidebar {
ThreadItem::new(id, title)
.icon(thread.icon)
+ .status(thread.status)
.when_some(thread.icon_from_external_svg.clone(), |this, svg| {
this.custom_icon_from_external_svg(svg)
})
.when_some(thread.worktree_name.clone(), |this, name| {
- this.worktree(name)
+ let this = this.worktree(name);
+ match thread.worktree_full_path.clone() {
+ Some(path) => this.worktree_full_path(path),
+ None => this,
+ }
})
.worktree_highlight_positions(thread.worktree_highlight_positions.clone())
.when_some(timestamp, |this, ts| this.timestamp(ts))
.highlight_positions(thread.highlight_positions.to_vec())
- .status(thread.status)
- .generating_title(thread.is_title_generating)
+ .title_generating(thread.is_title_generating)
.notified(has_notification)
.when(thread.diff_stats.lines_added > 0, |this| {
this.added(thread.diff_stats.lines_added as usize)
@@ -2587,7 +2703,6 @@ impl Sidebar {
&self,
ix: usize,
path_list: &PathList,
- remaining_count: usize,
is_fully_expanded: bool,
is_selected: bool,
cx: &mut Context<Self>,
@@ -2595,23 +2710,15 @@ impl Sidebar {
let path_list = path_list.clone();
let id = SharedString::from(format!("view-more-{}", ix));
- let icon = if is_fully_expanded {
- IconName::ListCollapse
- } else {
- IconName::Plus
- };
-
let label: SharedString = if is_fully_expanded {
"Collapse".into()
- } else if remaining_count > 0 {
- format!("View More ({})", remaining_count).into()
} else {
"View More".into()
};
ThreadItem::new(id, label)
- .icon(icon)
.focused(is_selected)
+ .icon_visible(false)
.title_label_color(Color::Muted)
.on_click(cx.listener(move |this, _, _window, cx| {
this.selection = None;
@@ -2694,21 +2801,17 @@ impl Sidebar {
ix: usize,
_path_list: &PathList,
workspace: &Entity<Workspace>,
+ is_active_draft: bool,
is_selected: bool,
cx: &mut Context<Self>,
) -> AnyElement {
- let is_active = self.agent_panel_visible
- && self.active_thread_is_draft
- && self
- .multi_workspace
- .upgrade()
- .map_or(false, |mw| mw.read(cx).workspace() == workspace);
+ let is_active = is_active_draft && self.agent_panel_visible && self.active_thread_is_draft;
let label: SharedString = if is_active {
self.active_draft_text(cx)
- .unwrap_or_else(|| "New Thread".into())
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into())
} else {
- "New Thread".into()
+ DEFAULT_THREAD_TITLE.into()
};
let workspace = workspace.clone();
@@ -2716,9 +2819,9 @@ impl Sidebar {
let thread_item = ThreadItem::new(id, label)
.icon(IconName::Plus)
+ .icon_color(Color::Custom(cx.theme().colors().icon_muted.opacity(0.8)))
.selected(is_active)
.focused(is_selected)
- .title_label_color(Color::Custom(cx.theme().colors().text.opacity(0.85)))
.when(!is_active, |this| {
this.on_click(cx.listener(move |this, _, window, cx| {
this.selection = None;
@@ -2993,11 +3096,11 @@ impl Render for Sidebar {
let _titlebar_height = ui::utils::platform_title_bar_height(window);
let ui_font = theme::setup_ui_font(window, cx);
let sticky_header = self.render_sticky_header(window, cx);
- let bg = cx
- .theme()
- .colors()
+
+ let color = cx.theme().colors();
+ let bg = color
.title_bar_background
- .blend(cx.theme().colors().panel_background.opacity(0.8));
+ .blend(color.panel_background.opacity(0.32));
let no_open_projects = !self.contents.has_open_projects;
let no_search_results = self.contents.entries.is_empty();
@@ -3031,7 +3134,7 @@ impl Render for Sidebar {
.w(self.width)
.bg(bg)
.border_r_1()
- .border_color(cx.theme().colors().border)
+ .border_color(color.border)
.map(|this| match &self.view {
SidebarView::ThreadList => this
.child(self.render_sidebar_header(no_open_projects, window, cx))
@@ -3306,14 +3409,12 @@ mod tests {
)
}
ListEntry::ViewMore {
- remaining_count,
- is_fully_expanded,
- ..
+ is_fully_expanded, ..
} => {
if *is_fully_expanded {
format!(" - Collapse{}", selected)
} else {
- format!(" + View More ({}){}", remaining_count, selected)
+ format!(" + View More{}", selected)
}
}
ListEntry::NewThread { .. } => {
@@ -3363,6 +3464,27 @@ mod tests {
assert_eq!(Sidebar::clean_mention_links(""), "");
}
+ #[gpui::test]
+ async fn test_entities_released_on_window_close(cx: &mut TestAppContext) {
+ let project = init_test_project("/my-project", cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
+ let weak_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().downgrade());
+ let weak_sidebar = sidebar.downgrade();
+ let weak_multi_workspace = multi_workspace.downgrade();
+
+ drop(sidebar);
+ drop(multi_workspace);
+ cx.update(|window, _cx| window.remove_window());
+ cx.run_until_parked();
+
+ weak_multi_workspace.assert_released();
+ weak_sidebar.assert_released();
+ weak_workspace.assert_released();
+ }
+
#[gpui::test]
async fn test_single_workspace_no_threads(cx: &mut TestAppContext) {
let project = init_test_project("/my-project", cx).await;
@@ -3411,7 +3533,6 @@ mod tests {
visible_entries_as_strings(&sidebar, cx),
vec![
"v [my-project]",
- " [+ New Thread]",
" Fix crash in project panel",
" Add inline diff view",
]
@@ -3443,7 +3564,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [project-a]", " [+ New Thread]", " Thread A1"]
+ vec!["v [project-a]", " Thread A1"]
);
// Add a second workspace
@@ -3454,7 +3575,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [project-a]", " [+ New Thread]", " Thread A1",]
+ vec!["v [project-a]", " Thread A1",]
);
// Remove the second workspace
@@ -3465,7 +3586,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [project-a]", " [+ New Thread]", " Thread A1"]
+ vec!["v [project-a]", " Thread A1"]
);
}
@@ -3486,13 +3607,12 @@ mod tests {
visible_entries_as_strings(&sidebar, cx),
vec![
"v [my-project]",
- " [+ New Thread]",
" Thread 12",
" Thread 11",
" Thread 10",
" Thread 9",
" Thread 8",
- " + View More (7)",
+ " + View More",
]
);
}
@@ -3511,23 +3631,23 @@ mod tests {
multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
cx.run_until_parked();
- // Initially shows NewThread + 5 threads + View More (12 remaining)
+ // Initially shows 5 threads + View More
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 8); // header + NewThread + 5 threads + View More
- assert!(entries.iter().any(|e| e.contains("View More (12)")));
+ assert_eq!(entries.len(), 7); // header + 5 threads + View More
+ assert!(entries.iter().any(|e| e.contains("View More")));
// Focus and navigate to View More, then confirm to expand by one batch
open_and_focus_sidebar(&sidebar, cx);
- for _ in 0..8 {
+ for _ in 0..7 {
cx.dispatch_action(SelectNext);
}
cx.dispatch_action(Confirm);
cx.run_until_parked();
- // Now shows NewThread + 10 threads + View More (7 remaining)
+ // Now shows 10 threads + View More
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 13); // header + NewThread + 10 threads + View More
- assert!(entries.iter().any(|e| e.contains("View More (7)")));
+ assert_eq!(entries.len(), 12); // header + 10 threads + View More
+ assert!(entries.iter().any(|e| e.contains("View More")));
// Expand again by one batch
sidebar.update_in(cx, |s, _window, cx| {
@@ -3537,10 +3657,10 @@ mod tests {
});
cx.run_until_parked();
- // Now shows NewThread + 15 threads + View More (2 remaining)
+ // Now shows 15 threads + View More
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 18); // header + NewThread + 15 threads + View More
- assert!(entries.iter().any(|e| e.contains("View More (2)")));
+ assert_eq!(entries.len(), 17); // header + 15 threads + View More
+ assert!(entries.iter().any(|e| e.contains("View More")));
// Expand one more time - should show all 17 threads with Collapse button
sidebar.update_in(cx, |s, _window, cx| {
@@ -3552,7 +3672,7 @@ mod tests {
// All 17 threads shown with Collapse button
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 20); // header + NewThread + 17 threads + Collapse
+ assert_eq!(entries.len(), 19); // header + 17 threads + Collapse
assert!(!entries.iter().any(|e| e.contains("View More")));
assert!(entries.iter().any(|e| e.contains("Collapse")));
@@ -3563,10 +3683,10 @@ mod tests {
});
cx.run_until_parked();
- // Back to initial state: NewThread + 5 threads + View More (12 remaining)
+ // Back to initial state: 5 threads + View More
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 8); // header + NewThread + 5 threads + View More
- assert!(entries.iter().any(|e| e.contains("View More (12)")));
+ assert_eq!(entries.len(), 7); // header + 5 threads + View More
+ assert!(entries.iter().any(|e| e.contains("View More")));
}
#[gpui::test]
@@ -3584,7 +3704,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [my-project]", " [+ New Thread]", " Thread 1"]
+ vec!["v [my-project]", " Thread 1"]
);
// Collapse
@@ -3606,7 +3726,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [my-project]", " [+ New Thread]", " Thread 1"]
+ vec!["v [my-project]", " Thread 1"]
);
}
@@ -3636,7 +3756,6 @@ mod tests {
has_running_threads: false,
waiting_thread_count: 0,
},
- // Thread with default (Completed) status, not active
ListEntry::Thread(ThreadEntry {
agent: Agent::NativeAgent,
session_info: acp_thread::AgentSessionInfo {
@@ -3656,6 +3775,7 @@ mod tests {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
}),
@@ -3679,6 +3799,7 @@ mod tests {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
}),
@@ -3702,6 +3823,7 @@ mod tests {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
}),
@@ -3725,6 +3847,7 @@ mod tests {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
}),
@@ -3748,13 +3871,13 @@ mod tests {
is_title_generating: false,
highlight_positions: Vec::new(),
worktree_name: None,
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
diff_stats: DiffStats::default(),
}),
// View More entry
ListEntry::ViewMore {
path_list: expanded_path.clone(),
- remaining_count: 42,
is_fully_expanded: false,
},
// Collapsed project header
@@ -3767,6 +3890,7 @@ mod tests {
waiting_thread_count: 0,
},
];
+
// Select the Running thread (index 2)
s.selection = Some(2);
});
@@ -3780,7 +3904,7 @@ mod tests {
" Error thread * (error)",
" Waiting thread (waiting)",
" Notified thread * (!)",
- " + View More (42)",
+ " + View More",
"> [collapsed-project]",
]
);
@@ -3824,7 +3948,7 @@ mod tests {
multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
cx.run_until_parked();
- // Entries: [header, new_thread, thread3, thread2, thread1]
+ // Entries: [header, thread3, thread2, thread1]
// Focusing the sidebar does not set a selection; select_next/select_previous
// handle None gracefully by starting from the first or last entry.
open_and_focus_sidebar(&sidebar, cx);
@@ -3844,9 +3968,6 @@ mod tests {
cx.dispatch_action(SelectNext);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(3));
- cx.dispatch_action(SelectNext);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(4));
-
// At the end, wraps back to first entry
cx.dispatch_action(SelectNext);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(0));
@@ -3858,13 +3979,8 @@ mod tests {
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(2));
cx.dispatch_action(SelectNext);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(3));
- cx.dispatch_action(SelectNext);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(4));
// Move back up
- cx.dispatch_action(SelectPrevious);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(3));
-
cx.dispatch_action(SelectPrevious);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(2));
@@ -3895,7 +4011,7 @@ mod tests {
// SelectLast jumps to the end
cx.dispatch_action(SelectLast);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(4));
+ assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(3));
// SelectFirst jumps to the beginning
cx.dispatch_action(SelectFirst);
@@ -3948,7 +4064,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [my-project]", " [+ New Thread]", " Thread 1"]
+ vec!["v [my-project]", " Thread 1"]
);
// Focus the sidebar and select the header (index 0)
@@ -3972,11 +4088,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec![
- "v [my-project] <== selected",
- " [+ New Thread]",
- " Thread 1",
- ]
+ vec!["v [my-project] <== selected", " Thread 1",]
);
}
@@ -3992,17 +4104,17 @@ mod tests {
multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
cx.run_until_parked();
- // Should show header + NewThread + 5 threads + "View More (3)"
+ // Should show header + 5 threads + "View More"
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 8);
- assert!(entries.iter().any(|e| e.contains("View More (3)")));
+ assert_eq!(entries.len(), 7);
+ assert!(entries.iter().any(|e| e.contains("View More")));
- // Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 7)
+ // Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 6)
open_and_focus_sidebar(&sidebar, cx);
- for _ in 0..8 {
+ for _ in 0..7 {
cx.dispatch_action(SelectNext);
}
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(7));
+ assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(6));
// Confirm on "View More" to expand
cx.dispatch_action(Confirm);
@@ -4010,7 +4122,7 @@ mod tests {
// All 8 threads should now be visible with a "Collapse" button
let entries = visible_entries_as_strings(&sidebar, cx);
- assert_eq!(entries.len(), 11); // header + NewThread + 8 threads + Collapse button
+ assert_eq!(entries.len(), 10); // header + 8 threads + Collapse button
assert!(!entries.iter().any(|e| e.contains("View More")));
assert!(entries.iter().any(|e| e.contains("Collapse")));
}
@@ -4029,7 +4141,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [my-project]", " [+ New Thread]", " Thread 1"]
+ vec!["v [my-project]", " Thread 1"]
);
// Focus sidebar and manually select the header (index 0). Press left to collapse.
@@ -4052,11 +4164,7 @@ mod tests {
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec![
- "v [my-project] <== selected",
- " [+ New Thread]",
- " Thread 1",
- ]
+ vec!["v [my-project] <== selected", " Thread 1",]
);
// Press right again on already-expanded header moves selection down
@@ -4080,16 +4188,11 @@ mod tests {
open_and_focus_sidebar(&sidebar, cx);
cx.dispatch_action(SelectNext);
cx.dispatch_action(SelectNext);
- cx.dispatch_action(SelectNext);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(2));
+ assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec![
- "v [my-project]",
- " [+ New Thread]",
- " Thread 1 <== selected",
- ]
+ vec!["v [my-project]", " Thread 1 <== selected",]
);
// Pressing left on a child collapses the parent group and selects it
@@ -4110,7 +4213,7 @@ mod tests {
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx));
let sidebar = setup_sidebar(&multi_workspace, cx);
- // Even an empty project has the header and a new thread button
+ // An empty project has the header and a new thread button.
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
vec!["v [empty-project]", " [+ New Thread]"]
@@ -4149,12 +4252,11 @@ mod tests {
multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
cx.run_until_parked();
- // Focus sidebar (selection starts at None), navigate down to the thread (index 2)
+ // Focus sidebar (selection starts at None), navigate down to the thread (index 1)
open_and_focus_sidebar(&sidebar, cx);
cx.dispatch_action(SelectNext);
cx.dispatch_action(SelectNext);
- cx.dispatch_action(SelectNext);
- assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(2));
+ assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
// Collapse the group, which removes the thread from the list
cx.dispatch_action(SelectParent);
@@ -4254,15 +4356,10 @@ mod tests {
cx.run_until_parked();
let mut entries = visible_entries_as_strings(&sidebar, cx);
- entries[2..].sort();
+ entries[1..].sort();
assert_eq!(
entries,
- vec![
- "v [my-project]",
- " [+ New Thread]",
- " Hello *",
- " Hello * (running)",
- ]
+ vec!["v [my-project]", " Hello *", " Hello * (running)",]
);
}
@@ -4303,7 +4400,7 @@ mod tests {
// Thread A is still running; no notification yet.
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [project-a]", " [+ New Thread]", " Hello * (running)",]
+ vec!["v [project-a]", " Hello * (running)",]
);
// Complete thread A's turn (transition Running β Completed).
@@ -4313,7 +4410,7 @@ mod tests {
// The completed background thread shows a notification indicator.
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
- vec!["v [project-a]", " [+ New Thread]", " Hello * (!)",]
+ vec!["v [project-a]", " Hello * (!)",]
);
}
@@ -4356,7 +4453,6 @@ mod tests {
visible_entries_as_strings(&sidebar, cx),
vec![
"v [my-project]",
- " [+ New Thread]",
" Fix crash in project panel",
" Add inline diff view",
" Refactor settings module",
@@ -14,7 +14,7 @@ path = "src/sum_tree.rs"
doctest = false
[dependencies]
-arrayvec = "0.7.1"
+heapless.workspace = true
rayon.workspace = true
log.workspace = true
ztracing.workspace = true
@@ -1,5 +1,5 @@
use super::*;
-use arrayvec::ArrayVec;
+use heapless::Vec as ArrayVec;
use std::{cmp::Ordering, mem, sync::Arc};
use ztracing::instrument;
@@ -29,7 +29,7 @@ impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
#[derive(Clone)]
pub struct Cursor<'a, 'b, T: Item, D> {
tree: &'a SumTree<T>,
- stack: ArrayVec<StackEntry<'a, T, D>, 16>,
+ stack: ArrayVec<StackEntry<'a, T, D>, 16, u8>,
pub position: D,
did_seek: bool,
at_end: bool,
@@ -53,7 +53,7 @@ where
pub struct Iter<'a, T: Item> {
tree: &'a SumTree<T>,
- stack: ArrayVec<StackEntry<'a, T, ()>, 16>,
+ stack: ArrayVec<StackEntry<'a, T, ()>, 16, u8>,
}
impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
@@ -231,11 +231,13 @@ where
self.position = D::zero(self.cx);
self.at_end = self.tree.is_empty();
if !self.tree.is_empty() {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: self.tree.0.child_summaries().len() as u32,
- position: D::from_summary(self.tree.summary(), self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: self.tree.0.child_summaries().len() as u32,
+ position: D::from_summary(self.tree.summary(), self.cx),
+ })
+ .unwrap_oob();
}
}
@@ -267,11 +269,13 @@ where
Node::Internal { child_trees, .. } => {
if descending {
let tree = &child_trees[entry.index()];
- self.stack.push(StackEntry {
- position: D::zero(self.cx),
- tree,
- index: tree.0.child_summaries().len() as u32 - 1,
- })
+ self.stack
+ .push(StackEntry {
+ position: D::zero(self.cx),
+ tree,
+ index: tree.0.child_summaries().len() as u32 - 1,
+ })
+ .unwrap_oob();
}
}
Node::Leaf { .. } => {
@@ -297,11 +301,13 @@ where
if self.stack.is_empty() {
if !self.at_end {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: D::zero(self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: D::zero(self.cx),
+ })
+ .unwrap_oob();
descend = true;
}
self.did_seek = true;
@@ -361,11 +367,13 @@ where
if let Some(subtree) = new_subtree {
descend = true;
- self.stack.push(StackEntry {
- tree: subtree,
- index: 0,
- position: self.position.clone(),
- });
+ self.stack
+ .push(StackEntry {
+ tree: subtree,
+ index: 0,
+ position: self.position.clone(),
+ })
+ .unwrap_oob();
} else {
descend = false;
self.stack.pop();
@@ -467,11 +475,13 @@ where
if !self.did_seek {
self.did_seek = true;
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: D::zero(self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: D::zero(self.cx),
+ })
+ .unwrap_oob();
}
let mut ascending = false;
@@ -503,11 +513,13 @@ where
entry.index += 1;
entry.position = self.position.clone();
} else {
- self.stack.push(StackEntry {
- tree: child_tree,
- index: 0,
- position: self.position.clone(),
- });
+ self.stack
+ .push(StackEntry {
+ tree: child_tree,
+ index: 0,
+ position: self.position.clone(),
+ })
+ .unwrap_oob();
ascending = false;
continue 'outer;
}
@@ -578,11 +590,13 @@ impl<'a, T: Item> Iterator for Iter<'a, T> {
let mut descend = false;
if self.stack.is_empty() {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: (),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: (),
+ })
+ .unwrap_oob();
descend = true;
}
@@ -611,11 +625,13 @@ impl<'a, T: Item> Iterator for Iter<'a, T> {
if let Some(subtree) = new_subtree {
descend = true;
- self.stack.push(StackEntry {
- tree: subtree,
- index: 0,
- position: (),
- });
+ self.stack
+ .push(StackEntry {
+ tree: subtree,
+ index: 0,
+ position: (),
+ })
+ .unwrap_oob();
} else {
descend = false;
self.stack.pop();
@@ -748,8 +764,8 @@ trait SeekAggregate<'a, T: Item> {
struct SliceSeekAggregate<T: Item> {
tree: SumTree<T>,
- leaf_items: ArrayVec<T, { 2 * TREE_BASE }>,
- leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
+ leaf_items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
+ leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
leaf_summary: T::Summary,
}
@@ -786,8 +802,8 @@ impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
summary: &T::Summary,
cx: <T::Summary as Summary>::Context<'_>,
) {
- self.leaf_items.push(item.clone());
- self.leaf_item_summaries.push(summary.clone());
+ self.leaf_items.push(item.clone()).unwrap_oob();
+ self.leaf_item_summaries.push(summary.clone()).unwrap_oob();
Summary::add_summary(&mut self.leaf_summary, summary, cx);
}
fn push_tree(
@@ -3,8 +3,8 @@ mod cursor;
pub mod property_test;
mod tree_map;
-use arrayvec::ArrayVec;
pub use cursor::{Cursor, FilterCursor, Iter};
+use heapless::Vec as ArrayVec;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator as _};
use std::marker::PhantomData;
use std::mem;
@@ -17,6 +17,17 @@ pub const TREE_BASE: usize = 2;
#[cfg(not(test))]
pub const TREE_BASE: usize = 6;
+// Helper for when we cannot use ArrayVec::<T>::push().unwrap() as T doesn't impl Debug
+trait CapacityResultExt {
+ fn unwrap_oob(self);
+}
+
+impl<T> CapacityResultExt for Result<(), T> {
+ fn unwrap_oob(self) {
+ self.unwrap_or_else(|_| panic!("item should fit into fixed size ArrayVec"))
+ }
+}
+
/// An item that can be stored in a [`SumTree`]
///
/// Must be summarized by a type that implements [`Summary`]
@@ -243,8 +254,9 @@ impl<T: Item> SumTree<T> {
let mut iter = iter.into_iter().fuse().peekable();
while iter.peek().is_some() {
- let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
- let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+ let items: ArrayVec<T, { 2 * TREE_BASE }, u8> =
+ iter.by_ref().take(2 * TREE_BASE).collect();
+ let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> =
items.iter().map(|item| item.summary(cx)).collect();
let mut summary = item_summaries[0].clone();
@@ -284,8 +296,8 @@ impl<T: Item> SumTree<T> {
};
let child_summary = child_node.summary();
<T::Summary as Summary>::add_summary(summary, child_summary, cx);
- child_summaries.push(child_summary.clone());
- child_trees.push(child_node);
+ child_summaries.push(child_summary.clone()).unwrap_oob();
+ child_trees.push(child_node.clone()).unwrap_oob();
if child_trees.len() == 2 * TREE_BASE {
parent_nodes.extend(current_parent_node.take());
@@ -315,8 +327,8 @@ impl<T: Item> SumTree<T> {
.into_par_iter()
.chunks(2 * TREE_BASE)
.map(|items| {
- let items: ArrayVec<T, { 2 * TREE_BASE }> = items.into_iter().collect();
- let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+ let items: ArrayVec<T, { 2 * TREE_BASE }, u8> = items.into_iter().collect();
+ let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> =
items.iter().map(|item| item.summary(cx)).collect();
let mut summary = item_summaries[0].clone();
for item_summary in &item_summaries[1..] {
@@ -337,9 +349,9 @@ impl<T: Item> SumTree<T> {
.into_par_iter()
.chunks(2 * TREE_BASE)
.map(|child_nodes| {
- let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }> =
+ let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8> =
child_nodes.into_iter().collect();
- let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
+ let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> = child_trees
.iter()
.map(|child_tree| child_tree.summary().clone())
.collect();
@@ -798,14 +810,16 @@ impl<T: Item> SumTree<T> {
<T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
let height_delta = *height - other_node.height();
- let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
- let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }>::new();
+ let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }, u8>::new();
+ let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }, u8>::new();
if height_delta == 0 {
summaries_to_append.extend(other_node.child_summaries().iter().cloned());
trees_to_append.extend(other_node.child_trees().iter().cloned());
} else if height_delta == 1 && !other_node.is_underflowing() {
- summaries_to_append.push(other_node.summary().clone());
- trees_to_append.push(other)
+ summaries_to_append
+ .push(other_node.summary().clone())
+ .unwrap_oob();
+ trees_to_append.push(other).unwrap_oob();
} else {
let tree_to_append = child_trees
.last_mut()
@@ -815,15 +829,17 @@ impl<T: Item> SumTree<T> {
child_trees.last().unwrap().0.summary().clone();
if let Some(split_tree) = tree_to_append {
- summaries_to_append.push(split_tree.0.summary().clone());
- trees_to_append.push(split_tree);
+ summaries_to_append
+ .push(split_tree.0.summary().clone())
+ .unwrap_oob();
+ trees_to_append.push(split_tree).unwrap_oob();
}
}
let child_count = child_trees.len() + trees_to_append.len();
if child_count > 2 * TREE_BASE {
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
- let right_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8>;
+ let right_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8>;
let left_trees;
let right_trees;
@@ -868,7 +884,7 @@ impl<T: Item> SumTree<T> {
let left_items;
let right_items;
let left_summaries;
- let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>;
+ let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>;
let midpoint = (child_count + child_count % 2) / 2;
{
@@ -933,8 +949,10 @@ impl<T: Item> SumTree<T> {
*child_summaries.first_mut().unwrap() = first.summary().clone();
if let Some(tree) = res {
if child_trees.len() < 2 * TREE_BASE {
- child_summaries.insert(0, tree.summary().clone());
- child_trees.insert(0, tree);
+ child_summaries
+ .insert(0, tree.summary().clone())
+ .unwrap_oob();
+ child_trees.insert(0, tree).unwrap_oob();
None
} else {
let new_child_summaries = {
@@ -1016,7 +1034,7 @@ impl<T: Item> SumTree<T> {
.iter()
.chain(child_summaries.iter())
.cloned();
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }> =
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8> =
all_summaries.by_ref().take(midpoint).collect();
*child_summaries = all_summaries.collect();
@@ -1065,7 +1083,7 @@ impl<T: Item> SumTree<T> {
.iter()
.chain(item_summaries.iter())
.cloned();
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }> =
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8> =
all_summaries.by_ref().take(midpoint).collect();
*item_summaries = all_summaries.collect();
@@ -1088,11 +1106,11 @@ impl<T: Item> SumTree<T> {
) -> Self {
let height = left.0.height() + 1;
let mut child_summaries = ArrayVec::new();
- child_summaries.push(left.0.summary().clone());
- child_summaries.push(right.0.summary().clone());
+ child_summaries.push(left.0.summary().clone()).unwrap_oob();
+ child_summaries.push(right.0.summary().clone()).unwrap_oob();
let mut child_trees = ArrayVec::new();
- child_trees.push(left);
- child_trees.push(right);
+ child_trees.push(left).unwrap_oob();
+ child_trees.push(right).unwrap_oob();
SumTree(Arc::new(Node::Internal {
height,
summary: sum(child_summaries.iter(), cx),
@@ -1252,13 +1270,13 @@ pub enum Node<T: Item> {
Internal {
height: u8,
summary: T::Summary,
- child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
- child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }>,
+ child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
+ child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8>,
},
Leaf {
summary: T::Summary,
- items: ArrayVec<T, { 2 * TREE_BASE }>,
- item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
+ items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
+ item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
},
}
@@ -1323,14 +1341,14 @@ impl<T: Item> Node<T> {
}
}
- fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }> {
+ fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8> {
match self {
Node::Internal { child_trees, .. } => child_trees,
Node::Leaf { .. } => panic!("Leaf nodes have no child trees"),
}
}
- fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }> {
+ fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }, u8> {
match self {
Node::Leaf { items, .. } => items,
Node::Internal { .. } => panic!("Internal nodes have no items"),
@@ -207,11 +207,16 @@ impl TerminalBounds {
}
pub fn num_lines(&self) -> usize {
- (self.bounds.size.height / self.line_height).floor() as usize
+ // Tolerance to prevent f32 precision from losing a row:
+ // `N * line_height / line_height` can be N-epsilon, which floor()
+ // would round down, pushing the first line into invisible scrollback.
+ let raw = self.bounds.size.height / self.line_height;
+ raw.next_up().floor() as usize
}
pub fn num_columns(&self) -> usize {
- (self.bounds.size.width / self.cell_width).floor() as usize
+ let raw = self.bounds.size.width / self.cell_width;
+ raw.next_up().floor() as usize
}
pub fn height(&self) -> Pixels {
@@ -3364,5 +3369,59 @@ mod tests {
scroll_by(-1);
}
}
+
+ #[test]
+ fn test_num_lines_float_precision() {
+ let line_heights = [
+ 20.1f32, 16.7, 18.3, 22.9, 14.1, 15.6, 17.8, 19.4, 21.3, 23.7,
+ ];
+ for &line_height in &line_heights {
+ for n in 1..=100 {
+ let height = n as f32 * line_height;
+ let bounds = TerminalBounds::new(
+ px(line_height),
+ px(8.0),
+ Bounds {
+ origin: Point::default(),
+ size: Size {
+ width: px(800.0),
+ height: px(height),
+ },
+ },
+ );
+ assert_eq!(
+ bounds.num_lines(),
+ n,
+ "num_lines() should be {n} for height={height}, line_height={line_height}"
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn test_num_columns_float_precision() {
+ let cell_widths = [8.1f32, 7.3, 9.7, 6.9, 10.1];
+ for &cell_width in &cell_widths {
+ for n in 1..=200 {
+ let width = n as f32 * cell_width;
+ let bounds = TerminalBounds::new(
+ px(20.0),
+ px(cell_width),
+ Bounds {
+ origin: Point::default(),
+ size: Size {
+ width: px(width),
+ height: px(400.0),
+ },
+ },
+ );
+ assert_eq!(
+ bounds.num_columns(),
+ n,
+ "num_columns() should be {n} for width={width}, cell_width={cell_width}"
+ );
+ }
+ }
+ }
}
}
@@ -162,6 +162,7 @@ pub struct TitleBar {
impl Render for TitleBar {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let title_bar_settings = *TitleBarSettings::get_global(cx);
+ let button_layout = title_bar_settings.button_layout;
let show_menus = show_menus(cx);
@@ -266,6 +267,7 @@ impl Render for TitleBar {
if show_menus {
self.platform_titlebar.update(cx, |this, _| {
+ this.set_button_layout(button_layout);
this.set_children(
self.application_menu
.clone()
@@ -293,6 +295,7 @@ impl Render for TitleBar {
.into_any_element()
} else {
self.platform_titlebar.update(cx, |this, _| {
+ this.set_button_layout(button_layout);
this.set_children(children);
});
self.platform_titlebar.clone().into_any_element()
@@ -360,6 +363,7 @@ impl TitleBar {
}),
);
subscriptions.push(cx.observe(&user_store, |_a, _, cx| cx.notify()));
+ subscriptions.push(cx.observe_button_layout_changed(window, |_, _, cx| cx.notify()));
if let Some(trusted_worktrees) = TrustedWorktrees::try_get_global(cx) {
subscriptions.push(cx.subscribe(&trusted_worktrees, |_, _, _, cx| {
cx.notify();
@@ -908,14 +912,7 @@ impl TitleBar {
};
let branch_name = branch_name?;
- let button_text = if let Some(worktree_name) = linked_worktree_name {
- format!("{}/{}", worktree_name, branch_name)
- } else {
- branch_name
- };
-
let settings = TitleBarSettings::get_global(cx);
-
let effective_repository = Some(repository);
Some(
@@ -931,21 +928,42 @@ impl TitleBar {
))
})
.trigger_with_tooltip(
- Button::new("project_branch_trigger", button_text)
+ ButtonLike::new("project_branch_trigger")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .label_size(LabelSize::Small)
- .color(Color::Muted)
- .when(settings.show_branch_icon, |branch_button| {
- let (icon, icon_color) = icon_info;
- branch_button.start_icon(
- Icon::new(icon).size(IconSize::Indicator).color(icon_color),
- )
- }),
+ .child(
+ h_flex()
+ .gap_0p5()
+ .when(settings.show_branch_icon, |this| {
+ let (icon, icon_color) = icon_info;
+ this.child(
+ Icon::new(icon).size(IconSize::XSmall).color(icon_color),
+ )
+ })
+ .when_some(linked_worktree_name.as_ref(), |this, worktree_name| {
+ this.child(
+ Label::new(worktree_name)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("/").size(LabelSize::Small).color(
+ Color::Custom(
+ cx.theme().colors().text_muted.opacity(0.4),
+ ),
+ ),
+ )
+ })
+ .child(
+ Label::new(branch_name)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ ),
move |_window, cx| {
Tooltip::with_meta(
- "Recent Branches",
+ "Git Switcher",
Some(&zed_actions::git::Branch),
- "Local branches only",
+ "Worktrees, Branches, and Stashes",
cx,
)
},
@@ -1,3 +1,4 @@
+use gpui::WindowButtonLayout;
use settings::{RegisterSetting, Settings, SettingsContent};
#[derive(Copy, Clone, Debug, RegisterSetting)]
@@ -10,6 +11,7 @@ pub struct TitleBarSettings {
pub show_sign_in: bool,
pub show_user_menu: bool,
pub show_menus: bool,
+ pub button_layout: Option<WindowButtonLayout>,
}
impl Settings for TitleBarSettings {
@@ -24,6 +26,7 @@ impl Settings for TitleBarSettings {
show_sign_in: content.show_sign_in.unwrap(),
show_user_menu: content.show_user_menu.unwrap(),
show_menus: content.show_menus.unwrap(),
+ button_layout: content.button_layout.unwrap_or_default().into_layout(),
}
}
}
@@ -22,25 +22,26 @@ pub enum AgentThreadStatus {
pub struct ThreadItem {
id: ElementId,
icon: IconName,
+ icon_color: Option<Color>,
+ icon_visible: bool,
custom_icon_from_external_svg: Option<SharedString>,
title: SharedString,
+ title_label_color: Option<Color>,
+ title_generating: bool,
+ highlight_positions: Vec<usize>,
timestamp: SharedString,
notified: bool,
status: AgentThreadStatus,
- generating_title: bool,
selected: bool,
focused: bool,
hovered: bool,
- docked_right: bool,
added: Option<usize>,
removed: Option<usize>,
worktree: Option<SharedString>,
- highlight_positions: Vec<usize>,
+ worktree_full_path: Option<SharedString>,
worktree_highlight_positions: Vec<usize>,
on_click: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>,
on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
- title_label_color: Option<Color>,
- title_label_size: Option<LabelSize>,
action_slot: Option<AnyElement>,
tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView + 'static>>,
}
@@ -50,25 +51,26 @@ impl ThreadItem {
Self {
id: id.into(),
icon: IconName::ZedAgent,
+ icon_color: None,
+ icon_visible: true,
custom_icon_from_external_svg: None,
title: title.into(),
+ title_label_color: None,
+ title_generating: false,
+ highlight_positions: Vec::new(),
timestamp: "".into(),
notified: false,
status: AgentThreadStatus::default(),
- generating_title: false,
selected: false,
focused: false,
hovered: false,
- docked_right: false,
added: None,
removed: None,
worktree: None,
- highlight_positions: Vec::new(),
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
on_click: None,
on_hover: Box::new(|_, _, _| {}),
- title_label_color: None,
- title_label_size: None,
action_slot: None,
tooltip: None,
}
@@ -84,6 +86,16 @@ impl ThreadItem {
self
}
+ pub fn icon_color(mut self, color: Color) -> Self {
+ self.icon_color = Some(color);
+ self
+ }
+
+ pub fn icon_visible(mut self, visible: bool) -> Self {
+ self.icon_visible = visible;
+ self
+ }
+
pub fn custom_icon_from_external_svg(mut self, svg: impl Into<SharedString>) -> Self {
self.custom_icon_from_external_svg = Some(svg.into());
self
@@ -99,8 +111,18 @@ impl ThreadItem {
self
}
- pub fn generating_title(mut self, generating: bool) -> Self {
- self.generating_title = generating;
+ pub fn title_generating(mut self, generating: bool) -> Self {
+ self.title_generating = generating;
+ self
+ }
+
+ pub fn title_label_color(mut self, color: Color) -> Self {
+ self.title_label_color = Some(color);
+ self
+ }
+
+ pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
+ self.highlight_positions = positions;
self
}
@@ -124,18 +146,13 @@ impl ThreadItem {
self
}
- pub fn docked_right(mut self, docked_right: bool) -> Self {
- self.docked_right = docked_right;
- self
- }
-
pub fn worktree(mut self, worktree: impl Into<SharedString>) -> Self {
self.worktree = Some(worktree.into());
self
}
- pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
- self.highlight_positions = positions;
+ pub fn worktree_full_path(mut self, worktree_full_path: impl Into<SharedString>) -> Self {
+ self.worktree_full_path = Some(worktree_full_path.into());
self
}
@@ -162,16 +179,6 @@ impl ThreadItem {
self
}
- pub fn title_label_color(mut self, color: Color) -> Self {
- self.title_label_color = Some(color);
- self
- }
-
- pub fn title_label_size(mut self, size: LabelSize) -> Self {
- self.title_label_size = Some(size);
- self
- }
-
pub fn action_slot(mut self, element: impl IntoElement) -> Self {
self.action_slot = Some(element.into_any_element());
self
@@ -186,6 +193,26 @@ impl ThreadItem {
impl RenderOnce for ThreadItem {
fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement {
let color = cx.theme().colors();
+ let base_bg = color
+ .title_bar_background
+ .blend(color.panel_background.opacity(0.2));
+
+ let base_bg = if self.selected {
+ color.element_active
+ } else {
+ base_bg
+ };
+
+ let hover_color = color
+ .element_active
+ .blend(color.element_background.opacity(0.2));
+
+ let gradient_overlay = GradientFade::new(base_bg, hover_color, hover_color)
+ .width(px(64.0))
+ .right(px(-10.0))
+ .gradient_stop(0.75)
+ .group_name("thread-item");
+
let dot_separator = || {
Label::new("β’")
.size(LabelSize::Small)
@@ -194,25 +221,26 @@ impl RenderOnce for ThreadItem {
};
let icon_id = format!("icon-{}", self.id);
+ let icon_visible = self.icon_visible;
let icon_container = || {
h_flex()
.id(icon_id.clone())
.size_4()
.flex_none()
.justify_center()
+ .when(!icon_visible, |this| this.invisible())
};
+ let icon_color = self.icon_color.unwrap_or(Color::Muted);
let agent_icon = if let Some(custom_svg) = self.custom_icon_from_external_svg {
Icon::from_external_svg(custom_svg)
- .color(Color::Muted)
+ .color(icon_color)
.size(IconSize::Small)
} else {
- Icon::new(self.icon)
- .color(Color::Muted)
- .size(IconSize::Small)
+ Icon::new(self.icon).color(icon_color).size(IconSize::Small)
};
let decoration = |icon: IconDecorationKind, color: Hsla| {
- IconDecoration::new(icon, cx.theme().colors().surface_background, cx)
+ IconDecoration::new(icon, base_bg, cx)
.color(color)
.position(gpui::Point {
x: px(-2.),
@@ -264,10 +292,9 @@ impl RenderOnce for ThreadItem {
let title = self.title;
let highlight_positions = self.highlight_positions;
- let title_label_size = self.title_label_size.unwrap_or(LabelSize::Default);
- let title_label = if self.generating_title {
+
+ let title_label = if self.title_generating {
Label::new(title)
- .size(title_label_size)
.color(Color::Muted)
.with_animation(
"generating-title",
@@ -278,66 +305,38 @@ impl RenderOnce for ThreadItem {
)
.into_any_element()
} else if highlight_positions.is_empty() {
- let label = Label::new(title).size(title_label_size);
- let label = if let Some(color) = self.title_label_color {
- label.color(color)
- } else {
- label
- };
- label.into_any_element()
- } else {
- let label = HighlightedLabel::new(title, highlight_positions).size(title_label_size);
- let label = if let Some(color) = self.title_label_color {
- label.color(color)
- } else {
- label
- };
- label.into_any_element()
- };
-
- let b_bg = color
- .title_bar_background
- .blend(color.panel_background.opacity(0.8));
-
- let base_bg = if self.selected {
- color.element_active
+ Label::new(title)
+ .when_some(self.title_label_color, |label, color| label.color(color))
+ .into_any_element()
} else {
- b_bg
+ HighlightedLabel::new(title, highlight_positions)
+ .when_some(self.title_label_color, |label, color| label.color(color))
+ .into_any_element()
};
- let gradient_overlay =
- GradientFade::new(base_bg, color.element_hover, color.element_active)
- .width(px(64.0))
- .right(px(-10.0))
- .gradient_stop(0.75)
- .group_name("thread-item");
-
let has_diff_stats = self.added.is_some() || self.removed.is_some();
+ let diff_stat_id = self.id.clone();
let added_count = self.added.unwrap_or(0);
let removed_count = self.removed.unwrap_or(0);
- let diff_stat_id = self.id.clone();
+
let has_worktree = self.worktree.is_some();
let has_timestamp = !self.timestamp.is_empty();
let timestamp = self.timestamp;
v_flex()
.id(self.id.clone())
+ .cursor_pointer()
.group("thread-item")
.relative()
.overflow_hidden()
- .cursor_pointer()
.w_full()
.py_1()
.px_1p5()
.when(self.selected, |s| s.bg(color.element_active))
.border_1()
.border_color(gpui::transparent_black())
- .when(self.focused, |s| {
- s.when(self.docked_right, |s| s.border_r_2())
- .border_color(color.border_focused)
- })
- .hover(|s| s.bg(color.element_hover))
- .active(|s| s.bg(color.element_active))
+ .when(self.focused, |s| s.border_color(color.border_focused))
+ .hover(|s| s.bg(hover_color))
.on_hover(self.on_hover)
.child(
h_flex()
@@ -358,15 +357,11 @@ impl RenderOnce for ThreadItem {
.child(gradient_overlay)
.when(self.hovered, |this| {
this.when_some(self.action_slot, |this, slot| {
- let overlay = GradientFade::new(
- base_bg,
- color.element_hover,
- color.element_active,
- )
- .width(px(64.0))
- .right(px(6.))
- .gradient_stop(0.75)
- .group_name("thread-item");
+ let overlay = GradientFade::new(base_bg, hover_color, hover_color)
+ .width(px(64.0))
+ .right(px(6.))
+ .gradient_stop(0.75)
+ .group_name("thread-item");
this.child(
h_flex()
@@ -380,57 +375,56 @@ impl RenderOnce for ThreadItem {
})
}),
)
- .when_some(self.worktree, |this, worktree| {
- let worktree_highlight_positions = self.worktree_highlight_positions;
- let worktree_label = if worktree_highlight_positions.is_empty() {
- Label::new(worktree)
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_any_element()
- } else {
- HighlightedLabel::new(worktree, worktree_highlight_positions)
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_any_element()
- };
+ .when(has_worktree || has_diff_stats || has_timestamp, |this| {
+ let worktree_full_path = self.worktree_full_path.clone().unwrap_or_default();
+ let worktree_label = self.worktree.map(|worktree| {
+ let positions = self.worktree_highlight_positions;
+ if positions.is_empty() {
+ Label::new(worktree)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element()
+ } else {
+ HighlightedLabel::new(worktree, positions)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element()
+ }
+ });
this.child(
h_flex()
.min_w_0()
.gap_1p5()
.child(icon_container()) // Icon Spacing
- .child(worktree_label)
- .when(has_diff_stats || has_timestamp, |this| {
- this.child(dot_separator())
- })
- .when(has_diff_stats, |this| {
+ .when_some(worktree_label, |this, label| {
this.child(
- DiffStat::new(diff_stat_id.clone(), added_count, removed_count)
- .tooltip("Unreviewed changes"),
+ h_flex()
+ .id(format!("{}-worktree", self.id.clone()))
+ .gap_1()
+ .child(
+ Icon::new(IconName::GitWorktree)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(label)
+ .tooltip(move |_, cx| {
+ Tooltip::with_meta(
+ "Thread Running in a Local Git Worktree",
+ None,
+ worktree_full_path.clone(),
+ cx,
+ )
+ }),
)
})
- .when(has_diff_stats && has_timestamp, |this| {
+ .when(has_worktree && (has_diff_stats || has_timestamp), |this| {
this.child(dot_separator())
})
- .when(has_timestamp, |this| {
- this.child(
- Label::new(timestamp.clone())
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- }),
- )
- })
- .when(!has_worktree && (has_diff_stats || has_timestamp), |this| {
- this.child(
- h_flex()
- .min_w_0()
- .gap_1p5()
- .child(icon_container()) // Icon Spacing
.when(has_diff_stats, |this| {
this.child(
DiffStat::new(diff_stat_id, added_count, removed_count)
- .tooltip("Unreviewed Changes"),
+ .tooltip("Unreviewed changes"),
)
})
.when(has_diff_stats && has_timestamp, |this| {
@@ -583,18 +577,6 @@ impl Component for ThreadItem {
)
.into_any_element(),
),
- single_example(
- "Focused + Docked Right",
- container()
- .child(
- ThreadItem::new("ti-7b", "Focused with right dock border")
- .icon(IconName::AiClaude)
- .timestamp("1w")
- .focused(true)
- .docked_right(true),
- )
- .into_any_element(),
- ),
single_example(
"Selected + Focused",
container()
@@ -4,7 +4,7 @@ use component::{Component, ComponentScope, example_group_with_title, single_exam
use gpui::{AnyElement, AnyView, ClickEvent, MouseButton, MouseDownEvent, Pixels, px};
use smallvec::SmallVec;
-use crate::{Disclosure, GradientFade, prelude::*};
+use crate::{Disclosure, prelude::*};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Default)]
pub enum ListItemSpacing {
@@ -31,9 +31,6 @@ pub struct ListItem {
/// A slot for content that appears on hover after the children
/// It will obscure the `end_slot` when visible.
end_hover_slot: Option<AnyElement>,
- /// When true, renders a gradient fade overlay before the `end_hover_slot`
- /// to smoothly truncate overflowing content.
- end_hover_gradient_overlay: bool,
toggle: Option<bool>,
inset: bool,
on_click: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>,
@@ -65,7 +62,6 @@ impl ListItem {
start_slot: None,
end_slot: None,
end_hover_slot: None,
- end_hover_gradient_overlay: false,
toggle: None,
inset: false,
on_click: None,
@@ -174,11 +170,6 @@ impl ListItem {
self
}
- pub fn end_hover_gradient_overlay(mut self, show: bool) -> Self {
- self.end_hover_gradient_overlay = show;
- self
- }
-
pub fn outlined(mut self) -> Self {
self.outlined = true;
self
@@ -232,21 +223,6 @@ impl ParentElement for ListItem {
impl RenderOnce for ListItem {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
- let color = cx.theme().colors();
-
- let base_bg = if self.selected {
- color.element_active
- } else {
- color.panel_background
- };
-
- let end_hover_gradient_overlay =
- GradientFade::new(base_bg, color.element_hover, color.element_active)
- .width(px(96.0))
- .when_some(self.group_name.clone(), |fade, group| {
- fade.group_name(group)
- });
-
h_flex()
.id(self.id)
.when_some(self.group_name, |this, group| this.group(group))
@@ -382,9 +358,6 @@ impl RenderOnce for ListItem {
.right(DynamicSpacing::Base06.rems(cx))
.top_0()
.visible_on_hover("list_item")
- .when(self.end_hover_gradient_overlay, |this| {
- this.child(end_hover_gradient_overlay)
- })
.child(end_hover_slot),
)
}),
@@ -162,7 +162,7 @@ impl RenderOnce for ModalHeader {
children.insert(
0,
Headline::new(headline)
- .size(HeadlineSize::XSmall)
+ .size(HeadlineSize::Small)
.color(Color::Muted)
.into_any_element(),
);
@@ -1726,7 +1726,15 @@ fn generate_commands(_: &App) -> Vec<VimCommand> {
)
.range(wrap_count),
VimCommand::new(("j", "oin"), JoinLines).range(select_range),
- VimCommand::new(("reflow", ""), Rewrap).range(select_range),
+ VimCommand::new(("reflow", ""), Rewrap { line_length: None })
+ .range(select_range)
+ .args(|_action, args| {
+ args.parse::<usize>().map_or(None, |length| {
+ Some(Box::new(Rewrap {
+ line_length: Some(length),
+ }))
+ })
+ }),
VimCommand::new(("fo", "ld"), editor::actions::FoldSelectedRanges).range(act_on_range),
VimCommand::new(("foldo", "pen"), editor::actions::UnfoldLines)
.bang(editor::actions::UnfoldRecursive)
@@ -3550,7 +3558,7 @@ mod test {
cx.set_state(
indoc! {"
- Λ0123456789 0123456789 0123456789 0123456789
+ Λ0123456789 0123456789
"},
Mode::Normal,
);
@@ -3560,8 +3568,6 @@ mod test {
cx.assert_state(
indoc! {"
- 0123456789
- 0123456789
0123456789
Λ0123456789
"},
@@ -3570,22 +3576,59 @@ mod test {
cx.set_state(
indoc! {"
- Β«0123456789 0123456789ΛΒ»
- 0123456789 0123456789
+ Λ0123456789 0123456789
"},
Mode::VisualLine,
);
- cx.simulate_keystrokes(": reflow");
+ cx.simulate_keystrokes("shift-v : reflow");
cx.simulate_keystrokes("enter");
cx.assert_state(
indoc! {"
- Λ0123456789
0123456789
- 0123456789 0123456789
+ Λ0123456789
"},
Mode::Normal,
);
+
+ cx.set_state(
+ indoc! {"
+ Λ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
+
+ cx.simulate_keystrokes(": reflow space 7");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ Λ0123
+ 4567
+ 0123
+ 4567
+ "},
+ Mode::Normal,
+ );
+
+ // Assert that, if `:reflow` is invoked with an invalid argument, it
+ // does not actually have any effect in the buffer's contents.
+ cx.set_state(
+ indoc! {"
+ Λ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
+
+ cx.simulate_keystrokes(": reflow space a");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ Λ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
}
}
@@ -1,19 +1,20 @@
use crate::{Vim, motion::Motion, object::Object, state::Mode};
use collections::HashMap;
use editor::{Bias, Editor, RewrapOptions, SelectionEffects, display_map::ToDisplayPoint};
-use gpui::{Context, Window, actions};
+use gpui::{Action, Context, Window};
use language::SelectionGoal;
+use schemars::JsonSchema;
+use serde::Deserialize;
-actions!(
- vim,
- [
- /// Rewraps the selected text to fit within the line width.
- Rewrap
- ]
-);
+/// Rewraps the selected text to fit within the line width.
+#[derive(Clone, Deserialize, JsonSchema, PartialEq, Action)]
+#[action(namespace = vim)]
+pub(crate) struct Rewrap {
+ pub line_length: Option<usize>,
+}
pub(crate) fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
- Vim::action(editor, cx, |vim, _: &Rewrap, window, cx| {
+ Vim::action(editor, cx, |vim, action: &Rewrap, window, cx| {
vim.record_current_action(cx);
Vim::take_count(cx);
Vim::take_forced_motion(cx);
@@ -24,6 +25,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
editor.rewrap_impl(
RewrapOptions {
override_language_settings: true,
+ line_length: action.line_length,
..Default::default()
},
cx,
@@ -11,8 +11,10 @@ use project::Project;
use settings::Settings;
use std::future::Future;
use std::path::PathBuf;
+use std::sync::Arc;
use ui::prelude::*;
use util::ResultExt;
+use zed_actions::agents_sidebar::MoveWorkspaceToNewWindow;
const SIDEBAR_RESIZE_HANDLE_SIZE: Pixels = px(6.0);
@@ -30,6 +32,10 @@ actions!(
CloseWorkspaceSidebar,
/// Moves focus to or from the workspace sidebar without closing it.
FocusWorkspaceSidebar,
+ /// Switches to the next workspace.
+ NextWorkspace,
+ /// Switches to the previous workspace.
+ PreviousWorkspace,
]
);
@@ -405,6 +411,29 @@ impl MultiWorkspace {
cx.notify();
}
+ fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context<Self>) {
+ let count = self.workspaces.len() as isize;
+ if count <= 1 {
+ return;
+ }
+ let current = self.active_workspace_index as isize;
+ let next = ((current + delta).rem_euclid(count)) as usize;
+ self.activate_index(next, window, cx);
+ }
+
+ fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context<Self>) {
+ self.cycle_workspace(1, window, cx);
+ }
+
+ fn previous_workspace(
+ &mut self,
+ _: &PreviousWorkspace,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.cycle_workspace(-1, window, cx);
+ }
+
fn serialize(&mut self, cx: &mut App) {
let window_id = self.window_id;
let state = crate::persistence::model::MultiWorkspaceState {
@@ -609,9 +638,14 @@ impl MultiWorkspace {
})
}
- pub fn remove_workspace(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
+ pub fn remove_workspace(
+ &mut self,
+ index: usize,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<Entity<Workspace>> {
if self.workspaces.len() <= 1 || index >= self.workspaces.len() {
- return;
+ return None;
}
let removed_workspace = self.workspaces.remove(index);
@@ -622,6 +656,16 @@ impl MultiWorkspace {
self.active_workspace_index -= 1;
}
+ // Clear session_id and cancel any in-flight serialization on the
+ // removed workspace. Without this, a pending throttle timer from
+ // `serialize_workspace` could fire and write the old session_id
+ // back to the DB, resurrecting the workspace on next launch.
+ removed_workspace.update(cx, |workspace, _cx| {
+ workspace.session_id.take();
+ workspace._schedule_serialize_workspace.take();
+ workspace._serialize_workspace_task.take();
+ });
+
if let Some(workspace_id) = removed_workspace.read(cx).database_id() {
let db = crate::persistence::WorkspaceDb::global(cx);
self.pending_removal_tasks.retain(|task| !task.is_ready());
@@ -642,6 +686,49 @@ impl MultiWorkspace {
));
cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
cx.notify();
+
+ Some(removed_workspace)
+ }
+
+ pub fn move_workspace_to_new_window(
+ &mut self,
+ index: usize,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if self.workspaces.len() <= 1 || index >= self.workspaces.len() {
+ return;
+ }
+
+ let Some(workspace) = self.remove_workspace(index, window, cx) else {
+ return;
+ };
+
+ let app_state: Arc<crate::AppState> = workspace.read(cx).app_state().clone();
+
+ cx.defer(move |cx| {
+ let options = (app_state.build_window_options)(None, cx);
+
+ let Ok(window) = cx.open_window(options, |window, cx| {
+ cx.new(|cx| MultiWorkspace::new(workspace, window, cx))
+ }) else {
+ return;
+ };
+
+ let _ = window.update(cx, |_, window, _| {
+ window.activate_window();
+ });
+ });
+ }
+
+ fn move_active_workspace_to_new_window(
+ &mut self,
+ _: &MoveWorkspaceToNewWindow,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let index = self.active_workspace_index;
+ self.move_workspace_to_new_window(index, window, cx);
}
pub fn open_project(
@@ -760,6 +847,9 @@ impl Render for MultiWorkspace {
this.focus_sidebar(window, cx);
},
))
+ .on_action(cx.listener(Self::next_workspace))
+ .on_action(cx.listener(Self::previous_workspace))
+ .on_action(cx.listener(Self::move_active_workspace_to_new_window))
})
.when(
self.sidebar_open() && self.multi_workspace_enabled(cx),
@@ -28,7 +28,8 @@ pub use crate::notifications::NotificationFrame;
pub use dock::Panel;
pub use multi_workspace::{
CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace,
- MultiWorkspaceEvent, Sidebar, SidebarHandle, ToggleWorkspaceSidebar,
+ MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarHandle,
+ ToggleWorkspaceSidebar,
};
pub use path_list::{PathList, SerializedPathList};
pub use toast_layer::{ToastAction, ToastLayer, ToastView};
@@ -128,6 +128,7 @@ pub struct LocalWorktree {
scan_requests_tx: channel::Sender<ScanRequest>,
path_prefixes_to_scan_tx: channel::Sender<PathPrefixScanRequest>,
is_scanning: (watch::Sender<bool>, watch::Receiver<bool>),
+ snapshot_subscriptions: VecDeque<(usize, oneshot::Sender<()>)>,
_background_scanner_tasks: Vec<Task<()>>,
update_observer: Option<UpdateObservationState>,
fs: Arc<dyn Fs>,
@@ -470,6 +471,7 @@ impl Worktree {
next_entry_id,
snapshot,
is_scanning: watch::channel_with(true),
+ snapshot_subscriptions: Default::default(),
update_observer: None,
scan_requests_tx,
path_prefixes_to_scan_tx,
@@ -714,6 +716,16 @@ impl Worktree {
}
}
+ pub fn wait_for_snapshot(
+ &mut self,
+ scan_id: usize,
+ ) -> impl Future<Output = Result<()>> + use<> {
+ match self {
+ Worktree::Local(this) => this.wait_for_snapshot(scan_id).boxed(),
+ Worktree::Remote(this) => this.wait_for_snapshot(scan_id).boxed(),
+ }
+ }
+
#[cfg(feature = "test-support")]
pub fn has_update_observer(&self) -> bool {
match self {
@@ -1170,6 +1182,15 @@ impl LocalWorktree {
if !repo_changes.is_empty() {
cx.emit(Event::UpdatedGitRepositories(repo_changes));
}
+
+ while let Some((scan_id, _)) = self.snapshot_subscriptions.front() {
+ if self.snapshot.completed_scan_id >= *scan_id {
+ let (_, tx) = self.snapshot_subscriptions.pop_front().unwrap();
+ tx.send(()).ok();
+ } else {
+ break;
+ }
+ }
}
fn changed_repos(
@@ -1286,6 +1307,28 @@ impl LocalWorktree {
}
}
+ pub fn wait_for_snapshot(
+ &mut self,
+ scan_id: usize,
+ ) -> impl Future<Output = Result<()>> + use<> {
+ let (tx, rx) = oneshot::channel();
+ if self.snapshot.completed_scan_id >= scan_id {
+ tx.send(()).ok();
+ } else {
+ match self
+ .snapshot_subscriptions
+ .binary_search_by_key(&scan_id, |probe| probe.0)
+ {
+ Ok(ix) | Err(ix) => self.snapshot_subscriptions.insert(ix, (scan_id, tx)),
+ }
+ }
+
+ async move {
+ rx.await?;
+ Ok(())
+ }
+ }
+
pub fn snapshot(&self) -> LocalSnapshot {
self.snapshot.clone()
}
@@ -7,12 +7,14 @@ fn main() {
// Add rpaths for libraries that webrtc-sys dlopens at runtime.
// This is mostly required for hosts with non-standard SO installation
// locations such as NixOS.
- let dlopened_libs = ["libva", "libva-drm"];
+ let dlopened_libs = ["libva", "libva-drm", "egl"];
let mut rpath_dirs = std::collections::BTreeSet::new();
for lib in &dlopened_libs {
if let Some(libdir) = pkg_config::get_variable(lib, "libdir").ok() {
rpath_dirs.insert(libdir);
+ } else {
+ eprintln!("zed build.rs: {lib} not found in pkg-config's path");
}
}
@@ -786,6 +786,8 @@ pub mod agents_sidebar {
[
/// Moves focus to the sidebar's search/filter editor.
FocusSidebarFilter,
+ /// Moves the active workspace to a new window.
+ MoveWorkspaceToNewWindow,
]
);
}
@@ -3,10 +3,14 @@ use anyhow::{Context as _, Result, anyhow};
pub const MARKER_TAG_PREFIX: &str = "<|marker_";
pub const MARKER_TAG_SUFFIX: &str = "|>";
pub const RELATIVE_MARKER_TAG_PREFIX: &str = "<|marker";
-const MIN_BLOCK_LINES: usize = 3;
-const MAX_BLOCK_LINES: usize = 8;
+const V0316_MIN_BLOCK_LINES: usize = 3;
+const V0316_MAX_BLOCK_LINES: usize = 8;
+const V0318_MIN_BLOCK_LINES: usize = 6;
+const V0318_MAX_BLOCK_LINES: usize = 16;
+const MAX_NUDGE_LINES: usize = 5;
pub const V0316_END_MARKER: &str = "<[endβofβsentence]>";
pub const V0317_END_MARKER: &str = "<[endβofβsentence]>";
+pub const V0318_END_MARKER: &str = "<[endβofβsentence]>";
pub fn marker_tag(number: usize) -> String {
format!("{MARKER_TAG_PREFIX}{number}{MARKER_TAG_SUFFIX}")
@@ -22,71 +26,104 @@ pub fn marker_tag_relative(delta: isize) -> String {
}
}
+struct LineInfo {
+ start: usize,
+ is_blank: bool,
+ is_good_start: bool,
+}
+
+fn collect_line_info(text: &str) -> Vec<LineInfo> {
+ let mut lines = Vec::new();
+ let mut offset = 0;
+ for line in text.split('\n') {
+ let trimmed = line.trim();
+ let is_blank = trimmed.is_empty();
+ let is_good_start = !is_blank && !is_structural_tail(trimmed);
+ lines.push(LineInfo {
+ start: offset,
+ is_blank,
+ is_good_start,
+ });
+ offset += line.len() + 1;
+ }
+ // split('\n') on "abc\n" yields ["abc", ""] β drop the phantom trailing
+ // empty element when the text ends with '\n'.
+ if text.ends_with('\n') && lines.len() > 1 {
+ lines.pop();
+ }
+ lines
+}
+
+fn is_structural_tail(trimmed_line: &str) -> bool {
+ if trimmed_line.starts_with(&['}', ']', ')']) {
+ return true;
+ }
+ matches!(
+ trimmed_line.trim_end_matches(';'),
+ "break" | "continue" | "return" | "throw" | "end"
+ )
+}
+
+/// Starting from line `from`, scan up to `MAX_NUDGE_LINES` forward to find a
+/// line with `is_good_start`. Returns `None` if no suitable line is found.
+fn skip_to_good_start(lines: &[LineInfo], from: usize) -> Option<usize> {
+ (from..lines.len().min(from + MAX_NUDGE_LINES)).find(|&i| lines[i].is_good_start)
+}
+
/// Compute byte offsets within `editable_text` where marker boundaries should
/// be placed.
///
/// Returns a sorted `Vec<usize>` that always starts with `0` and ends with
/// `editable_text.len()`. Interior offsets are placed at line boundaries
/// (right after a `\n`), preferring blank-line boundaries when available and
-/// respecting `MIN_BLOCK_LINES` / `MAX_BLOCK_LINES` constraints.
-pub fn compute_marker_offsets(editable_text: &str) -> Vec<usize> {
+/// respecting `min_block_lines` / `max_block_lines` constraints.
+fn compute_marker_offsets_with_limits(
+ editable_text: &str,
+ min_block_lines: usize,
+ max_block_lines: usize,
+) -> Vec<usize> {
if editable_text.is_empty() {
return vec![0, 0];
}
+ let lines = collect_line_info(editable_text);
let mut offsets = vec![0usize];
- let mut lines_since_last_marker = 0usize;
- let mut byte_offset = 0usize;
-
- for line in editable_text.split('\n') {
- let line_end = byte_offset + line.len() + 1;
- let is_past_end = line_end > editable_text.len();
- let actual_line_end = line_end.min(editable_text.len());
- lines_since_last_marker += 1;
-
- let is_blank = line.trim().is_empty();
-
- if !is_past_end && lines_since_last_marker >= MIN_BLOCK_LINES {
- if is_blank {
- // Blank-line boundary found. We'll place the marker when we
- // find the next non-blank line (handled below).
- } else if lines_since_last_marker >= MAX_BLOCK_LINES {
- offsets.push(actual_line_end);
- lines_since_last_marker = 0;
- }
- }
+ let mut last_boundary_line = 0;
+ let mut i = 0;
+
+ while i < lines.len() {
+ let gap = i - last_boundary_line;
- // Non-blank line immediately following blank line(s): split here so
- // the new block starts with this line.
- if !is_blank && byte_offset > 0 && lines_since_last_marker >= MIN_BLOCK_LINES {
- let before = &editable_text[..byte_offset];
- let has_preceding_blank_line = before
- .strip_suffix('\n')
- .map(|stripped| {
- let last_line = match stripped.rfind('\n') {
- Some(pos) => &stripped[pos + 1..],
- None => stripped,
- };
- last_line.trim().is_empty()
- })
- .unwrap_or(false);
-
- if has_preceding_blank_line {
- offsets.push(byte_offset);
- lines_since_last_marker = 1;
+ // Blank-line split: non-blank line following blank line(s) with enough
+ // accumulated lines.
+ if gap >= min_block_lines && !lines[i].is_blank && i > 0 && lines[i - 1].is_blank {
+ let target = if lines[i].is_good_start {
+ i
+ } else {
+ skip_to_good_start(&lines, i).unwrap_or(i)
+ };
+ if lines.len() - target >= min_block_lines
+ && lines[target].start > *offsets.last().unwrap_or(&0)
+ {
+ offsets.push(lines[target].start);
+ last_boundary_line = target;
+ i = target + 1;
+ continue;
}
}
- byte_offset = actual_line_end;
-
- // Re-check after blank-line logic since lines_since_last_marker may
- // have been reset.
- if !is_past_end && lines_since_last_marker >= MAX_BLOCK_LINES {
- if *offsets.last().unwrap_or(&0) != actual_line_end {
- offsets.push(actual_line_end);
- lines_since_last_marker = 0;
+ // Hard cap: too many lines without a split.
+ if gap >= max_block_lines {
+ let target = skip_to_good_start(&lines, i).unwrap_or(i);
+ if lines[target].start > *offsets.last().unwrap_or(&0) {
+ offsets.push(lines[target].start);
+ last_boundary_line = target;
+ i = target + 1;
+ continue;
}
}
+
+ i += 1;
}
let end = editable_text.len();
@@ -97,6 +134,15 @@ pub fn compute_marker_offsets(editable_text: &str) -> Vec<usize> {
offsets
}
+/// Compute byte offsets within `editable_text` for the V0316/V0317 block sizing rules.
+pub fn compute_marker_offsets(editable_text: &str) -> Vec<usize> {
+ compute_marker_offsets_with_limits(editable_text, V0316_MIN_BLOCK_LINES, V0316_MAX_BLOCK_LINES)
+}
+
+pub fn compute_marker_offsets_v0318(editable_text: &str) -> Vec<usize> {
+ compute_marker_offsets_with_limits(editable_text, V0318_MIN_BLOCK_LINES, V0318_MAX_BLOCK_LINES)
+}
+
/// Write the editable region content with marker tags, inserting the cursor
/// marker at the given offset within the editable text.
pub fn write_editable_with_markers(
@@ -267,27 +313,8 @@ pub fn encode_from_old_and_new(
}
let marker_offsets = compute_marker_offsets(old_editable);
-
- let common_prefix = old_editable
- .bytes()
- .zip(new_editable.bytes())
- .take_while(|(a, b)| a == b)
- .count();
-
- let old_remaining = old_editable.len() - common_prefix;
- let new_remaining = new_editable.len() - common_prefix;
- let max_suffix = old_remaining.min(new_remaining);
- let common_suffix = old_editable.as_bytes()[old_editable.len() - max_suffix..]
- .iter()
- .rev()
- .zip(
- new_editable.as_bytes()[new_editable.len() - max_suffix..]
- .iter()
- .rev(),
- )
- .take_while(|(a, b)| a == b)
- .count();
-
+ let (common_prefix, common_suffix) =
+ common_prefix_suffix(old_editable.as_bytes(), new_editable.as_bytes());
let change_end_in_old = old_editable.len() - common_suffix;
let start_marker_idx = marker_offsets
@@ -380,55 +407,24 @@ pub fn extract_editable_region_from_markers(text: &str) -> Option<String> {
Some(result)
}
-struct MarkerTag {
- number: usize,
- tag_start: usize,
- tag_end: usize,
-}
-
-struct RelativeMarkerTag {
- delta: isize,
+struct ParsedTag {
+ value: isize,
tag_start: usize,
tag_end: usize,
}
-fn collect_marker_tags(text: &str) -> Vec<MarkerTag> {
- let mut markers = Vec::new();
+fn collect_tags(text: &str, prefix: &str, parse: fn(&str) -> Option<isize>) -> Vec<ParsedTag> {
+ let mut tags = Vec::new();
let mut search_from = 0;
- while let Some(rel_pos) = text[search_from..].find(MARKER_TAG_PREFIX) {
+ while let Some(rel_pos) = text[search_from..].find(prefix) {
let tag_start = search_from + rel_pos;
- let num_start = tag_start + MARKER_TAG_PREFIX.len();
- if let Some(suffix_rel) = text[num_start..].find(MARKER_TAG_SUFFIX) {
- let num_end = num_start + suffix_rel;
- if let Ok(number) = text[num_start..num_end].parse::<usize>() {
- let tag_end = num_end + MARKER_TAG_SUFFIX.len();
- markers.push(MarkerTag {
- number,
- tag_start,
- tag_end,
- });
- search_from = tag_end;
- continue;
- }
- }
- search_from = tag_start + MARKER_TAG_PREFIX.len();
- }
- markers
-}
-
-fn collect_relative_marker_tags(text: &str) -> Vec<RelativeMarkerTag> {
- let mut markers = Vec::new();
- let mut search_from = 0;
- while let Some(rel_pos) = text[search_from..].find(RELATIVE_MARKER_TAG_PREFIX) {
- let tag_start = search_from + rel_pos;
- let payload_start = tag_start + RELATIVE_MARKER_TAG_PREFIX.len();
+ let payload_start = tag_start + prefix.len();
if let Some(suffix_rel) = text[payload_start..].find(MARKER_TAG_SUFFIX) {
let payload_end = payload_start + suffix_rel;
- let payload = &text[payload_start..payload_end];
- if let Ok(delta) = payload.parse::<isize>() {
+ if let Some(value) = parse(&text[payload_start..payload_end]) {
let tag_end = payload_end + MARKER_TAG_SUFFIX.len();
- markers.push(RelativeMarkerTag {
- delta,
+ tags.push(ParsedTag {
+ value,
tag_start,
tag_end,
});
@@ -436,9 +432,21 @@ fn collect_relative_marker_tags(text: &str) -> Vec<RelativeMarkerTag> {
continue;
}
}
- search_from = tag_start + RELATIVE_MARKER_TAG_PREFIX.len();
+ search_from = tag_start + prefix.len();
}
- markers
+ tags
+}
+
+fn collect_marker_tags(text: &str) -> Vec<ParsedTag> {
+ collect_tags(text, MARKER_TAG_PREFIX, |s| {
+ s.parse::<usize>().ok().map(|n| n as isize)
+ })
+}
+
+fn collect_relative_marker_tags(text: &str) -> Vec<ParsedTag> {
+ collect_tags(text, RELATIVE_MARKER_TAG_PREFIX, |s| {
+ s.parse::<isize>().ok()
+ })
}
pub fn nearest_marker_number(cursor_offset: Option<usize>, marker_offsets: &[usize]) -> usize {
@@ -459,21 +467,87 @@ fn cursor_block_index(cursor_offset: Option<usize>, marker_offsets: &[usize]) ->
.unwrap_or_else(|| marker_offsets.len().saturating_sub(2))
}
-/// Write the editable region content with V0317 byte-exact marker tags, where
-/// marker numbers are relative to the cursor block.
-pub fn write_editable_with_markers_v0317(
+fn common_prefix_suffix(a: &[u8], b: &[u8]) -> (usize, usize) {
+ let prefix = a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count();
+ let remaining_a = a.len() - prefix;
+ let remaining_b = b.len() - prefix;
+ let max_suffix = remaining_a.min(remaining_b);
+ let suffix = a[a.len() - max_suffix..]
+ .iter()
+ .rev()
+ .zip(b[b.len() - max_suffix..].iter().rev())
+ .take_while(|(x, y)| x == y)
+ .count();
+ (prefix, suffix)
+}
+
+/// Map a byte offset from old span coordinates to new span coordinates,
+/// using common prefix/suffix within the span for accuracy.
+fn map_boundary_offset(
+ old_rel: usize,
+ old_span_len: usize,
+ new_span_len: usize,
+ span_common_prefix: usize,
+ span_common_suffix: usize,
+) -> usize {
+ if old_rel <= span_common_prefix {
+ old_rel
+ } else if old_rel >= old_span_len - span_common_suffix {
+ new_span_len - (old_span_len - old_rel)
+ } else {
+ let old_changed_start = span_common_prefix;
+ let old_changed_len = old_span_len
+ .saturating_sub(span_common_prefix)
+ .saturating_sub(span_common_suffix);
+ let new_changed_start = span_common_prefix;
+ let new_changed_len = new_span_len
+ .saturating_sub(span_common_prefix)
+ .saturating_sub(span_common_suffix);
+
+ if old_changed_len == 0 {
+ new_changed_start
+ } else {
+ new_changed_start + ((old_rel - old_changed_start) * new_changed_len / old_changed_len)
+ }
+ }
+}
+
+fn snap_to_line_start(text: &str, offset: usize) -> usize {
+ let bounded = offset.min(text.len());
+ let bounded = text.floor_char_boundary(bounded);
+
+ if bounded >= text.len() {
+ return text.len();
+ }
+
+ if bounded == 0 || text.as_bytes().get(bounded - 1) == Some(&b'\n') {
+ return bounded;
+ }
+
+ if let Some(next_nl_rel) = text[bounded..].find('\n') {
+ let next = bounded + next_nl_rel + 1;
+ return text.floor_char_boundary(next.min(text.len()));
+ }
+
+ let prev_start = text[..bounded].rfind('\n').map(|idx| idx + 1).unwrap_or(0);
+ text.floor_char_boundary(prev_start)
+}
+
+/// Write the editable region content with byte-exact marker tags, inserting the
+/// cursor marker at the given offset within the editable text.
+///
+/// The `tag_for_index` closure maps a boundary index to the marker tag string.
+fn write_editable_with_markers_impl(
output: &mut String,
editable_text: &str,
cursor_offset_in_editable: usize,
cursor_marker: &str,
+ marker_offsets: &[usize],
+ tag_for_index: impl Fn(usize) -> String,
) {
- let marker_offsets = compute_marker_offsets(editable_text);
- let anchor_idx = cursor_block_index(Some(cursor_offset_in_editable), &marker_offsets);
let mut cursor_placed = false;
-
for (i, &offset) in marker_offsets.iter().enumerate() {
- let marker_delta = i as isize - anchor_idx as isize;
- output.push_str(&marker_tag_relative(marker_delta));
+ output.push_str(&tag_for_index(i));
if let Some(&next_offset) = marker_offsets.get(i + 1) {
let block = &editable_text[offset..next_offset];
@@ -493,11 +567,6 @@ pub fn write_editable_with_markers_v0317(
}
}
-/// Write the editable region content with V0316 byte-exact marker tags.
-///
-/// Unlike the V0306 version, markers are pure delimiters with no newline
-/// padding. The content between markers is the exact bytes from the editable
-/// text.
pub fn write_editable_with_markers_v0316(
output: &mut String,
editable_text: &str,
@@ -505,103 +574,93 @@ pub fn write_editable_with_markers_v0316(
cursor_marker: &str,
) {
let marker_offsets = compute_marker_offsets(editable_text);
- let mut cursor_placed = false;
- for (i, &offset) in marker_offsets.iter().enumerate() {
- let marker_num = i + 1;
- output.push_str(&marker_tag(marker_num));
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag(i + 1),
+ );
+}
- if let Some(&next_offset) = marker_offsets.get(i + 1) {
- let block = &editable_text[offset..next_offset];
- if !cursor_placed
- && cursor_offset_in_editable >= offset
- && cursor_offset_in_editable <= next_offset
- {
- cursor_placed = true;
- let cursor_in_block = cursor_offset_in_editable - offset;
- output.push_str(&block[..cursor_in_block]);
- output.push_str(cursor_marker);
- output.push_str(&block[cursor_in_block..]);
- } else {
- output.push_str(block);
- }
- }
- }
+pub fn write_editable_with_markers_v0317(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets(editable_text);
+ let anchor_idx = cursor_block_index(Some(cursor_offset_in_editable), &marker_offsets);
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag_relative(i as isize - anchor_idx as isize),
+ );
}
-/// Parse V0316 model output and reconstruct the full new editable region.
-///
-/// V0316 differences from V0306:
-/// - No newline stripping or normalization (byte-exact content).
-/// - The no-edit signal is `start_num == end_num` (any repeated marker).
-/// - Intermediate marker tags are used for block-level extraction.
-pub fn apply_marker_span_v0316(old_editable: &str, output: &str) -> Result<String> {
- let markers = collect_marker_tags(output);
+pub fn write_editable_with_markers_v0318(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets_v0318(editable_text);
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag(i + 1),
+ );
+}
- if markers.is_empty() {
+/// Parse byte-exact model output and reconstruct the full new editable region.
+///
+/// `resolve_boundary` maps a parsed tag value to an absolute byte offset in
+/// old_editable, given the marker_offsets. Returns `(start_byte, end_byte)` or
+/// an error.
+fn apply_marker_span_impl(
+ old_editable: &str,
+ tags: &[ParsedTag],
+ output: &str,
+ resolve_boundaries: impl Fn(isize, isize) -> Result<(usize, usize)>,
+) -> Result<String> {
+ if tags.is_empty() {
return Err(anyhow!("no marker tags found in output"));
}
-
- if markers.len() == 1 {
+ if tags.len() == 1 {
return Err(anyhow!(
"only one marker tag found in output, expected at least two"
));
}
- let start_num = markers
- .first()
- .map(|marker| marker.number)
- .context("missing first marker")?;
- let end_num = markers
- .last()
- .map(|marker| marker.number)
- .context("missing last marker")?;
+ let start_value = tags[0].value;
+ let end_value = tags[tags.len() - 1].value;
- // No-edit signal: start_num == end_num
- if start_num == end_num {
+ if start_value == end_value {
return Ok(old_editable.to_string());
}
- // Validate monotonically increasing with no gaps
- let expected_nums: Vec<usize> = (start_num..=end_num).collect();
- let actual_nums: Vec<usize> = markers.iter().map(|m| m.number).collect();
- if actual_nums != expected_nums {
- eprintln!(
- "V0316 marker sequence validation failed: expected {:?}, got {:?}. Attempting best-effort parse.",
- expected_nums, actual_nums
- );
- }
-
- let marker_offsets = compute_marker_offsets(old_editable);
-
- let start_idx = start_num
- .checked_sub(1)
- .context("marker numbers are 1-indexed")?;
- let end_idx = end_num
- .checked_sub(1)
- .context("marker numbers are 1-indexed")?;
-
- let start_byte = *marker_offsets
- .get(start_idx)
- .context("start marker number out of range")?;
- let end_byte = *marker_offsets
- .get(end_idx)
- .context("end marker number out of range")?;
+ let (start_byte, end_byte) = resolve_boundaries(start_value, end_value)?;
if start_byte > end_byte {
return Err(anyhow!("start marker must come before end marker"));
}
- // Extract byte-exact content between consecutive markers
let mut new_content = String::new();
- for i in 0..markers.len() - 1 {
- let content_start = markers[i].tag_end;
- let content_end = markers[i + 1].tag_start;
+ for i in 0..tags.len() - 1 {
+ let content_start = tags[i].tag_end;
+ let content_end = tags[i + 1].tag_start;
if content_start <= content_end {
new_content.push_str(&output[content_start..content_end]);
}
}
- // Splice into old_editable
let mut result = String::new();
result.push_str(&old_editable[..start_byte]);
result.push_str(&new_content);
@@ -610,134 +669,146 @@ pub fn apply_marker_span_v0316(old_editable: &str, output: &str) -> Result<Strin
Ok(result)
}
-/// Parse V0317 model output and reconstruct the full new editable region.
-///
-/// V0317 differences from V0316:
-/// - Marker ids are relative to the cursor block (e.g. -2, -1, 0, +1, +2).
-/// - No-edit signal is any repeated relative marker tag.
+pub fn apply_marker_span_v0316(old_editable: &str, output: &str) -> Result<String> {
+ let tags = collect_marker_tags(output);
+
+ // Validate monotonically increasing with no gaps (best-effort warning)
+ if tags.len() >= 2 {
+ let start_num = tags[0].value;
+ let end_num = tags[tags.len() - 1].value;
+ if start_num != end_num {
+ let expected: Vec<isize> = (start_num..=end_num).collect();
+ let actual: Vec<isize> = tags.iter().map(|t| t.value).collect();
+ if actual != expected {
+ eprintln!(
+ "V0316 marker sequence validation failed: expected {:?}, got {:?}. Attempting best-effort parse.",
+ expected, actual
+ );
+ }
+ }
+ }
+
+ let marker_offsets = compute_marker_offsets(old_editable);
+ apply_marker_span_impl(old_editable, &tags, output, |start_val, end_val| {
+ let start_idx = (start_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = (end_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
+}
+
pub fn apply_marker_span_v0317(
old_editable: &str,
output: &str,
cursor_offset_in_old: Option<usize>,
) -> Result<String> {
- let markers = collect_relative_marker_tags(output);
-
- if markers.is_empty() {
- return Err(anyhow!("no marker tags found in output"));
- }
-
- if markers.len() == 1 {
- return Err(anyhow!(
- "only one marker tag found in output, expected at least two"
- ));
- }
-
+ let tags = collect_relative_marker_tags(output);
let marker_offsets = compute_marker_offsets(old_editable);
let anchor_idx = cursor_block_index(cursor_offset_in_old, &marker_offsets);
- let start_delta = markers
- .first()
- .map(|marker| marker.delta)
- .context("missing first marker")?;
- let end_delta = markers
- .last()
- .map(|marker| marker.delta)
- .context("missing last marker")?;
-
- if start_delta == end_delta {
- return Ok(old_editable.to_string());
- }
-
- let start_idx_isize = anchor_idx as isize + start_delta;
- let end_idx_isize = anchor_idx as isize + end_delta;
- if start_idx_isize < 0 || end_idx_isize < 0 {
- return Err(anyhow!("relative marker maps before first marker"));
- }
-
- let start_idx = usize::try_from(start_idx_isize).context("invalid start marker index")?;
- let end_idx = usize::try_from(end_idx_isize).context("invalid end marker index")?;
-
- let start_byte = *marker_offsets
- .get(start_idx)
- .context("start marker number out of range")?;
- let end_byte = *marker_offsets
- .get(end_idx)
- .context("end marker number out of range")?;
-
- if start_byte > end_byte {
- return Err(anyhow!("start marker must come before end marker"));
- }
+ apply_marker_span_impl(old_editable, &tags, output, |start_delta, end_delta| {
+ let start_idx_signed = anchor_idx as isize + start_delta;
+ let end_idx_signed = anchor_idx as isize + end_delta;
+ if start_idx_signed < 0 || end_idx_signed < 0 {
+ return Err(anyhow!("relative marker maps before first marker"));
+ }
+ let start_idx = usize::try_from(start_idx_signed).context("invalid start marker index")?;
+ let end_idx = usize::try_from(end_idx_signed).context("invalid end marker index")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
+}
- let mut new_content = String::new();
- for i in 0..markers.len() - 1 {
- let content_start = markers[i].tag_end;
- let content_end = markers[i + 1].tag_start;
- if content_start <= content_end {
- new_content.push_str(&output[content_start..content_end]);
+pub fn apply_marker_span_v0318(old_editable: &str, output: &str) -> Result<String> {
+ let tags = collect_marker_tags(output);
+
+ if tags.len() >= 2 {
+ let start_num = tags[0].value;
+ let end_num = tags[tags.len() - 1].value;
+ if start_num != end_num {
+ let expected: Vec<isize> = (start_num..=end_num).collect();
+ let actual: Vec<isize> = tags.iter().map(|t| t.value).collect();
+ if actual != expected {
+ eprintln!(
+ "V0318 marker sequence validation failed: expected {:?}, got {:?}. Attempting best-effort parse.",
+ expected, actual
+ );
+ }
}
}
- let mut result = String::new();
- result.push_str(&old_editable[..start_byte]);
- result.push_str(&new_content);
- result.push_str(&old_editable[end_byte..]);
-
- Ok(result)
+ let marker_offsets = compute_marker_offsets_v0318(old_editable);
+ apply_marker_span_impl(old_editable, &tags, output, |start_val, end_val| {
+ let start_idx = (start_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = (end_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
}
-/// Encode the V0316 training target from old and new editable text.
+/// Encode the training target from old and new editable text.
///
-/// V0316 differences from V0306:
-/// - No-edit signal: `<|marker_C|><|marker_C|>{end_marker}` where C is nearest
-/// to cursor.
-/// - All intermediate markers are emitted with byte-exact content.
-/// - No newline padding around marker tags.
-pub fn encode_from_old_and_new_v0316(
+/// Shared implementation for V0316, V0317, and V0318. The `tag_for_block_idx`
+/// closure maps a block index to the appropriate marker tag string.
+/// `no_edit_tag` is the marker tag to repeat when there are no edits.
+fn encode_from_old_and_new_impl(
old_editable: &str,
new_editable: &str,
cursor_offset_in_new: Option<usize>,
cursor_marker: &str,
end_marker: &str,
+ no_edit_tag: &str,
+ marker_offsets: &[usize],
+ tag_for_block_idx: impl Fn(usize) -> String,
) -> Result<String> {
- let marker_offsets = compute_marker_offsets(old_editable);
-
if old_editable == new_editable {
- let marker_num = nearest_marker_number(cursor_offset_in_new, &marker_offsets);
- let tag = marker_tag(marker_num);
- return Ok(format!("{tag}{tag}{end_marker}"));
+ return Ok(format!("{no_edit_tag}{no_edit_tag}{end_marker}"));
}
- let common_prefix = old_editable
- .bytes()
- .zip(new_editable.bytes())
- .take_while(|(a, b)| a == b)
- .count();
-
- let old_remaining = old_editable.len() - common_prefix;
- let new_remaining = new_editable.len() - common_prefix;
- let max_suffix = old_remaining.min(new_remaining);
- let common_suffix = old_editable.as_bytes()[old_editable.len() - max_suffix..]
- .iter()
- .rev()
- .zip(
- new_editable.as_bytes()[new_editable.len() - max_suffix..]
- .iter()
- .rev(),
- )
- .take_while(|(a, b)| a == b)
- .count();
-
+ let (common_prefix, common_suffix) =
+ common_prefix_suffix(old_editable.as_bytes(), new_editable.as_bytes());
let change_end_in_old = old_editable.len() - common_suffix;
- let start_marker_idx = marker_offsets
+ let mut start_marker_idx = marker_offsets
.iter()
.rposition(|&offset| offset <= common_prefix)
.unwrap_or(0);
- let end_marker_idx = marker_offsets
+ let mut end_marker_idx = marker_offsets
.iter()
.position(|&offset| offset >= change_end_in_old)
.unwrap_or(marker_offsets.len() - 1);
+ if start_marker_idx == end_marker_idx {
+ if end_marker_idx < marker_offsets.len().saturating_sub(1) {
+ end_marker_idx += 1;
+ } else if start_marker_idx > 0 {
+ start_marker_idx -= 1;
+ }
+ }
+
let old_start = marker_offsets[start_marker_idx];
let old_end = marker_offsets[end_marker_idx];
@@ -749,40 +820,19 @@ pub fn encode_from_old_and_new_v0316(
let new_span = &new_editable[new_start..new_end];
let old_span = &old_editable[old_start..old_end];
- // Compute common prefix/suffix within the span for accurate boundary mapping
- let span_common_prefix = old_span
- .bytes()
- .zip(new_span.bytes())
- .take_while(|(a, b)| a == b)
- .count();
-
- let span_old_remaining = old_span.len() - span_common_prefix;
- let span_new_remaining = new_span.len() - span_common_prefix;
- let span_max_suffix = span_old_remaining.min(span_new_remaining);
- let span_common_suffix = old_span.as_bytes()[old_span.len() - span_max_suffix..]
- .iter()
- .rev()
- .zip(
- new_span.as_bytes()[new_span.len() - span_max_suffix..]
- .iter()
- .rev(),
- )
- .take_while(|(a, b)| a == b)
- .count();
+ let (span_common_prefix, span_common_suffix) =
+ common_prefix_suffix(old_span.as_bytes(), new_span.as_bytes());
let mut result = String::new();
let mut prev_new_rel = 0usize;
let mut cursor_placed = false;
for block_idx in start_marker_idx..end_marker_idx {
- let marker_num = block_idx + 1;
- result.push_str(&marker_tag(marker_num));
+ result.push_str(&tag_for_block_idx(block_idx));
let new_rel_end = if block_idx + 1 == end_marker_idx {
- // Last block: extends to end of new span
new_span.len()
} else {
- // Map the intermediate boundary from old to new coordinates
let old_rel = marker_offsets[block_idx + 1] - old_start;
let mapped = map_boundary_offset(
old_rel,
@@ -791,13 +841,10 @@ pub fn encode_from_old_and_new_v0316(
span_common_prefix,
span_common_suffix,
);
- // Ensure char boundary safety and monotonicity
- new_span.floor_char_boundary(mapped)
+ snap_to_line_start(new_span, mapped)
};
- // Ensure monotonicity (each block gets at least zero content)
let new_rel_end = new_rel_end.max(prev_new_rel);
-
let block_content = &new_span[prev_new_rel..new_rel_end];
if !cursor_placed {
@@ -821,19 +868,33 @@ pub fn encode_from_old_and_new_v0316(
prev_new_rel = new_rel_end;
}
- // Final closing marker
- let end_marker_num = end_marker_idx + 1;
- result.push_str(&marker_tag(end_marker_num));
+ result.push_str(&tag_for_block_idx(end_marker_idx));
result.push_str(end_marker);
Ok(result)
}
-/// Encode the V0317 training target from old and new editable text.
-///
-/// V0317 differences from V0316:
-/// - Marker ids are relative to cursor block (..., -2, -1, 0, +1, +2, ...).
-/// - No-edit signal: repeated cursor-relative marker.
+pub fn encode_from_old_and_new_v0316(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+) -> Result<String> {
+ let marker_offsets = compute_marker_offsets(old_editable);
+ let no_edit_tag = marker_tag(nearest_marker_number(cursor_offset_in_new, &marker_offsets));
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag(block_idx + 1),
+ )
+}
+
pub fn encode_from_old_and_new_v0317(
old_editable: &str,
new_editable: &str,
@@ -843,157 +904,38 @@ pub fn encode_from_old_and_new_v0317(
) -> Result<String> {
let marker_offsets = compute_marker_offsets(old_editable);
let anchor_idx = cursor_block_index(cursor_offset_in_new, &marker_offsets);
-
- if old_editable == new_editable {
- let tag = marker_tag_relative(0);
- return Ok(format!("{tag}{tag}{end_marker}"));
- }
-
- let common_prefix = old_editable
- .bytes()
- .zip(new_editable.bytes())
- .take_while(|(a, b)| a == b)
- .count();
-
- let old_remaining = old_editable.len() - common_prefix;
- let new_remaining = new_editable.len() - common_prefix;
- let max_suffix = old_remaining.min(new_remaining);
- let common_suffix = old_editable.as_bytes()[old_editable.len() - max_suffix..]
- .iter()
- .rev()
- .zip(
- new_editable.as_bytes()[new_editable.len() - max_suffix..]
- .iter()
- .rev(),
- )
- .take_while(|(a, b)| a == b)
- .count();
-
- let change_end_in_old = old_editable.len() - common_suffix;
-
- let start_marker_idx = marker_offsets
- .iter()
- .rposition(|&offset| offset <= common_prefix)
- .unwrap_or(0);
- let end_marker_idx = marker_offsets
- .iter()
- .position(|&offset| offset >= change_end_in_old)
- .unwrap_or(marker_offsets.len() - 1);
-
- let old_start = marker_offsets[start_marker_idx];
- let old_end = marker_offsets[end_marker_idx];
-
- let new_start = old_start;
- let new_end = new_editable
- .len()
- .saturating_sub(old_editable.len().saturating_sub(old_end));
-
- let new_span = &new_editable[new_start..new_end];
- let old_span = &old_editable[old_start..old_end];
-
- let span_common_prefix = old_span
- .bytes()
- .zip(new_span.bytes())
- .take_while(|(a, b)| a == b)
- .count();
-
- let span_old_remaining = old_span.len() - span_common_prefix;
- let span_new_remaining = new_span.len() - span_common_prefix;
- let span_max_suffix = span_old_remaining.min(span_new_remaining);
- let span_common_suffix = old_span.as_bytes()[old_span.len() - span_max_suffix..]
- .iter()
- .rev()
- .zip(
- new_span.as_bytes()[new_span.len() - span_max_suffix..]
- .iter()
- .rev(),
- )
- .take_while(|(a, b)| a == b)
- .count();
-
- let mut result = String::new();
- let mut prev_new_rel = 0usize;
- let mut cursor_placed = false;
-
- for block_idx in start_marker_idx..end_marker_idx {
- let marker_delta = block_idx as isize - anchor_idx as isize;
- result.push_str(&marker_tag_relative(marker_delta));
-
- let new_rel_end = if block_idx + 1 == end_marker_idx {
- new_span.len()
- } else {
- let old_rel = marker_offsets[block_idx + 1] - old_start;
- let mapped = map_boundary_offset(
- old_rel,
- old_span.len(),
- new_span.len(),
- span_common_prefix,
- span_common_suffix,
- );
- new_span.floor_char_boundary(mapped)
- };
-
- let new_rel_end = new_rel_end.max(prev_new_rel);
- let block_content = &new_span[prev_new_rel..new_rel_end];
-
- if !cursor_placed {
- if let Some(cursor_offset) = cursor_offset_in_new {
- let abs_start = new_start + prev_new_rel;
- let abs_end = new_start + new_rel_end;
- if cursor_offset >= abs_start && cursor_offset <= abs_end {
- cursor_placed = true;
- let cursor_in_block = cursor_offset - abs_start;
- let bounded = cursor_in_block.min(block_content.len());
- result.push_str(&block_content[..bounded]);
- result.push_str(cursor_marker);
- result.push_str(&block_content[bounded..]);
- prev_new_rel = new_rel_end;
- continue;
- }
- }
- }
-
- result.push_str(block_content);
- prev_new_rel = new_rel_end;
- }
-
- let end_marker_delta = end_marker_idx as isize - anchor_idx as isize;
- result.push_str(&marker_tag_relative(end_marker_delta));
- result.push_str(end_marker);
-
- Ok(result)
+ let no_edit_tag = marker_tag_relative(0);
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag_relative(block_idx as isize - anchor_idx as isize),
+ )
}
-/// Map a byte offset from old span coordinates to new span coordinates,
-/// using common prefix/suffix within the span for accuracy.
-fn map_boundary_offset(
- old_rel: usize,
- old_span_len: usize,
- new_span_len: usize,
- span_common_prefix: usize,
- span_common_suffix: usize,
-) -> usize {
- if old_rel <= span_common_prefix {
- old_rel
- } else if old_rel >= old_span_len - span_common_suffix {
- new_span_len - (old_span_len - old_rel)
- } else {
- // Within the changed region: proportional mapping
- let old_changed_start = span_common_prefix;
- let old_changed_len = old_span_len
- .saturating_sub(span_common_prefix)
- .saturating_sub(span_common_suffix);
- let new_changed_start = span_common_prefix;
- let new_changed_len = new_span_len
- .saturating_sub(span_common_prefix)
- .saturating_sub(span_common_suffix);
-
- if old_changed_len == 0 {
- new_changed_start
- } else {
- new_changed_start + ((old_rel - old_changed_start) * new_changed_len / old_changed_len)
- }
- }
+pub fn encode_from_old_and_new_v0318(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+) -> Result<String> {
+ let marker_offsets = compute_marker_offsets_v0318(old_editable);
+ let no_edit_tag = marker_tag(nearest_marker_number(cursor_offset_in_new, &marker_offsets));
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag(block_idx + 1),
+ )
}
#[cfg(test)]
@@ -91,6 +91,8 @@ pub enum ZetaFormat {
V0306SeedMultiRegions,
/// Byte-exact marker spans; all intermediate markers emitted; repeated marker means no-edit.
V0316SeedMultiRegions,
+ /// V0316 with larger block sizes.
+ V0318SeedMultiRegions,
/// V0316, but marker numbers are relative to the cursor block (e.g. -1, -0, +1).
V0317SeedMultiRegions,
}
@@ -242,6 +244,18 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str]
];
TOKENS
}
+ ZetaFormat::V0318SeedMultiRegions => {
+ static TOKENS: &[&str] = &[
+ seed_coder::FIM_SUFFIX,
+ seed_coder::FIM_PREFIX,
+ seed_coder::FIM_MIDDLE,
+ seed_coder::FILE_MARKER,
+ multi_region::V0318_END_MARKER,
+ CURSOR_MARKER,
+ multi_region::MARKER_TAG_PREFIX,
+ ];
+ TOKENS
+ }
ZetaFormat::V0317SeedMultiRegions => {
static TOKENS: &[&str] = &[
seed_coder::FIM_SUFFIX,
@@ -283,6 +297,7 @@ pub fn token_limits_for_format(format: ZetaFormat) -> (usize, usize) {
| ZetaFormat::v0226Hashline
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
| ZetaFormat::V0317SeedMultiRegions
| ZetaFormat::V0304SeedNoEdits => (350, 150),
ZetaFormat::V0304VariableEdit => (1024, 0),
@@ -303,6 +318,7 @@ pub fn stop_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0304SeedNoEdits => &[],
ZetaFormat::V0316SeedMultiRegions => &[multi_region::V0316_END_MARKER],
+ ZetaFormat::V0318SeedMultiRegions => &[multi_region::V0318_END_MARKER],
ZetaFormat::V0317SeedMultiRegions => &[multi_region::V0317_END_MARKER],
}
}
@@ -328,6 +344,7 @@ pub fn excerpt_ranges_for_format(
| ZetaFormat::V0304SeedNoEdits
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
| ZetaFormat::V0317SeedMultiRegions => (
ranges.editable_350.clone(),
ranges.editable_350_context_150.clone(),
@@ -419,6 +436,14 @@ pub fn write_cursor_excerpt_section_for_format(
cursor_offset,
));
}
+ ZetaFormat::V0318SeedMultiRegions => {
+ prompt.push_str(&build_v0318_cursor_prefix(
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ));
+ }
ZetaFormat::V0317SeedMultiRegions => {
prompt.push_str(&build_v0317_cursor_prefix(
path,
@@ -486,6 +511,33 @@ fn build_v0316_cursor_prefix(
section
}
+fn build_v0318_cursor_prefix(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) -> String {
+ let mut section = String::new();
+ let path_str = path.to_string_lossy();
+ write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok();
+
+ section.push_str(&context[..editable_range.start]);
+
+ let editable_text = &context[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+ multi_region::write_editable_with_markers_v0318(
+ &mut section,
+ editable_text,
+ cursor_in_editable,
+ CURSOR_MARKER,
+ );
+
+ if !section.ends_with('\n') {
+ section.push('\n');
+ }
+ section
+}
+
fn build_v0317_cursor_prefix(
path: &Path,
context: &str,
@@ -551,6 +603,7 @@ pub fn format_prompt_with_budget_for_format(
| ZetaFormat::V0304SeedNoEdits
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
| ZetaFormat::V0317SeedMultiRegions => {
let mut cursor_section = String::new();
write_cursor_excerpt_section_for_format(
@@ -649,6 +702,7 @@ pub fn max_edit_event_count_for_format(format: &ZetaFormat) -> usize {
| ZetaFormat::V0304VariableEdit
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
| ZetaFormat::V0317SeedMultiRegions => 6,
}
}
@@ -671,6 +725,7 @@ pub fn get_prefill_for_format(
ZetaFormat::V0304SeedNoEdits
| ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
| ZetaFormat::V0317SeedMultiRegions => String::new(),
}
}
@@ -684,6 +739,7 @@ pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str>
| ZetaFormat::V0304SeedNoEdits
| ZetaFormat::V0306SeedMultiRegions => Some(seed_coder::END_MARKER),
ZetaFormat::V0316SeedMultiRegions => Some(multi_region::V0316_END_MARKER),
+ ZetaFormat::V0318SeedMultiRegions => Some(multi_region::V0318_END_MARKER),
ZetaFormat::V0317SeedMultiRegions => Some(multi_region::V0317_END_MARKER),
ZetaFormat::V0112MiddleAtEnd
| ZetaFormat::V0113Ordered
@@ -727,6 +783,22 @@ pub fn encode_patch_as_output_for_format(
Ok(None)
}
}
+ ZetaFormat::V0318SeedMultiRegions => {
+ let empty_patch = patch.lines().count() <= 3;
+ if empty_patch {
+ let marker_offsets =
+ multi_region::compute_marker_offsets_v0318(old_editable_region);
+ let marker_num =
+ multi_region::nearest_marker_number(cursor_offset, &marker_offsets);
+ let tag = multi_region::marker_tag(marker_num);
+ Ok(Some(format!(
+ "{tag}{tag}{}",
+ multi_region::V0318_END_MARKER
+ )))
+ } else {
+ Ok(None)
+ }
+ }
ZetaFormat::V0317SeedMultiRegions => {
let empty_patch = patch.lines().count() <= 3;
if empty_patch {
@@ -797,6 +869,10 @@ pub fn parse_zeta2_model_output(
editable_range_in_context,
multi_region::apply_marker_span_v0316(old_editable_region, output)?,
),
+ ZetaFormat::V0318SeedMultiRegions => (
+ editable_range_in_context,
+ multi_region::apply_marker_span_v0318(old_editable_region, output)?,
+ ),
ZetaFormat::V0317SeedMultiRegions => (
editable_range_in_context,
multi_region::apply_marker_span_v0317(
@@ -48,8 +48,6 @@ description = "Example extension"
repository = "https://github.com/your-name/my-zed-extension"
```
-> **Note:** If you are working on a theme extension with the intent to publish it later, suffix your theme extension ID with `-theme`. Otherwise, this may be raised during [extension publishing](#publishing-your-extension).
-
In addition to this, there are several other optional files and directories that can be used to add functionality to a Zed extension. An example directory structure of an extension that provides all capabilities is as follows:
```
@@ -144,7 +142,24 @@ Your license file should be at the root of your extension repository. Any filena
> This license requirement applies only to your extension code itself (the code that gets compiled into the extension binary).
> It does not apply to any tools your extension may download or interact with, such as language servers or other external dependencies.
-> If your repository contains both extension code and other projects (like a language server), you are not required to relicense those other projectsβonly the extension code needs to be one of the aforementioned accepted licenses.
+> If your repository contains both extension code and other projects (like a language server), you are not required to relicense those other projects β only the extension code needs to be one of the aforementioned accepted licenses.
+
+## Extension Publishing Prerequisites
+
+Before publishing your extension, make sure that you have chosen a unique extension ID for your extension in the [extension manifest](#directory-structure-of-a-zed-extension).
+This will be the primary identifier for your extension and cannot be changed after your extension has been published.
+Also, ensure that you have filled out all the required fields in the manifest.
+
+Furthermore, please make sure that your extension fulfills the following preconditions before you move on to publishing your extension:
+
+- Extension IDs and names must not contain the words `zed`, `Zed` or `extension`, since they are all Zed extensions.
+- Your extension ID should provide some information on what your extension tries to accomplish. E.g. for themes, it should be suffixed with `-theme`, snippet extensions should be suffixed with `-snippets` and so on. An exception to that rule are extension that provide support for languages or popular tooling that people would expect to find under that ID. You can take a look at the list of [existing extensions](https://github.com/zed-industries/extensions/blob/main/extensions.toml) to get a grasp on how this usually is enforced.
+- Extensions should provide something that is not yet available in the marketplace as opposed to fixing something that could be resolved within an existing extension. For example, if you find that an existing extension's support for a language server is not functioning properly, first try contributing a fix to the existing extension as opposed to submitting a new extension immediately.
+ - If you receive no response or reaction within the upstream repository within a reasonable amount of time, feel free to submit a pull request that aims to fix said issue. Please ensure that you provide your previous efforts within the pull request to the extensions repository for adding your extension. Zed maintainers will then decide on how to proceed on a case by case basis.
+- Extensions that intend to provide a language, debugger or MCP server must not ship the language server as part of the extension. Instead, the extension should either download the language server or check for the availability of the language server in the users environment using the APIs as provided by the [Zed Rust Extension API](https://docs.rs/zed_extension_api/latest/zed_extension_api/).
+- Themes and icon themes should not be published as part of extensions that provide other features, e.g. language support. Instead, they should be published as a distinct extension. This also applies to theme and icon themes living in the same repository.
+
+Note that non-compliance will be raised during the publishing process by reviewers and delay the release of your extension.
## Publishing your extension
@@ -152,13 +167,15 @@ To publish an extension, open a PR to [the `zed-industries/extensions` repo](htt
In your PR, do the following:
-1. Add your extension as a Git submodule within the `extensions/` directory
+1. Add your extension as a Git submodule within the `extensions/` directory under the `extensions/{extension-id}` path
```sh
-git submodule add https://github.com/your-username/foobar-zed.git extensions/foobar
-git add extensions/foobar
+git submodule add https://github.com/your-username/foobar-zed.git extensions/my-extension
+git add extensions/my-extension
```
+> **Note:** Your extension must live under te
+
> All extension submodules must use HTTPS URLs and not SSH URLS (`git@github.com`).
2. Add a new entry to the top-level `extensions.toml` file containing your extension:
@@ -169,14 +186,21 @@ submodule = "extensions/my-extension"
version = "0.0.1"
```
-> If your extension is in a subdirectory within the submodule you can use the `path` field to point to where the extension resides.
+If your extension is in a subdirectory within the submodule, you can use the `path` field to point to where the extension resides:
+
+```toml
+[my-extension]
+submodule = "extensions-my-extension"
+path = "packages/zed"
+version = "0.0.1"
+```
+
+> Note that the [required extension license](#extension-license-requirements) must reside at the specified path, a license at the root of the repository will not work. However, you are free to symlink an existing license within the repository or choose an alternative license from the list of accepted licenses for the extension code.
3. Run `pnpm sort-extensions` to ensure `extensions.toml` and `.gitmodules` are sorted
Once your PR is merged, the extension will be packaged and published to the Zed extension registry.
-> Extension IDs and names should not contain `zed` or `Zed`, since they are all Zed extensions.
-
## Updating an extension
To update an extension, open a PR to [the `zed-industries/extensions` repo](https://github.com/zed-industries/extensions).
@@ -4627,7 +4627,8 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
"show_user_picture": true,
"show_user_menu": true,
"show_sign_in": true,
- "show_menus": false
+ "show_menus": false,
+ "button_layout": "platform_default"
}
}
```
@@ -4642,6 +4643,7 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
- `show_user_menu`: Whether to show the user menu button in the titlebar (the one that displays your avatar by default and contains options like Settings, Keymap, Themes, etc.)
- `show_sign_in`: Whether to show the sign in button in the titlebar
- `show_menus`: Whether to show the menus in the titlebar
+- `button_layout`: The layout of window control buttons in the title bar (Linux only). Can be set to `"platform_default"` to follow the system setting, `"standard"` to use Zed's built-in layout, or a custom format like `"close:minimize,maximize"`
## Vim
@@ -77,7 +77,6 @@ let
builtins.elem firstComp topLevelIncludes;
craneLib = crane.overrideToolchain rustToolchain;
- gpu-lib = if withGLES then libglvnd else vulkan-loader;
commonArgs =
let
zedCargoLock = builtins.fromTOML (builtins.readFile ../crates/zed/Cargo.toml);
@@ -179,7 +178,8 @@ let
libva
libxkbcommon
wayland
- gpu-lib
+ libglvnd
+ vulkan-loader
xorg.libX11
xorg.libxcb
libdrm
@@ -236,7 +236,8 @@ let
# about them that's special is that they're manually dlopened at runtime
NIX_LDFLAGS = lib.optionalString stdenv'.hostPlatform.isLinux "-rpath ${
lib.makeLibraryPath [
- gpu-lib
+ libglvnd
+ vulkan-loader
wayland
libva
]
@@ -245,7 +246,7 @@ let
NIX_OUTPATH_USED_AS_RANDOM_SEED = "norebuilds";
};
- # prevent nix from removing the "unused" wayland/gpu-lib rpaths
+ # prevent nix from removing the "unused" wayland rpaths
dontPatchELF = stdenv'.hostPlatform.isLinux;
# TODO: try craneLib.cargoNextest separate output
@@ -33,7 +33,7 @@ fn style() -> NamedJob {
.add_step(steps::cache_rust_dependencies_namespace())
.map(steps::install_linux_dependencies)
.add_step(steps::cargo_fmt())
- .add_step(steps::clippy(Platform::Linux)),
+ .add_step(steps::clippy(Platform::Linux, None)),
))
}
@@ -16,9 +16,9 @@ pub(crate) fn release() -> Workflow {
let macos_tests = run_tests::run_platform_tests_no_filter(Platform::Mac);
let linux_tests = run_tests::run_platform_tests_no_filter(Platform::Linux);
let windows_tests = run_tests::run_platform_tests_no_filter(Platform::Windows);
- let macos_clippy = run_tests::clippy(Platform::Mac);
- let linux_clippy = run_tests::clippy(Platform::Linux);
- let windows_clippy = run_tests::clippy(Platform::Windows);
+ let macos_clippy = run_tests::clippy(Platform::Mac, None);
+ let linux_clippy = run_tests::clippy(Platform::Linux, None);
+ let windows_clippy = run_tests::clippy(Platform::Windows, None);
let check_scripts = run_tests::check_scripts();
let create_draft_release = create_draft_release();
@@ -18,7 +18,7 @@ pub fn release_nightly() -> Workflow {
let style = check_style();
// run only on windows as that's our fastest platform right now.
let tests = run_platform_tests_no_filter(Platform::Windows);
- let clippy_job = clippy(Platform::Windows);
+ let clippy_job = clippy(Platform::Windows, None);
let nightly = Some(ReleaseChannel::Nightly);
let bundle = ReleaseBundleJobs {
@@ -15,7 +15,7 @@ use crate::tasks::workflows::{
};
use super::{
- runners::{self, Platform},
+ runners::{self, Arch, Platform},
steps::{self, FluentBuilder, NamedJob, named, release_job},
};
@@ -48,9 +48,10 @@ pub(crate) fn run_tests() -> Workflow {
let mut jobs = vec![
orchestrate,
check_style(),
- should_run_tests.guard(clippy(Platform::Windows)),
- should_run_tests.guard(clippy(Platform::Linux)),
- should_run_tests.guard(clippy(Platform::Mac)),
+ should_run_tests.guard(clippy(Platform::Windows, None)),
+ should_run_tests.guard(clippy(Platform::Linux, None)),
+ should_run_tests.guard(clippy(Platform::Mac, None)),
+ should_run_tests.guard(clippy(Platform::Mac, Some(Arch::X86_64))),
should_run_tests.guard(run_platform_tests(Platform::Windows)),
should_run_tests.guard(run_platform_tests(Platform::Linux)),
should_run_tests.guard(run_platform_tests(Platform::Mac)),
@@ -489,7 +490,12 @@ fn check_workspace_binaries() -> NamedJob {
))
}
-pub(crate) fn clippy(platform: Platform) -> NamedJob {
+pub(crate) fn clippy(platform: Platform, arch: Option<Arch>) -> NamedJob {
+ let target = arch.map(|arch| match (platform, arch) {
+ (Platform::Mac, Arch::X86_64) => "x86_64-apple-darwin",
+ (Platform::Mac, Arch::AARCH64) => "aarch64-apple-darwin",
+ _ => unimplemented!("cross-arch clippy not supported for {platform}/{arch}"),
+ });
let runner = match platform {
Platform::Windows => runners::WINDOWS_DEFAULT,
Platform::Linux => runners::LINUX_DEFAULT,
@@ -507,16 +513,20 @@ pub(crate) fn clippy(platform: Platform) -> NamedJob {
platform == Platform::Linux,
steps::install_linux_dependencies,
)
+ .when_some(target, |this, target| {
+ this.add_step(steps::install_rustup_target(target))
+ })
.add_step(steps::setup_sccache(platform))
- .add_step(steps::clippy(platform))
+ .add_step(steps::clippy(platform, target))
.add_step(steps::show_sccache_stats(platform));
if platform == Platform::Linux {
job = use_clang(job);
}
- NamedJob {
- name: format!("clippy_{platform}"),
- job,
- }
+ let name = match arch {
+ Some(arch) => format!("clippy_{platform}_{arch}"),
+ None => format!("clippy_{platform}"),
+ };
+ NamedJob { name, job }
}
pub(crate) fn run_platform_tests(platform: Platform) -> NamedJob {
@@ -211,13 +211,20 @@ pub fn clear_target_dir_if_large(platform: Platform) -> Step<Run> {
}
}
-pub fn clippy(platform: Platform) -> Step<Run> {
+pub fn clippy(platform: Platform, target: Option<&str>) -> Step<Run> {
match platform {
Platform::Windows => named::pwsh("./script/clippy.ps1"),
- _ => named::bash("./script/clippy"),
+ _ => match target {
+ Some(target) => named::bash(format!("./script/clippy --target {target}")),
+ None => named::bash("./script/clippy"),
+ },
}
}
+pub fn install_rustup_target(target: &str) -> Step<Run> {
+ named::bash(format!("rustup target add {target}"))
+}
+
pub fn cache_rust_dependencies_namespace() -> Step<Use> {
named::uses("namespacelabs", "nscloud-cache-action", "v1")
.add_with(("cache", "rust"))