From 7d8fe66f1b9357458b726a96b1ad27a7326f60e2 Mon Sep 17 00:00:00 2001 From: Saketh <126517689+SAKETH11111@users.noreply.github.com> Date: Mon, 6 Apr 2026 13:48:26 -0500 Subject: [PATCH 01/21] workspace: Keep restricted mode modal actions visible (#53124) Closes #52586 ## Summary - cap the restricted project list height inside the security modal and make it scroll - cap the modal body content height so the action buttons stay reachable on smaller screens - add a regression test that reproduces the overflow scenario with many restricted projects in a constrained window ## Validation - manually reproduced the overflow by opening 60 untrusted projects in a 720x620 window before the fix - cargo test -p workspace test_security_modal_project_list_scrolls_when_many_projects_are_restricted - cargo check -p workspace Release Notes: - Fixed restricted mode dialogs overflowing past the window when many unrecognized projects are open. --------- Co-authored-by: Danilo Leal --- crates/workspace/src/security_modal.rs | 97 +++++++++++++++++--------- 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/crates/workspace/src/security_modal.rs b/crates/workspace/src/security_modal.rs index 664aa891550cecdd602d54bfca579d04e03f33dc..2130a1d1eca3d33651a057d32a252718270f89f8 100644 --- a/crates/workspace/src/security_modal.rs +++ b/crates/workspace/src/security_modal.rs @@ -7,7 +7,7 @@ use std::{ }; use collections::{HashMap, HashSet}; -use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, WeakEntity}; +use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, ScrollHandle, WeakEntity}; use project::{ WorktreeId, @@ -17,7 +17,8 @@ use project::{ use smallvec::SmallVec; use theme::ActiveTheme; use ui::{ - AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, prelude::*, + AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, WithScrollbar, + prelude::*, }; use crate::{DismissDecision, ModalView, ToggleWorktreeSecurity}; @@ -29,6 +30,7 @@ pub struct SecurityModal { worktree_store: WeakEntity, remote_host: Option, focus_handle: FocusHandle, + project_list_scroll_handle: ScrollHandle, trusted: Option, } @@ -63,16 +65,17 @@ impl ModalView for SecurityModal { } impl Render for SecurityModal { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { if self.restricted_paths.is_empty() { self.dismiss(cx); return v_flex().into_any_element(); } - let header_label = if self.restricted_paths.len() == 1 { - "Unrecognized Project" + let restricted_count = self.restricted_paths.len(); + let header_label: SharedString = if restricted_count == 1 { + "Unrecognized Project".into() } else { - "Unrecognized Projects" + format!("Unrecognized Projects ({})", restricted_count).into() }; let trust_label = self.build_trust_label(); @@ -102,32 +105,61 @@ impl Render for SecurityModal { .child(Icon::new(IconName::Warning).color(Color::Warning)) .child(Label::new(header_label)), ) - .children(self.restricted_paths.values().filter_map(|restricted_path| { - let abs_path = if restricted_path.is_file { - restricted_path.abs_path.parent() - } else { - Some(restricted_path.abs_path.as_ref()) - }?; - let label = match &restricted_path.host { - Some(remote_host) => match &remote_host.user_name { - Some(user_name) => format!( - "{} ({}@{})", - self.shorten_path(abs_path).display(), - user_name, - remote_host.host_identifier - ), - None => format!( - "{} ({})", - self.shorten_path(abs_path).display(), - remote_host.host_identifier - ), - }, - None => self.shorten_path(abs_path).display().to_string(), - }; - Some(h_flex() - .pl(IconSize::default().rems() + rems(0.5)) - .child(Label::new(label).color(Color::Muted))) - })), + .child( + div() + .size_full() + .vertical_scrollbar_for(&self.project_list_scroll_handle, window, cx) + .child( + v_flex() + .id("paths_container") + .max_h_24() + .overflow_y_scroll() + .track_scroll(&self.project_list_scroll_handle) + .children( + self.restricted_paths.values().filter_map( + |restricted_path| { + let abs_path = if restricted_path.is_file { + restricted_path.abs_path.parent() + } else { + Some(restricted_path.abs_path.as_ref()) + }?; + let label = match &restricted_path.host { + Some(remote_host) => { + match &remote_host.user_name { + Some(user_name) => format!( + "{} ({}@{})", + self.shorten_path(abs_path) + .display(), + user_name, + remote_host.host_identifier + ), + None => format!( + "{} ({})", + self.shorten_path(abs_path) + .display(), + remote_host.host_identifier + ), + } + } + None => self + .shorten_path(abs_path) + .display() + .to_string(), + }; + Some( + h_flex() + .pl( + IconSize::default().rems() + rems(0.5), + ) + .child( + Label::new(label).color(Color::Muted), + ), + ) + }, + ), + ), + ), + ), ) .child( v_flex() @@ -219,6 +251,7 @@ impl SecurityModal { remote_host: remote_host.map(|host| host.into()), restricted_paths: HashMap::default(), focus_handle: cx.focus_handle(), + project_list_scroll_handle: ScrollHandle::new(), trust_parents: false, home_dir: std::env::home_dir(), trusted: None, From d1b1f258e51fb2aa7cef8d4c2b346cf1b784ac6e Mon Sep 17 00:00:00 2001 From: Xin Zhao Date: Tue, 7 Apr 2026 02:51:12 +0800 Subject: [PATCH 02/21] git_graph: Fix commit hover misalignment after fractional scrolling (#53218) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #53199 Mathematically, `floor(A) + floor(B) != floor(A + B)`. The original code calculated the hovered row by applying `.floor()` to the scrolled offset and local offset separately before adding them together, which incorrectly dropped fractional sub-pixels and caused an off-by-one targeting error. Release Notes: - N/A --------- Co-authored-by: Anthony Eid --- crates/git_graph/src/git_graph.rs | 77 +++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/crates/git_graph/src/git_graph.rs b/crates/git_graph/src/git_graph.rs index 83cd01eda5c509583f24fd424426d20a55bbfbed..aa5f6bc6e1293cfd057baa0c5e9f77819da71086 100644 --- a/crates/git_graph/src/git_graph.rs +++ b/crates/git_graph/src/git_graph.rs @@ -2394,9 +2394,8 @@ impl GitGraph { let local_y = position_y - canvas_bounds.origin.y; if local_y >= px(0.) && local_y < canvas_bounds.size.height { - let row_in_viewport = (local_y / self.row_height).floor() as usize; - let scroll_rows = (scroll_offset_y / self.row_height).floor() as usize; - let absolute_row = scroll_rows + row_in_viewport; + let absolute_y = local_y + scroll_offset_y; + let absolute_row = (absolute_y / self.row_height).floor() as usize; if absolute_row < self.graph_data.commits.len() { return Some(absolute_row); @@ -4006,4 +4005,76 @@ mod tests { }); assert_eq!(reloaded_shas, vec![updated_head, updated_stash]); } + + #[gpui::test] + async fn test_git_graph_row_at_position_rounding(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + Path::new("/project"), + serde_json::json!({ + ".git": {}, + "file.txt": "content", + }), + ) + .await; + + let mut rng = StdRng::seed_from_u64(42); + let commits = generate_random_commit_dag(&mut rng, 10, false); + fs.set_graph_commits(Path::new("/project/.git"), commits.clone()); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + cx.run_until_parked(); + + let repository = project.read_with(cx, |project, cx| { + project + .active_repository(cx) + .expect("should have a repository") + }); + + let (multi_workspace, cx) = cx.add_window_view(|window, cx| { + workspace::MultiWorkspace::test_new(project.clone(), window, cx) + }); + + let workspace_weak = + multi_workspace.read_with(&*cx, |multi, _| multi.workspace().downgrade()); + + let git_graph = cx.new_window_entity(|window, cx| { + GitGraph::new( + repository.read(cx).id, + project.read(cx).git_store().clone(), + workspace_weak, + window, + cx, + ) + }); + cx.run_until_parked(); + + git_graph.update(cx, |graph, cx| { + assert!( + graph.graph_data.commits.len() >= 10, + "graph should load dummy commits" + ); + + graph.row_height = px(20.0); + let origin_y = px(100.0); + graph.graph_canvas_bounds.set(Some(Bounds { + origin: point(px(0.0), origin_y), + size: gpui::size(px(100.0), px(1000.0)), + })); + + graph.table_interaction_state.update(cx, |state, _| { + state.set_scroll_offset(point(px(0.0), px(-15.0))) + }); + let pos_y = origin_y + px(10.0); + let absolute_calc_row = graph.row_at_position(pos_y, cx); + + assert_eq!( + absolute_calc_row, + Some(1), + "Row calculation should yield absolute row exactly" + ); + }); + } } From 136e91a7d325fcafe8534b56e5ef2e3bf6111b04 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Mon, 6 Apr 2026 11:51:24 -0700 Subject: [PATCH 03/21] Fix a bug where legacy threads would be spuriously opened in a main workspace (#53260) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --- crates/sidebar/src/sidebar_tests.rs | 114 ++++++++++++++++++++++++ crates/workspace/src/multi_workspace.rs | 2 +- 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index cf1ee8a0f524d9d94edf83c24ecea900f3261fb8..a50c5dadbdbff77ccadd81dd96196a469e920e87 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -4759,6 +4759,120 @@ async fn test_linked_worktree_workspace_shows_main_worktree_threads(cx: &mut Tes ); } +#[gpui::test] +async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + "/project", + serde_json::json!({ + ".git": { + "worktrees": { + "feature-a": { + "commondir": "../../", + "HEAD": "ref: refs/heads/feature-a", + }, + }, + }, + "src": {}, + }), + ) + .await; + + fs.insert_tree( + "/wt-feature-a", + serde_json::json!({ + ".git": "gitdir: /project/.git/worktrees/feature-a", + "src": {}, + }), + ) + .await; + + fs.add_linked_worktree_for_repo( + Path::new("/project/.git"), + false, + git::repository::Worktree { + path: PathBuf::from("/wt-feature-a"), + ref_name: Some("refs/heads/feature-a".into()), + sha: "abc".into(), + is_main: false, + }, + ) + .await; + + cx.update(|cx| ::set_global(fs.clone(), cx)); + + // Only a linked worktree workspace is open — no workspace for /project. + let worktree_project = project::Project::test(fs.clone(), ["/wt-feature-a".as_ref()], cx).await; + worktree_project + .update(cx, |p, cx| p.git_scans_complete(cx)) + .await; + + let (multi_workspace, cx) = cx.add_window_view(|window, cx| { + MultiWorkspace::test_new(worktree_project.clone(), window, cx) + }); + let sidebar = setup_sidebar(&multi_workspace, cx); + + // Save a legacy thread: folder_paths = main repo, main_worktree_paths = empty. + let legacy_session = acp::SessionId::new(Arc::from("legacy-main-thread")); + cx.update(|_, cx| { + let metadata = ThreadMetadata { + session_id: legacy_session.clone(), + agent_id: agent::ZED_AGENT_ID.clone(), + title: "Legacy Main Thread".into(), + updated_at: chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 1, 0, 0, 0).unwrap(), + created_at: None, + folder_paths: PathList::new(&[PathBuf::from("/project")]), + main_worktree_paths: PathList::default(), + archived: false, + }; + ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save_manually(metadata, cx)); + }); + cx.run_until_parked(); + + multi_workspace.update_in(cx, |_, _window, cx| cx.notify()); + cx.run_until_parked(); + + // The legacy thread should appear in the sidebar under the project group. + let entries = visible_entries_as_strings(&sidebar, cx); + assert!( + entries.iter().any(|e| e.contains("Legacy Main Thread")), + "legacy thread should be visible: {entries:?}", + ); + + // Verify only 1 workspace before clicking. + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + 1, + ); + + // Focus and select the legacy thread, then confirm. + open_and_focus_sidebar(&sidebar, cx); + let thread_index = sidebar.read_with(cx, |sidebar, _| { + sidebar + .contents + .entries + .iter() + .position(|e| e.session_id().is_some_and(|id| id == &legacy_session)) + .expect("legacy thread should be in entries") + }); + sidebar.update_in(cx, |sidebar, _window, _cx| { + sidebar.selection = Some(thread_index); + }); + cx.dispatch_action(Confirm); + cx.run_until_parked(); + + let new_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + let new_path_list = + new_workspace.read_with(cx, |_, cx| workspace_path_list(&new_workspace, cx)); + assert_eq!( + new_path_list, + PathList::new(&[PathBuf::from("/project")]), + "the new workspace should be for the main repo, not the linked worktree", + ); +} + mod property_test { use super::*; diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index dc6060b70a0eeeebc1168113c2c9eb1ba2ddd251..72cc133f83aece0c6ea68b19bea53b0f5ee65755 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -649,7 +649,7 @@ impl MultiWorkspace { if let Some(workspace) = self .workspaces .iter() - .find(|ws| ws.read(cx).project_group_key(cx).path_list() == &path_list) + .find(|ws| PathList::new(&ws.read(cx).root_paths(cx)) == path_list) .cloned() { self.activate(workspace.clone(), window, cx); From dee42503c7024cb7e5374798379b9e12480000ce Mon Sep 17 00:00:00 2001 From: Peter Siegel <33677897+yeetypete@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:13:17 +0200 Subject: [PATCH 04/21] dev_container: Preserve build context for docker-compose Dockerfiles (#53140) When a Docker Compose service specifies a build context, the generated override file was replacing it with an empty context directory. This meant Dockerfiles that reference files relative to their build context (e.g. `COPY . /app`) would fail. The fix preserves the original build context from the compose service, falling back to the empty context directory only when no context was specified. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [ ] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - Fixed docker compose Dockerfile build context not being preserved in dev_container integration. --------- Co-authored-by: KyleBarton --- .../src/devcontainer_manifest.rs | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/crates/dev_container/src/devcontainer_manifest.rs b/crates/dev_container/src/devcontainer_manifest.rs index e3a09ae548b68bb4d589d8a214ca1ba5daa9cfa4..5ef82fa3eb2a3ac5d13810e0f6102bec4f42295a 100644 --- a/crates/dev_container/src/devcontainer_manifest.rs +++ b/crates/dev_container/src/devcontainer_manifest.rs @@ -883,7 +883,13 @@ RUN sed -i -E 's/((^|\s)PATH=)([^\$]*)$/\1\${{PATH:-\3}}/g' /etc/profile || true labels: None, build: Some(DockerComposeServiceBuild { context: Some( - features_build_info.empty_context_dir.display().to_string(), + main_service + .build + .as_ref() + .and_then(|b| b.context.clone()) + .unwrap_or_else(|| { + features_build_info.empty_context_dir.display().to_string() + }), ), dockerfile: Some(dockerfile_path.display().to_string()), args: Some(build_args), @@ -3546,6 +3552,27 @@ ENV DOCKER_BUILDKIT=1 "# ); + let build_override = files + .iter() + .find(|f| { + f.file_name() + .is_some_and(|s| s.display().to_string() == "docker_compose_build.json") + }) + .expect("to be found"); + let build_override = test_dependencies.fs.load(build_override).await.unwrap(); + let build_config: DockerComposeConfig = + serde_json_lenient::from_str(&build_override).unwrap(); + let build_context = build_config + .services + .get("app") + .and_then(|s| s.build.as_ref()) + .and_then(|b| b.context.clone()) + .expect("build override should have a context"); + assert_eq!( + build_context, ".", + "build override should preserve the original build context from docker-compose.yml" + ); + let runtime_override = files .iter() .find(|f| { From 1823be56781cfa9ec8d6e809ff6cf2987da6dfa0 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:19:00 -0300 Subject: [PATCH 05/21] agent_ui: Fix "scroll to" buttons (#53232) Follow-up to https://github.com/zed-industries/zed/pull/53101 In the process of fixing the thread view's scroll experience, we for got to turn off the follow state tail for functions that power the scroll buttons in the agent panel. Release Notes: - N/A --- crates/gpui/src/elements/list.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 5525f5c17d2ad33e1ce9696afded1cea5447020c..5a88d81c18db5e790b7bbed0fb9def23bc973e14 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -462,6 +462,13 @@ impl ListState { let current_offset = self.logical_scroll_top(); let state = &mut *self.0.borrow_mut(); + + if distance < px(0.) { + if let FollowState::Tail { is_following } = &mut state.follow_state { + *is_following = false; + } + } + let mut cursor = state.items.cursor::(()); cursor.seek(&Count(current_offset.item_ix), Bias::Right); @@ -536,6 +543,12 @@ impl ListState { scroll_top.offset_in_item = px(0.); } + if scroll_top.item_ix < item_count { + if let FollowState::Tail { is_following } = &mut state.follow_state { + *is_following = false; + } + } + state.logical_scroll_top = Some(scroll_top); } From 383b0a7afb3d09ea39dad2bf9bf50d50e3395758 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:19:10 -0300 Subject: [PATCH 06/21] settings_ui: Recategorize some panel settings (#53243) Was looking around the panels page in the settings UI and noticed there was a standalone "Auto Open Files" section. That felt a bit out of place because those settings are really project panel-specific. So this PR moves them under the project panels section of the panels page. Release Notes: - N/A --- crates/settings_ui/src/page_data.rs | 71 +++++++++++++---------------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index bacfd227d83933d3ebd9b2d8836bbe19958acf2b..20a0c53534988a873b3b3f6e393eefd5bb0b3f7c 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -4433,7 +4433,7 @@ fn window_and_layout_page() -> SettingsPage { } fn panels_page() -> SettingsPage { - fn project_panel_section() -> [SettingsPageItem; 24] { + fn project_panel_section() -> [SettingsPageItem; 28] { [ SettingsPageItem::SectionHeader("Project Panel"), SettingsPageItem::SettingItem(SettingItem { @@ -4914,31 +4914,25 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "Hidden Files", - description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.", - field: Box::new( - SettingField { - json_path: Some("worktree.hidden_files"), - pick: |settings_content| { - settings_content.project.worktree.hidden_files.as_ref() - }, - write: |settings_content, value| { - settings_content.project.worktree.hidden_files = value; - }, - } - .unimplemented(), - ), + title: "Sort Mode", + description: "Sort order for entries in the project panel.", + field: Box::new(SettingField { + json_path: Some("project_panel.sort_mode"), + pick: |settings_content| { + settings_content.project_panel.as_ref()?.sort_mode.as_ref() + }, + write: |settings_content, value| { + settings_content + .project_panel + .get_or_insert_default() + .sort_mode = value; + }, + }), metadata: None, files: USER, }), - ] - } - - fn auto_open_files_section() -> [SettingsPageItem; 5] { - [ - SettingsPageItem::SectionHeader("Auto Open Files"), SettingsPageItem::SettingItem(SettingItem { - title: "On Create", + title: "Auto Open Files On Create", description: "Whether to automatically open newly created files in the editor.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_create"), @@ -4964,7 +4958,7 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "On Paste", + title: "Auto Open Files On Paste", description: "Whether to automatically open files after pasting or duplicating them.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_paste"), @@ -4990,7 +4984,7 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "On Drop", + title: "Auto Open Files On Drop", description: "Whether to automatically open files dropped from external sources.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_drop"), @@ -5016,20 +5010,20 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "Sort Mode", - description: "Sort order for entries in the project panel.", - field: Box::new(SettingField { - pick: |settings_content| { - settings_content.project_panel.as_ref()?.sort_mode.as_ref() - }, - write: |settings_content, value| { - settings_content - .project_panel - .get_or_insert_default() - .sort_mode = value; - }, - json_path: Some("project_panel.sort_mode"), - }), + title: "Hidden Files", + description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.", + field: Box::new( + SettingField { + json_path: Some("worktree.hidden_files"), + pick: |settings_content| { + settings_content.project.worktree.hidden_files.as_ref() + }, + write: |settings_content, value| { + settings_content.project.worktree.hidden_files = value; + }, + } + .unimplemented(), + ), metadata: None, files: USER, }), @@ -5807,7 +5801,6 @@ fn panels_page() -> SettingsPage { title: "Panels", items: concat_sections![ project_panel_section(), - auto_open_files_section(), terminal_panel_section(), outline_panel_section(), git_panel_section(), From 7adbee0cdc438a0252e6649914aa7fa063780902 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:34:36 -0300 Subject: [PATCH 07/21] sidebar: Fix behavior of "Remove Project" button (#53242) - Fix an issue where the "remove project" button, available in the header's ellipsis menu, wouldn't do anything if the sidebar contained only one project - Fix another issue where attempting to remove a project when the sidebar has more than one project wouldn't actually remove it. This is fixed by cleaning up the project group keys after its been already removed. Release Notes: - N/A --- crates/workspace/src/multi_workspace.rs | 85 ++++++++++++++++++++----- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index 72cc133f83aece0c6ea68b19bea53b0f5ee65755..a0c5eaabc629073dd9a46ac1b5073ddfbd26bd28 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -6,9 +6,7 @@ use gpui::{ ManagedView, MouseButton, Pixels, Render, Subscription, Task, Tiling, Window, WindowId, actions, deferred, px, }; -#[cfg(any(test, feature = "test-support"))] -use project::Project; -use project::{DirectoryLister, DisableAiSettings, ProjectGroupKey}; +use project::{DirectoryLister, DisableAiSettings, Project, ProjectGroupKey}; use settings::Settings; pub use settings::SidebarSide; use std::future::Future; @@ -468,6 +466,9 @@ impl MultiWorkspace { } pub fn add_project_group_key(&mut self, project_group_key: ProjectGroupKey) { + if project_group_key.path_list().paths().is_empty() { + return; + } if self.project_group_keys.contains(&project_group_key) { return; } @@ -1040,26 +1041,80 @@ impl MultiWorkspace { let Some(index) = self.workspaces.iter().position(|w| w == workspace) else { return false; }; + + let old_key = workspace.read(cx).project_group_key(cx); + if self.workspaces.len() <= 1 { - return false; - } + let has_worktrees = workspace.read(cx).visible_worktrees(cx).next().is_some(); + + if !has_worktrees { + return false; + } + + let old_workspace = workspace.clone(); + let old_entity_id = old_workspace.entity_id(); - let removed_workspace = self.workspaces.remove(index); + let app_state = old_workspace.read(cx).app_state().clone(); + + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + project::LocalProjectFlags::default(), + cx, + ); + + let new_workspace = cx.new(|cx| Workspace::new(None, project, app_state, window, cx)); + + self.workspaces[0] = new_workspace.clone(); + self.active_workspace_index = 0; + + Self::subscribe_to_workspace(&new_workspace, window, cx); + + self.sync_sidebar_to_workspace(&new_workspace, cx); + + let weak_self = cx.weak_entity(); - if self.active_workspace_index >= self.workspaces.len() { - self.active_workspace_index = self.workspaces.len() - 1; - } else if self.active_workspace_index > index { - self.active_workspace_index -= 1; + new_workspace.update(cx, |workspace, cx| { + workspace.set_multi_workspace(weak_self, cx); + }); + + self.detach_workspace(&old_workspace, cx); + + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old_entity_id)); + cx.emit(MultiWorkspaceEvent::WorkspaceAdded(new_workspace)); + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + } else { + let removed_workspace = self.workspaces.remove(index); + + if self.active_workspace_index >= self.workspaces.len() { + self.active_workspace_index = self.workspaces.len() - 1; + } else if self.active_workspace_index > index { + self.active_workspace_index -= 1; + } + + self.detach_workspace(&removed_workspace, cx); + + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved( + removed_workspace.entity_id(), + )); + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); } - self.detach_workspace(&removed_workspace, cx); + let key_still_in_use = self + .workspaces + .iter() + .any(|ws| ws.read(cx).project_group_key(cx) == old_key); + + if !key_still_in_use { + self.project_group_keys.retain(|k| k != &old_key); + } self.serialize(cx); self.focus_active_workspace(window, cx); - cx.emit(MultiWorkspaceEvent::WorkspaceRemoved( - removed_workspace.entity_id(), - )); - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); cx.notify(); true From 2b4901d8c36bb43e22253f3c65f472322ff85d41 Mon Sep 17 00:00:00 2001 From: mgabor <9047995+mgabor3141@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:39:16 +0200 Subject: [PATCH 08/21] workspace: Handle double-click on pinned tab row empty space (#51592) When `tab_bar.show_pinned_tabs_in_separate_row` is enabled, double-clicking the empty space on the unpinned tab row creates a new tab, but double-clicking the empty space on the pinned tab row does nothing. Add the same `on_click` double-click handler to the pinned tab bar drop target so both rows behave consistently. Release Notes: - Fixed double-clicking empty space in the pinned tab row not opening a new tab when `show_pinned_tabs_in_separate_row` is enabled. --------- Co-authored-by: Joseph T. Lyons --- crates/workspace/src/pane.rs | 90 +++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 92f0781f82234ce79d47db08785b6592fb53f566..27cc96ae80a010db2dd5357a9a0bc037ca762875 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -3670,6 +3670,11 @@ impl Pane { this.drag_split_direction = None; this.handle_external_paths_drop(paths, window, cx) })) + .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| { + if event.click_count() == 2 { + window.dispatch_action(this.double_click_dispatch_action.boxed_clone(), cx); + } + })) } pub fn render_menu_overlay(menu: &Entity) -> Div { @@ -4917,14 +4922,17 @@ impl Render for DraggedTab { #[cfg(test)] mod tests { - use std::{cell::Cell, iter::zip, num::NonZero}; + use std::{cell::Cell, iter::zip, num::NonZero, rc::Rc}; use super::*; use crate::{ Member, item::test::{TestItem, TestProjectItem}, }; - use gpui::{AppContext, Axis, TestAppContext, VisualTestContext, size}; + use gpui::{ + AppContext, Axis, Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, + TestAppContext, VisualTestContext, size, + }; use project::FakeFs; use settings::SettingsStore; use theme::LoadThemes; @@ -6649,8 +6657,6 @@ mod tests { #[gpui::test] async fn test_drag_tab_to_middle_tab_with_mouse_events(cx: &mut TestAppContext) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); let fs = FakeFs::new(cx.executor()); @@ -6702,8 +6708,6 @@ mod tests { async fn test_drag_pinned_tab_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6779,8 +6783,6 @@ mod tests { async fn test_drag_unpinned_tab_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6833,8 +6835,6 @@ mod tests { async fn test_drag_mixed_tabs_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6900,8 +6900,6 @@ mod tests { #[gpui::test] async fn test_middle_click_pinned_tab_does_not_close(cx: &mut TestAppContext) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseUpEvent}; - init_test(cx); let fs = FakeFs::new(cx.executor()); @@ -6971,6 +6969,74 @@ mod tests { assert_item_labels(&pane, ["A*!"], cx); } + #[gpui::test] + async fn test_double_click_pinned_tab_bar_empty_space_creates_new_tab(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + // The real NewFile handler lives in editor::init, which isn't initialized + // in workspace tests. Register a global action handler that sets a flag so + // we can verify the action is dispatched without depending on the editor crate. + // TODO: If editor::init is ever available in workspace tests, remove this + // flag and assert the resulting tab bar state directly instead. + let new_file_dispatched = Rc::new(Cell::new(false)); + cx.update(|_, cx| { + let new_file_dispatched = new_file_dispatched.clone(); + cx.on_action(move |_: &NewFile, _cx| { + new_file_dispatched.set(true); + }); + }); + + set_pinned_tabs_separate_row(cx, true); + + let item_a = add_labeled_item(&pane, "A", false, cx); + add_labeled_item(&pane, "B", false, cx); + + pane.update_in(cx, |pane, window, cx| { + let ix = pane + .index_for_item_id(item_a.item_id()) + .expect("item A should exist"); + pane.pin_tab_at(ix, window, cx); + }); + assert_item_labels(&pane, ["A!", "B*"], cx); + cx.run_until_parked(); + + let pinned_drop_target_bounds = cx + .debug_bounds("pinned_tabs_border") + .expect("pinned_tabs_border should have debug bounds"); + + cx.simulate_event(MouseDownEvent { + position: pinned_drop_target_bounds.center(), + button: MouseButton::Left, + modifiers: Modifiers::default(), + click_count: 2, + first_mouse: false, + }); + + cx.run_until_parked(); + + cx.simulate_event(MouseUpEvent { + position: pinned_drop_target_bounds.center(), + button: MouseButton::Left, + modifiers: Modifiers::default(), + click_count: 2, + }); + + cx.run_until_parked(); + + // TODO: If editor::init is ever available in workspace tests, replace this + // with an assert_item_labels check that verifies a new tab is actually created. + assert!( + new_file_dispatched.get(), + "Double-clicking pinned tab bar empty space should dispatch the new file action" + ); + } + #[gpui::test] async fn test_add_item_with_new_item(cx: &mut TestAppContext) { init_test(cx); From ec832cade62a0248ac7cee6b9d45367ea9ab9f7a Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Mon, 6 Apr 2026 17:19:29 -0300 Subject: [PATCH 09/21] rules_library: Fix hover selecting active rule (#53264) Closes https://github.com/zed-industries/zed/issues/53159 Recently, we changed the behavior of pickers so that hovering matches would also select them. This makes sense for most pickers that are used as "regular" pickers, but we have some in Zed that are not. A great example of one is the rules library, which sort of became way less usable with this behavior. So, this PR introduces a simple bool trait method to the picker so that we can turn this behavior off whenever necessary. The rules library kicks off as the only instance of it being turned off. Release Notes: - Fix navigation within the rules library making it so hovering the sidebar doesn't activate the visible rule. --- crates/picker/src/picker.rs | 17 +++++++++++------ crates/rules_library/src/rules_library.rs | 4 ++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 1e529cd53f2d2527af8525886d11dbcddbf33a34..eba5b3096194fe8a3379efeb9b230a6004cd2e36 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -121,6 +121,9 @@ pub trait PickerDelegate: Sized + 'static { ) -> bool { true } + fn select_on_hover(&self) -> bool { + true + } // Allows binding some optional effect to when the selection changes. fn selected_index_changed( @@ -788,12 +791,14 @@ impl Picker { this.handle_click(ix, event.modifiers.platform, window, cx) }), ) - .on_hover(cx.listener(move |this, hovered: &bool, window, cx| { - if *hovered { - this.set_selected_index(ix, None, false, window, cx); - cx.notify(); - } - })) + .when(self.delegate.select_on_hover(), |this| { + this.on_hover(cx.listener(move |this, hovered: &bool, window, cx| { + if *hovered { + this.set_selected_index(ix, None, false, window, cx); + cx.notify(); + } + })) + }) .children(self.delegate.render_match( ix, ix == self.delegate.selected_index(), diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 7e5a56f22d48c4d51f60d7d200dc8384582beb23..425f7d2aa3d9e9259fe005a0e15dee10e4e4baf1 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -225,6 +225,10 @@ impl PickerDelegate for RulePickerDelegate { } } + fn select_on_hover(&self) -> bool { + false + } + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { "Search…".into() } From d2257dbc3991bc7863bf276579b48bc91edf3262 Mon Sep 17 00:00:00 2001 From: Finn Evers Date: Mon, 6 Apr 2026 22:24:33 +0200 Subject: [PATCH 10/21] compliance: Initialize compliance checks (#53231) Release Notes: - N/A --- .github/workflows/compliance_check.yml | 55 ++ .github/workflows/release.yml | 84 +++ Cargo.lock | 148 +++- Cargo.toml | 3 + tooling/compliance/Cargo.toml | 38 + tooling/compliance/LICENSE-GPL | 1 + tooling/compliance/src/checks.rs | 647 ++++++++++++++++++ tooling/compliance/src/git.rs | 591 ++++++++++++++++ tooling/compliance/src/github.rs | 424 ++++++++++++ tooling/compliance/src/lib.rs | 4 + tooling/compliance/src/report.rs | 446 ++++++++++++ tooling/xtask/Cargo.toml | 6 +- tooling/xtask/src/main.rs | 2 + tooling/xtask/src/tasks.rs | 1 + tooling/xtask/src/tasks/compliance.rs | 135 ++++ tooling/xtask/src/tasks/workflows.rs | 2 + .../src/tasks/workflows/compliance_check.rs | 66 ++ tooling/xtask/src/tasks/workflows/release.rs | 113 ++- 18 files changed, 2756 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/compliance_check.yml create mode 100644 tooling/compliance/Cargo.toml create mode 120000 tooling/compliance/LICENSE-GPL create mode 100644 tooling/compliance/src/checks.rs create mode 100644 tooling/compliance/src/git.rs create mode 100644 tooling/compliance/src/github.rs create mode 100644 tooling/compliance/src/lib.rs create mode 100644 tooling/compliance/src/report.rs create mode 100644 tooling/xtask/src/tasks/compliance.rs create mode 100644 tooling/xtask/src/tasks/workflows/compliance_check.rs diff --git a/.github/workflows/compliance_check.yml b/.github/workflows/compliance_check.yml new file mode 100644 index 0000000000000000000000000000000000000000..f09c460c233b04e78df01e7828b4def737dec16e --- /dev/null +++ b/.github/workflows/compliance_check.yml @@ -0,0 +1,55 @@ +# Generated from xtask::workflows::compliance_check +# Rebuild with `cargo xtask workflows`. +name: compliance_check +env: + CARGO_TERM_COLOR: always +on: + schedule: + - cron: 30 17 * * 2 +jobs: + scheduled_compliance_check: + if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') + runs-on: namespace-profile-2x4-ubuntu-2404 + steps: + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: determine-version + name: compliance_check::scheduled_compliance_check + run: | + VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]') + if [ -z "$VERSION" ]; then + echo "Could not determine version from crates/zed/Cargo.toml" + exit 1 + fi + TAG="v${VERSION}-pre" + echo "Checking compliance for $TAG" + echo "tag=$TAG" >> "$GITHUB_OUTPUT" + - id: run-compliance-check + name: compliance_check::scheduled_compliance_check::run_compliance_check + run: cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report + env: + LATEST_TAG: ${{ steps.determine-version.outputs.tag }} + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: compliance_check::scheduled_compliance_check::send_failure_slack_notification + if: failure() + run: | + MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews." + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + LATEST_TAG: ${{ steps.determine-version.outputs.tag }} +defaults: + run: + shell: bash -euxo pipefail {0} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 35efafcfcd97c0139f8225ce7b15a05946c385ad..1401144ab3abda17dd4f526edd42166d37a47a49 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -293,6 +293,51 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} timeout-minutes: 60 + compliance_check: + if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') + runs-on: namespace-profile-16x32-ubuntu-2204 + env: + COMPLIANCE_FILE_PATH: compliance.md + steps: + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + ref: ${{ github.ref }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: run-compliance-check + name: release::compliance_check::run_compliance_check + run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT" + env: + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: release::compliance_check::send_compliance_slack_notification + if: always() + run: | + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + STATUS="✅ Compliance check passed for $GITHUB_REF_NAME" + else + STATUS="❌ Compliance check failed for $GITHUB_REF_NAME" + fi + + REPORT_CONTENT="" + if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then + REPORT_CONTENT=$(cat "$REPORT_FILE") + fi + + MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT") + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + COMPLIANCE_OUTCOME: ${{ steps.run-compliance-check.outcome }} bundle_linux_aarch64: needs: - run_tests_linux @@ -613,6 +658,45 @@ jobs: echo "All expected assets are present in the release." env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + ref: ${{ github.ref }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: run-post-upload-compliance-check + name: release::validate_release_assets::run_post_upload_compliance_check + run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report + env: + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: release::validate_release_assets::send_post_upload_compliance_notification + if: always() + run: | + if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then + echo "Compliance check was skipped, not sending notification" + exit 0 + fi + + TAG="$GITHUB_REF_NAME" + + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + MESSAGE="✅ Post-upload compliance re-check passed for $TAG" + else + MESSAGE="❌ Post-upload compliance re-check failed for $TAG" + fi + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + COMPLIANCE_OUTCOME: ${{ steps.run-post-upload-compliance-check.outcome }} auto_release_preview: needs: - validate_release_assets diff --git a/Cargo.lock b/Cargo.lock index 279fcec10f1efb4c3174bfdd8e28192cda2f6a0c..f7597693960b2c9e66121794f9c99cdb8d6ddcea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -677,6 +677,15 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + [[package]] name = "arg_enum_proc_macro" version = "0.3.4" @@ -2530,6 +2539,16 @@ dependencies = [ "serde", ] +[[package]] +name = "cargo-platform" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87a0c0e6148f11f01f32650a2ea02d532b2ad4e81d8bd41e6e565b5adc5e6082" +dependencies = [ + "serde", + "serde_core", +] + [[package]] name = "cargo_metadata" version = "0.19.2" @@ -2537,7 +2556,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba" dependencies = [ "camino", - "cargo-platform", + "cargo-platform 0.1.9", + "semver", + "serde", + "serde_json", + "thiserror 2.0.17", +] + +[[package]] +name = "cargo_metadata" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef987d17b0a113becdd19d3d0022d04d7ef41f9efe4f3fb63ac44ba61df3ade9" +dependencies = [ + "camino", + "cargo-platform 0.3.2", "semver", "serde", "serde_json", @@ -3284,6 +3317,25 @@ dependencies = [ "workspace", ] +[[package]] +name = "compliance" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "derive_more", + "futures 0.3.32", + "indoc", + "itertools 0.14.0", + "jsonwebtoken", + "octocrab", + "regex", + "semver", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "component" version = "0.1.0" @@ -8324,6 +8376,7 @@ dependencies = [ "http 1.3.1", "hyper 1.7.0", "hyper-util", + "log", "rustls 0.23.33", "rustls-native-certs 0.8.2", "rustls-pki-types", @@ -8332,6 +8385,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.7.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -11380,6 +11446,48 @@ dependencies = [ "memchr", ] +[[package]] +name = "octocrab" +version = "0.49.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63f6687a23731011d0117f9f4c3cdabaa7b5e42ca671f42b5cc0657c492540e3" +dependencies = [ + "arc-swap", + "async-trait", + "base64 0.22.1", + "bytes 1.11.1", + "cargo_metadata 0.23.1", + "cfg-if", + "chrono", + "either", + "futures 0.3.32", + "futures-core", + "futures-util", + "getrandom 0.2.16", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-rustls 0.27.7", + "hyper-timeout", + "hyper-util", + "jsonwebtoken", + "once_cell", + "percent-encoding", + "pin-project", + "secrecy", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "snafu", + "tokio", + "tower 0.5.2", + "tower-http 0.6.6", + "url", + "web-time", +] + [[package]] name = "ollama" version = "0.1.0" @@ -15381,6 +15489,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -16085,6 +16202,27 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f7a918bd2a9951d18ee6e48f076843e8e73a9a5d22cf05bcd4b7a81bdd04e17" +[[package]] +name = "snafu" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "snippet" version = "0.1.0" @@ -18089,8 +18227,10 @@ dependencies = [ "pin-project-lite", "sync_wrapper 1.0.2", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -18128,6 +18268,7 @@ dependencies = [ "tower 0.5.2", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -19974,6 +20115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", + "serde", "wasm-bindgen", ] @@ -21711,9 +21853,10 @@ dependencies = [ "annotate-snippets", "anyhow", "backtrace", - "cargo_metadata", + "cargo_metadata 0.19.2", "cargo_toml", "clap", + "compliance", "gh-workflow", "indexmap", "indoc", @@ -21723,6 +21866,7 @@ dependencies = [ "serde_json", "serde_yaml", "strum 0.27.2", + "tokio", "toml 0.8.23", "toml_edit 0.22.27", ] diff --git a/Cargo.toml b/Cargo.toml index 81bbb1176ddddcc117fc9082586cbc08dbb95d61..a800a6c9b276c5f30d6b6eca2f9f0f660f28b02d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -242,6 +242,7 @@ members = [ # Tooling # + "tooling/compliance", "tooling/perf", "tooling/xtask", ] @@ -289,6 +290,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections", version = "0.1.0" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +compliance = { path = "tooling/compliance" } component = { path = "crates/component" } component_preview = { path = "crates/component_preview" } context_server = { path = "crates/context_server" } @@ -547,6 +549,7 @@ derive_more = { version = "2.1.1", features = [ "add_assign", "deref", "deref_mut", + "display", "from_str", "mul", "mul_assign", diff --git a/tooling/compliance/Cargo.toml b/tooling/compliance/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..9b1ade359daa4b7a02beff861c94e01fff071f84 --- /dev/null +++ b/tooling/compliance/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "compliance" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[features] +octo-client = ["dep:octocrab", "dep:jsonwebtoken", "dep:futures", "dep:tokio"] + +[dependencies] +anyhow.workspace = true +async-trait.workspace = true +derive_more.workspace = true +futures = { workspace = true, optional = true } +itertools.workspace = true +jsonwebtoken = { version = "10.2", features = ["use_pem"], optional = true } +octocrab = { version = "0.49", default-features = false, features = [ + "default-client", + "jwt-aws-lc-rs", + "retry", + "rustls", + "rustls-aws-lc-rs", + "stream", + "timeout" +], optional = true } +regex.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio = { workspace = true, optional = true } + +[dev-dependencies] +indoc.workspace = true +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/tooling/compliance/LICENSE-GPL b/tooling/compliance/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/tooling/compliance/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/tooling/compliance/src/checks.rs b/tooling/compliance/src/checks.rs new file mode 100644 index 0000000000000000000000000000000000000000..a0623fbbbc179edf9f5b6d777b3116ff498f0265 --- /dev/null +++ b/tooling/compliance/src/checks.rs @@ -0,0 +1,647 @@ +use std::{fmt, ops::Not as _}; + +use itertools::Itertools as _; + +use crate::{ + git::{CommitDetails, CommitList}, + github::{ + CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData, + PullRequestReview, ReviewState, + }, + report::Report, +}; + +const ZED_ZIPPY_COMMENT_APPROVAL_PATTERN: &str = "@zed-zippy approve"; +const ZED_ZIPPY_GROUP_APPROVAL: &str = "@zed-industries/approved"; + +#[derive(Debug)] +pub enum ReviewSuccess { + ApprovingComment(Vec), + CoAuthored(Vec), + ExternalMergedContribution { merged_by: GitHubUser }, + PullRequestReviewed(Vec), +} + +impl ReviewSuccess { + pub(crate) fn reviewers(&self) -> anyhow::Result { + let reviewers = match self { + Self::CoAuthored(authors) => authors.iter().map(ToString::to_string).collect_vec(), + Self::PullRequestReviewed(reviews) => reviews + .iter() + .filter_map(|review| review.user.as_ref()) + .map(|user| format!("@{}", user.login)) + .collect_vec(), + Self::ApprovingComment(comments) => comments + .iter() + .map(|comment| format!("@{}", comment.user.login)) + .collect_vec(), + Self::ExternalMergedContribution { merged_by } => { + vec![format!("@{}", merged_by.login)] + } + }; + + let reviewers = reviewers.into_iter().unique().collect_vec(); + + reviewers + .is_empty() + .not() + .then(|| reviewers.join(", ")) + .ok_or_else(|| anyhow::anyhow!("Expected at least one reviewer")) + } +} + +impl fmt::Display for ReviewSuccess { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::CoAuthored(_) => formatter.write_str("Co-authored by an organization member"), + Self::PullRequestReviewed(_) => { + formatter.write_str("Approved by an organization review") + } + Self::ApprovingComment(_) => { + formatter.write_str("Approved by an organization approval comment") + } + Self::ExternalMergedContribution { .. } => { + formatter.write_str("External merged contribution") + } + } + } +} + +#[derive(Debug)] +pub enum ReviewFailure { + // todo: We could still query the GitHub API here to search for one + NoPullRequestFound, + Unreviewed, + UnableToDetermineReviewer, + Other(anyhow::Error), +} + +impl fmt::Display for ReviewFailure { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NoPullRequestFound => formatter.write_str("No pull request found"), + Self::Unreviewed => formatter + .write_str("No qualifying organization approval found for the pull request"), + Self::UnableToDetermineReviewer => formatter.write_str("Could not determine reviewer"), + Self::Other(error) => write!(formatter, "Failed to inspect review state: {error}"), + } + } +} + +pub(crate) type ReviewResult = Result; + +impl> From for ReviewFailure { + fn from(err: E) -> Self { + Self::Other(anyhow::anyhow!(err)) + } +} + +pub struct Reporter<'a> { + commits: CommitList, + github_client: &'a GitHubClient, +} + +impl<'a> Reporter<'a> { + pub fn new(commits: CommitList, github_client: &'a GitHubClient) -> Self { + Self { + commits, + github_client, + } + } + + /// Method that checks every commit for compliance + async fn check_commit(&self, commit: &CommitDetails) -> Result { + let Some(pr_number) = commit.pr_number() else { + return Err(ReviewFailure::NoPullRequestFound); + }; + + let pull_request = self.github_client.get_pull_request(pr_number).await?; + + if let Some(approval) = self.check_pull_request_approved(&pull_request).await? { + return Ok(approval); + } + + if let Some(approval) = self + .check_approving_pull_request_comment(&pull_request) + .await? + { + return Ok(approval); + } + + if let Some(approval) = self.check_commit_co_authors(commit).await? { + return Ok(approval); + } + + // if let Some(approval) = self.check_external_merged_pr(pr_number).await? { + // return Ok(approval); + // } + + Err(ReviewFailure::Unreviewed) + } + + async fn check_commit_co_authors( + &self, + commit: &CommitDetails, + ) -> Result, ReviewFailure> { + if commit.co_authors().is_some() + && let Some(commit_authors) = self + .github_client + .get_commit_authors([commit.sha()]) + .await? + .get(commit.sha()) + .and_then(|authors| authors.co_authors()) + { + let mut org_co_authors = Vec::new(); + for co_author in commit_authors { + if let Some(github_login) = co_author.user() + && self + .github_client + .check_org_membership(github_login) + .await? + { + org_co_authors.push(co_author.clone()); + } + } + + Ok(org_co_authors + .is_empty() + .not() + .then_some(ReviewSuccess::CoAuthored(org_co_authors))) + } else { + Ok(None) + } + } + + #[allow(unused)] + async fn check_external_merged_pr( + &self, + pull_request: PullRequestData, + ) -> Result, ReviewFailure> { + if let Some(user) = pull_request.user + && self + .github_client + .check_org_membership(&GithubLogin::new(user.login)) + .await? + .not() + { + pull_request.merged_by.map_or( + Err(ReviewFailure::UnableToDetermineReviewer), + |merged_by| { + Ok(Some(ReviewSuccess::ExternalMergedContribution { + merged_by, + })) + }, + ) + } else { + Ok(None) + } + } + + async fn check_pull_request_approved( + &self, + pull_request: &PullRequestData, + ) -> Result, ReviewFailure> { + let pr_reviews = self + .github_client + .get_pull_request_reviews(pull_request.number) + .await?; + + if !pr_reviews.is_empty() { + let mut org_approving_reviews = Vec::new(); + for review in pr_reviews { + if let Some(github_login) = review.user.as_ref() + && pull_request + .user + .as_ref() + .is_none_or(|pr_user| pr_user.login != github_login.login) + && review + .state + .is_some_and(|state| state == ReviewState::Approved) + && self + .github_client + .check_org_membership(&GithubLogin::new(github_login.login.clone())) + .await? + { + org_approving_reviews.push(review); + } + } + + Ok(org_approving_reviews + .is_empty() + .not() + .then_some(ReviewSuccess::PullRequestReviewed(org_approving_reviews))) + } else { + Ok(None) + } + } + + async fn check_approving_pull_request_comment( + &self, + pull_request: &PullRequestData, + ) -> Result, ReviewFailure> { + let other_comments = self + .github_client + .get_pull_request_comments(pull_request.number) + .await?; + + if !other_comments.is_empty() { + let mut org_approving_comments = Vec::new(); + + for comment in other_comments { + if pull_request + .user + .as_ref() + .is_some_and(|pr_author| pr_author.login != comment.user.login) + && comment.body.as_ref().is_some_and(|body| { + body.contains(ZED_ZIPPY_COMMENT_APPROVAL_PATTERN) + || body.contains(ZED_ZIPPY_GROUP_APPROVAL) + }) + && self + .github_client + .check_org_membership(&GithubLogin::new(comment.user.login.clone())) + .await? + { + org_approving_comments.push(comment); + } + } + + Ok(org_approving_comments + .is_empty() + .not() + .then_some(ReviewSuccess::ApprovingComment(org_approving_comments))) + } else { + Ok(None) + } + } + + pub async fn generate_report(mut self) -> anyhow::Result { + let mut report = Report::new(); + + let commits_to_check = std::mem::take(&mut self.commits); + let total_commits = commits_to_check.len(); + + for (i, commit) in commits_to_check.into_iter().enumerate() { + println!( + "Checking commit {:?} ({current}/{total})", + commit.sha().short(), + current = i + 1, + total = total_commits + ); + + let review_result = self.check_commit(&commit).await; + + if let Err(err) = &review_result { + println!("Commit {:?} failed review: {:?}", commit.sha().short(), err); + } + + report.add(commit, review_result); + } + + Ok(report) + } +} + +#[cfg(test)] +mod tests { + use std::rc::Rc; + use std::str::FromStr; + + use crate::git::{CommitDetails, CommitList, CommitSha}; + use crate::github::{ + AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin, + PullRequestComment, PullRequestData, PullRequestReview, ReviewState, + }; + + use super::{Reporter, ReviewFailure, ReviewSuccess}; + + struct MockGitHubApi { + pull_request: PullRequestData, + reviews: Vec, + comments: Vec, + commit_authors_json: serde_json::Value, + org_members: Vec, + } + + #[async_trait::async_trait(?Send)] + impl GitHubApiClient for MockGitHubApi { + async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result { + Ok(self.pull_request.clone()) + } + + async fn get_pull_request_reviews( + &self, + _pr_number: u64, + ) -> anyhow::Result> { + Ok(self.reviews.clone()) + } + + async fn get_pull_request_comments( + &self, + _pr_number: u64, + ) -> anyhow::Result> { + Ok(self.comments.clone()) + } + + async fn get_commit_authors( + &self, + _commit_shas: &[&CommitSha], + ) -> anyhow::Result { + serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into) + } + + async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result { + Ok(self + .org_members + .iter() + .any(|member| member == login.as_str())) + } + + async fn ensure_pull_request_has_label( + &self, + _label: &str, + _pr_number: u64, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + fn make_commit( + sha: &str, + author_name: &str, + author_email: &str, + title: &str, + body: &str, + ) -> CommitDetails { + let formatted = format!( + "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|\ + {title}|body-delimiter|{body}|commit-delimiter|" + ); + CommitList::from_str(&formatted) + .expect("test commit should parse") + .into_iter() + .next() + .expect("should have one commit") + } + + fn review(login: &str, state: ReviewState) -> PullRequestReview { + PullRequestReview { + user: Some(GitHubUser { + login: login.to_owned(), + }), + state: Some(state), + } + } + + fn comment(login: &str, body: &str) -> PullRequestComment { + PullRequestComment { + user: GitHubUser { + login: login.to_owned(), + }, + body: Some(body.to_owned()), + } + } + + struct TestScenario { + pull_request: PullRequestData, + reviews: Vec, + comments: Vec, + commit_authors_json: serde_json::Value, + org_members: Vec, + commit: CommitDetails, + } + + impl TestScenario { + fn single_commit() -> Self { + Self { + pull_request: PullRequestData { + number: 1234, + user: Some(GitHubUser { + login: "alice".to_owned(), + }), + merged_by: None, + }, + reviews: vec![], + comments: vec![], + commit_authors_json: serde_json::json!({}), + org_members: vec![], + commit: make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "", + ), + } + } + + fn with_reviews(mut self, reviews: Vec) -> Self { + self.reviews = reviews; + self + } + + fn with_comments(mut self, comments: Vec) -> Self { + self.comments = comments; + self + } + + fn with_org_members(mut self, members: Vec<&str>) -> Self { + self.org_members = members.into_iter().map(str::to_owned).collect(); + self + } + + fn with_commit_authors_json(mut self, json: serde_json::Value) -> Self { + self.commit_authors_json = json; + self + } + + fn with_commit(mut self, commit: CommitDetails) -> Self { + self.commit = commit; + self + } + + async fn run_scenario(self) -> Result { + let mock = MockGitHubApi { + pull_request: self.pull_request, + reviews: self.reviews, + comments: self.comments, + commit_authors_json: self.commit_authors_json, + org_members: self.org_members, + }; + let client = GitHubClient::new(Rc::new(mock)); + let reporter = Reporter::new(CommitList::default(), &client); + reporter.check_commit(&self.commit).await + } + } + + #[tokio::test] + async fn approved_review_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_)))); + } + + #[tokio::test] + async fn non_approved_review_state_is_not_accepted() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Other)]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn review_by_non_org_member_is_not_accepted() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn pr_author_own_approval_review_is_rejected() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("alice", ReviewState::Approved)]) + .with_org_members(vec!["alice"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn pr_author_own_approval_comment_is_rejected() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("alice", "@zed-zippy approve")]) + .with_org_members(vec!["alice"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn approval_comment_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-zippy approve")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn group_approval_comment_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-industries/approved")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn comment_without_approval_pattern_is_not_accepted() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "looks good")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn commit_without_pr_number_is_no_pr_found() { + let result = TestScenario::single_commit() + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing without PR number", + "", + )) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::NoPullRequestFound))); + } + + #[tokio::test] + async fn pr_review_takes_precedence_over_comment() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .with_comments(vec![comment("charlie", "@zed-zippy approve")]) + .with_org_members(vec!["bob", "charlie"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_)))); + } + + #[tokio::test] + async fn comment_takes_precedence_over_co_author() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-zippy approve")]) + .with_commit_authors_json(serde_json::json!({ + "abc12345abc12345": { + "author": { + "name": "Alice", + "email": "alice@test.com", + "user": { "login": "alice" } + }, + "authors": [{ + "name": "Charlie", + "email": "charlie@test.com", + "user": { "login": "charlie" } + }] + } + })) + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "Co-authored-by: Charlie ", + )) + .with_org_members(vec!["bob", "charlie"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn co_author_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_commit_authors_json(serde_json::json!({ + "abc12345abc12345": { + "author": { + "name": "Alice", + "email": "alice@test.com", + "user": { "login": "alice" } + }, + "authors": [{ + "name": "Bob", + "email": "bob@test.com", + "user": { "login": "bob" } + }] + } + })) + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "Co-authored-by: Bob ", + )) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::CoAuthored(_)))); + } + + #[tokio::test] + async fn no_reviews_no_comments_no_coauthors_is_unreviewed() { + let result = TestScenario::single_commit().run_scenario().await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } +} diff --git a/tooling/compliance/src/git.rs b/tooling/compliance/src/git.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa2cb725712de82526d4ce717c2ec3dc97d22885 --- /dev/null +++ b/tooling/compliance/src/git.rs @@ -0,0 +1,591 @@ +#![allow(clippy::disallowed_methods, reason = "This is only used in xtasks")] +use std::{ + fmt::{self, Debug}, + ops::Not, + process::Command, + str::FromStr, + sync::LazyLock, +}; + +use anyhow::{Context, Result, anyhow}; +use derive_more::{Deref, DerefMut, FromStr}; + +use itertools::Itertools; +use regex::Regex; +use semver::Version; +use serde::Deserialize; + +pub trait Subcommand { + type ParsedOutput: FromStr; + + fn args(&self) -> impl IntoIterator; +} + +#[derive(Deref, DerefMut)] +pub struct GitCommand { + #[deref] + #[deref_mut] + subcommand: G, +} + +impl GitCommand { + #[must_use] + pub fn run(subcommand: G) -> Result { + Self { subcommand }.run_impl() + } + + fn run_impl(self) -> Result { + let command_output = Command::new("git") + .args(self.subcommand.args()) + .output() + .context("Failed to spawn command")?; + + if command_output.status.success() { + String::from_utf8(command_output.stdout) + .map_err(|_| anyhow!("Invalid UTF8")) + .and_then(|s| { + G::ParsedOutput::from_str(s.trim()) + .map_err(|e| anyhow!("Failed to parse from string: {e:?}")) + }) + } else { + anyhow::bail!( + "Command failed with exit code {}, stderr: {}", + command_output.status.code().unwrap_or_default(), + String::from_utf8(command_output.stderr).unwrap_or_default() + ) + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum ReleaseChannel { + Stable, + Preview, +} + +impl ReleaseChannel { + pub(crate) fn tag_suffix(&self) -> &'static str { + match self { + ReleaseChannel::Stable => "", + ReleaseChannel::Preview => "-pre", + } + } +} + +#[derive(Debug, Clone)] +pub struct VersionTag(Version, ReleaseChannel); + +impl VersionTag { + pub fn parse(input: &str) -> Result { + // Being a bit more lenient for human inputs + let version = input.strip_prefix('v').unwrap_or(input); + + let (version_str, channel) = version + .strip_suffix("-pre") + .map_or((version, ReleaseChannel::Stable), |version_str| { + (version_str, ReleaseChannel::Preview) + }); + + Version::parse(version_str) + .map(|version| Self(version, channel)) + .map_err(|_| anyhow::anyhow!("Failed to parse version from tag!")) + } + + pub fn version(&self) -> &Version { + &self.0 + } +} + +impl ToString for VersionTag { + fn to_string(&self) -> String { + format!( + "v{version}{channel_suffix}", + version = self.0, + channel_suffix = self.1.tag_suffix() + ) + } +} + +#[derive(Debug, Deref, FromStr, PartialEq, Eq, Hash, Deserialize)] +pub struct CommitSha(pub(crate) String); + +impl CommitSha { + pub fn short(&self) -> &str { + self.0.as_str().split_at(8).0 + } +} + +#[derive(Debug)] +pub struct CommitDetails { + sha: CommitSha, + author: Committer, + title: String, + body: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Committer { + name: String, + email: String, +} + +impl Committer { + pub fn new(name: &str, email: &str) -> Self { + Self { + name: name.to_owned(), + email: email.to_owned(), + } + } +} + +impl fmt::Display for Committer { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "{} ({})", self.name, self.email) + } +} + +impl CommitDetails { + const BODY_DELIMITER: &str = "|body-delimiter|"; + const COMMIT_DELIMITER: &str = "|commit-delimiter|"; + const FIELD_DELIMITER: &str = "|field-delimiter|"; + const FORMAT_STRING: &str = "%H|field-delimiter|%an|field-delimiter|%ae|field-delimiter|%s|body-delimiter|%b|commit-delimiter|"; + + fn parse(line: &str, body: &str) -> Result { + let Some([sha, author_name, author_email, title]) = + line.splitn(4, Self::FIELD_DELIMITER).collect_array() + else { + return Err(anyhow!("Failed to parse commit fields from input {line}")); + }; + + Ok(CommitDetails { + sha: CommitSha(sha.to_owned()), + author: Committer::new(author_name, author_email), + title: title.to_owned(), + body: body.to_owned(), + }) + } + + pub fn pr_number(&self) -> Option { + // Since we use squash merge, all commit titles end with the '(#12345)' pattern. + // While we could strictly speaking index into this directly, go for a slightly + // less prone approach to errors + const PATTERN: &str = " (#"; + self.title + .rfind(PATTERN) + .and_then(|location| { + self.title[location..] + .find(')') + .map(|relative_end| location + PATTERN.len()..location + relative_end) + }) + .and_then(|range| self.title[range].parse().ok()) + } + + pub(crate) fn co_authors(&self) -> Option> { + static CO_AUTHOR_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"Co-authored-by: (.+) <(.+)>").unwrap()); + + let mut co_authors = Vec::new(); + + for cap in CO_AUTHOR_REGEX.captures_iter(&self.body.as_ref()) { + let Some((name, email)) = cap + .get(1) + .map(|m| m.as_str()) + .zip(cap.get(2).map(|m| m.as_str())) + else { + continue; + }; + co_authors.push(Committer::new(name, email)); + } + + co_authors.is_empty().not().then_some(co_authors) + } + + pub(crate) fn author(&self) -> &Committer { + &self.author + } + + pub(crate) fn title(&self) -> &str { + &self.title + } + + pub(crate) fn sha(&self) -> &CommitSha { + &self.sha + } +} + +#[derive(Debug, Deref, Default, DerefMut)] +pub struct CommitList(Vec); + +impl CommitList { + pub fn range(&self) -> Option { + self.0 + .first() + .zip(self.0.last()) + .map(|(first, last)| format!("{}..{}", first.sha().0, last.sha().0)) + } +} + +impl IntoIterator for CommitList { + type IntoIter = std::vec::IntoIter; + type Item = CommitDetails; + + fn into_iter(self) -> std::vec::IntoIter { + self.0.into_iter() + } +} + +impl FromStr for CommitList { + type Err = anyhow::Error; + + fn from_str(input: &str) -> Result { + Ok(CommitList( + input + .split(CommitDetails::COMMIT_DELIMITER) + .filter(|commit_details| !commit_details.is_empty()) + .map(|commit_details| { + let (line, body) = commit_details + .trim() + .split_once(CommitDetails::BODY_DELIMITER) + .expect("Missing body delimiter"); + CommitDetails::parse(line, body) + .expect("Parsing from the output should succeed") + }) + .collect(), + )) + } +} + +pub struct GetVersionTags; + +impl Subcommand for GetVersionTags { + type ParsedOutput = VersionTagList; + + fn args(&self) -> impl IntoIterator { + ["tag", "-l", "v*"].map(ToOwned::to_owned) + } +} + +pub struct VersionTagList(Vec); + +impl VersionTagList { + pub fn sorted(mut self) -> Self { + self.0.sort_by(|a, b| a.version().cmp(b.version())); + self + } + + pub fn find_previous_minor_version(&self, version_tag: &VersionTag) -> Option<&VersionTag> { + self.0 + .iter() + .take_while(|tag| tag.version() < version_tag.version()) + .collect_vec() + .into_iter() + .rev() + .find(|tag| { + (tag.version().major < version_tag.version().major + || (tag.version().major == version_tag.version().major + && tag.version().minor < version_tag.version().minor)) + && tag.version().patch == 0 + }) + } +} + +impl FromStr for VersionTagList { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let version_tags = s.lines().flat_map(VersionTag::parse).collect_vec(); + + version_tags + .is_empty() + .not() + .then_some(Self(version_tags)) + .ok_or_else(|| anyhow::anyhow!("No version tags found")) + } +} + +pub struct CommitsFromVersionToHead { + version_tag: VersionTag, + branch: String, +} + +impl CommitsFromVersionToHead { + pub fn new(version_tag: VersionTag, branch: String) -> Self { + Self { + version_tag, + branch, + } + } +} + +impl Subcommand for CommitsFromVersionToHead { + type ParsedOutput = CommitList; + + fn args(&self) -> impl IntoIterator { + [ + "log".to_string(), + format!("--pretty=format:{}", CommitDetails::FORMAT_STRING), + format!( + "{version}..{branch}", + version = self.version_tag.to_string(), + branch = self.branch + ), + ] + } +} + +pub struct NoOutput; + +impl FromStr for NoOutput { + type Err = anyhow::Error; + + fn from_str(_: &str) -> Result { + Ok(NoOutput) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn parse_stable_version_tag() { + let tag = VersionTag::parse("v0.172.8").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 8); + assert_eq!(tag.1, ReleaseChannel::Stable); + } + + #[test] + fn parse_preview_version_tag() { + let tag = VersionTag::parse("v0.172.1-pre").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 1); + assert_eq!(tag.1, ReleaseChannel::Preview); + } + + #[test] + fn parse_version_tag_without_v_prefix() { + let tag = VersionTag::parse("0.172.8").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 8); + } + + #[test] + fn parse_invalid_version_tag() { + let result = VersionTag::parse("vConradTest"); + assert!(result.is_err()); + } + + #[test] + fn version_tag_stable_roundtrip() { + let tag = VersionTag::parse("v0.172.8").unwrap(); + assert_eq!(tag.to_string(), "v0.172.8"); + } + + #[test] + fn version_tag_preview_roundtrip() { + let tag = VersionTag::parse("v0.172.1-pre").unwrap(); + assert_eq!(tag.to_string(), "v0.172.1-pre"); + } + + #[test] + fn sorted_orders_by_semver() { + let input = indoc! {" + v0.172.8 + v0.170.1 + v0.171.4 + v0.170.2 + v0.172.11 + v0.171.3 + v0.172.9 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + for window in list.0.windows(2) { + assert!( + window[0].version() <= window[1].version(), + "{} should come before {}", + window[0].to_string(), + window[1].to_string() + ); + } + assert_eq!(list.0[0].to_string(), "v0.170.1"); + assert_eq!(list.0[list.0.len() - 1].to_string(), "v0.172.11"); + } + + #[test] + fn find_previous_minor_for_173_returns_172() { + let input = indoc! {" + v0.170.1 + v0.170.2 + v0.171.3 + v0.171.4 + v0.172.0 + v0.172.8 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v0.173.0").unwrap(); + let previous = list.find_previous_minor_version(&target).unwrap(); + assert_eq!(previous.version().major, 0); + assert_eq!(previous.version().minor, 172); + assert_eq!(previous.version().patch, 0); + } + + #[test] + fn find_previous_minor_skips_same_minor() { + let input = indoc! {" + v0.172.8 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v0.172.8").unwrap(); + assert!(list.find_previous_minor_version(&target).is_none()); + } + + #[test] + fn find_previous_minor_with_major_version_gap() { + let input = indoc! {" + v0.172.0 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v1.0.0").unwrap(); + let previous = list.find_previous_minor_version(&target).unwrap(); + assert_eq!(previous.to_string(), "v0.172.0"); + } + + #[test] + fn find_previous_minor_requires_zero_patch_version() { + let input = indoc! {" + v0.172.1 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v1.0.0").unwrap(); + assert!(list.find_previous_minor_version(&target).is_none()); + } + + #[test] + fn parse_tag_list_from_real_tags() { + let input = indoc! {" + v0.9999-temporary + vConradTest + v0.172.8 + "}; + let list = VersionTagList::from_str(input).unwrap(); + assert_eq!(list.0.len(), 1); + assert_eq!(list.0[0].to_string(), "v0.172.8"); + } + + #[test] + fn parse_empty_tag_list_fails() { + let result = VersionTagList::from_str(""); + assert!(result.is_err()); + } + + #[test] + fn pr_number_from_squash_merge_title() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Add cool feature (#12345)", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), Some(12345)); + } + + #[test] + fn pr_number_missing() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some commit without PR ref", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), None); + } + + #[test] + fn pr_number_takes_last_match() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Fix (#123) and refactor (#456)", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), Some(456)); + } + + #[test] + fn co_authors_parsed_from_body() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some title", + d = CommitDetails::FIELD_DELIMITER + ); + let body = indoc! {" + Co-authored-by: Alice Smith + Co-authored-by: Bob Jones + "}; + let commit = CommitDetails::parse(&line, body).unwrap(); + let co_authors = commit.co_authors().unwrap(); + assert_eq!(co_authors.len(), 2); + assert_eq!( + co_authors[0], + Committer::new("Alice Smith", "alice@example.com") + ); + assert_eq!( + co_authors[1], + Committer::new("Bob Jones", "bob@example.com") + ); + } + + #[test] + fn no_co_authors_returns_none() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some title", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert!(commit.co_authors().is_none()); + } + + #[test] + fn commit_sha_short_returns_first_8_chars() { + let sha = CommitSha("abcdef1234567890abcdef1234567890abcdef12".into()); + assert_eq!(sha.short(), "abcdef12"); + } + + #[test] + fn parse_commit_list_from_git_log_format() { + let fd = CommitDetails::FIELD_DELIMITER; + let bd = CommitDetails::BODY_DELIMITER; + let cd = CommitDetails::COMMIT_DELIMITER; + + let input = format!( + "sha111{fd}Alice{fd}alice@test.com{fd}First commit (#100){bd}First body{cd}sha222{fd}Bob{fd}bob@test.com{fd}Second commit (#200){bd}Second body{cd}" + ); + + let list = CommitList::from_str(&input).unwrap(); + assert_eq!(list.0.len(), 2); + + assert_eq!(list.0[0].sha().0, "sha111"); + assert_eq!( + list.0[0].author(), + &Committer::new("Alice", "alice@test.com") + ); + assert_eq!(list.0[0].title(), "First commit (#100)"); + assert_eq!(list.0[0].pr_number(), Some(100)); + assert_eq!(list.0[0].body, "First body"); + + assert_eq!(list.0[1].sha().0, "sha222"); + assert_eq!(list.0[1].author(), &Committer::new("Bob", "bob@test.com")); + assert_eq!(list.0[1].title(), "Second commit (#200)"); + assert_eq!(list.0[1].pr_number(), Some(200)); + assert_eq!(list.0[1].body, "Second body"); + } +} diff --git a/tooling/compliance/src/github.rs b/tooling/compliance/src/github.rs new file mode 100644 index 0000000000000000000000000000000000000000..ebd2f2c75f5d0083632a8f70e3ea9dd2680d4eb5 --- /dev/null +++ b/tooling/compliance/src/github.rs @@ -0,0 +1,424 @@ +use std::{collections::HashMap, fmt, ops::Not, rc::Rc}; + +use anyhow::Result; +use derive_more::Deref; +use serde::Deserialize; + +use crate::git::CommitSha; + +pub const PR_REVIEW_LABEL: &str = "PR state:needs review"; + +#[derive(Debug, Clone)] +pub struct GitHubUser { + pub login: String, +} + +#[derive(Debug, Clone)] +pub struct PullRequestData { + pub number: u64, + pub user: Option, + pub merged_by: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReviewState { + Approved, + Other, +} + +#[derive(Debug, Clone)] +pub struct PullRequestReview { + pub user: Option, + pub state: Option, +} + +#[derive(Debug, Clone)] +pub struct PullRequestComment { + pub user: GitHubUser, + pub body: Option, +} + +#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)] +pub struct GithubLogin { + login: String, +} + +impl GithubLogin { + pub(crate) fn new(login: String) -> Self { + Self { login } + } +} + +impl fmt::Display for GithubLogin { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "@{}", self.login) + } +} + +#[derive(Debug, Deserialize, Clone)] +pub struct CommitAuthor { + name: String, + email: String, + user: Option, +} + +impl CommitAuthor { + pub(crate) fn user(&self) -> Option<&GithubLogin> { + self.user.as_ref() + } +} + +impl PartialEq for CommitAuthor { + fn eq(&self, other: &Self) -> bool { + self.user.as_ref().zip(other.user.as_ref()).map_or_else( + || self.email == other.email || self.name == other.name, + |(l, r)| l == r, + ) + } +} + +impl fmt::Display for CommitAuthor { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.user.as_ref() { + Some(user) => write!(formatter, "{} ({user})", self.name), + None => write!(formatter, "{} ({})", self.name, self.email), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct CommitAuthors { + #[serde(rename = "author")] + primary_author: CommitAuthor, + #[serde(rename = "authors")] + co_authors: Vec, +} + +impl CommitAuthors { + pub fn co_authors(&self) -> Option> { + self.co_authors.is_empty().not().then(|| { + self.co_authors + .iter() + .filter(|co_author| *co_author != &self.primary_author) + }) + } +} + +#[derive(Debug, Deserialize, Deref)] +pub struct AuthorsForCommits(HashMap); + +#[async_trait::async_trait(?Send)] +pub trait GitHubApiClient { + async fn get_pull_request(&self, pr_number: u64) -> Result; + async fn get_pull_request_reviews(&self, pr_number: u64) -> Result>; + async fn get_pull_request_comments(&self, pr_number: u64) -> Result>; + async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result; + async fn check_org_membership(&self, login: &GithubLogin) -> Result; + async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>; +} + +pub struct GitHubClient { + api: Rc, +} + +impl GitHubClient { + pub fn new(api: Rc) -> Self { + Self { api } + } + + #[cfg(feature = "octo-client")] + pub async fn for_app(app_id: u64, app_private_key: &str) -> Result { + let client = OctocrabClient::new(app_id, app_private_key).await?; + Ok(Self::new(Rc::new(client))) + } + + pub async fn get_pull_request(&self, pr_number: u64) -> Result { + self.api.get_pull_request(pr_number).await + } + + pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result> { + self.api.get_pull_request_reviews(pr_number).await + } + + pub async fn get_pull_request_comments( + &self, + pr_number: u64, + ) -> Result> { + self.api.get_pull_request_comments(pr_number).await + } + + pub async fn get_commit_authors<'a>( + &self, + commit_shas: impl IntoIterator, + ) -> Result { + let shas: Vec<&CommitSha> = commit_shas.into_iter().collect(); + self.api.get_commit_authors(&shas).await + } + + pub async fn check_org_membership(&self, login: &GithubLogin) -> Result { + self.api.check_org_membership(login).await + } + + pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> { + self.api + .ensure_pull_request_has_label(label, pr_number) + .await + } +} + +#[cfg(feature = "octo-client")] +mod octo_client { + use anyhow::{Context, Result}; + use futures::TryStreamExt as _; + use itertools::Itertools; + use jsonwebtoken::EncodingKey; + use octocrab::{ + Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState, + service::middleware::cache::mem::InMemoryCache, + }; + use serde::de::DeserializeOwned; + use tokio::pin; + + use crate::git::CommitSha; + + use super::{ + AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment, + PullRequestData, PullRequestReview, ReviewState, + }; + + const PAGE_SIZE: u8 = 100; + const ORG: &str = "zed-industries"; + const REPO: &str = "zed"; + + pub struct OctocrabClient { + client: Octocrab, + } + + impl OctocrabClient { + pub async fn new(app_id: u64, app_private_key: &str) -> Result { + let octocrab = Octocrab::builder() + .cache(InMemoryCache::new()) + .app( + app_id.into(), + EncodingKey::from_rsa_pem(app_private_key.as_bytes())?, + ) + .build()?; + + let installations = octocrab + .apps() + .installations() + .send() + .await + .context("Failed to fetch installations")? + .take_items(); + + let installation_id = installations + .into_iter() + .find(|installation| installation.account.login == ORG) + .context("Could not find Zed repository in installations")? + .id; + + let client = octocrab.installation(installation_id)?; + Ok(Self { client }) + } + + fn build_co_authors_query<'a>(shas: impl IntoIterator) -> String { + const FRAGMENT: &str = r#" + ... on Commit { + author { + name + email + user { login } + } + authors(first: 10) { + nodes { + name + email + user { login } + } + } + } + "#; + + let objects: String = shas + .into_iter() + .map(|commit_sha| { + format!( + "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}", + sha = **commit_sha + ) + }) + .join("\n"); + + format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}") + .replace("\n", "") + } + + async fn graphql( + &self, + query: &serde_json::Value, + ) -> octocrab::Result { + self.client.graphql(query).await + } + + async fn get_all( + &self, + page: Page, + ) -> octocrab::Result> { + self.get_filtered(page, |_| true).await + } + + async fn get_filtered( + &self, + page: Page, + predicate: impl Fn(&T) -> bool, + ) -> octocrab::Result> { + let stream = page.into_stream(&self.client); + pin!(stream); + + let mut results = Vec::new(); + + while let Some(item) = stream.try_next().await? + && predicate(&item) + { + results.push(item); + } + + Ok(results) + } + } + + #[async_trait::async_trait(?Send)] + impl GitHubApiClient for OctocrabClient { + async fn get_pull_request(&self, pr_number: u64) -> Result { + let pr = self.client.pulls(ORG, REPO).get(pr_number).await?; + Ok(PullRequestData { + number: pr.number, + user: pr.user.map(|user| GitHubUser { login: user.login }), + merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }), + }) + } + + async fn get_pull_request_reviews(&self, pr_number: u64) -> Result> { + let page = self + .client + .pulls(ORG, REPO) + .list_reviews(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?; + + let reviews = self.get_all(page).await?; + + Ok(reviews + .into_iter() + .map(|review| PullRequestReview { + user: review.user.map(|user| GitHubUser { login: user.login }), + state: review.state.map(|state| match state { + OctocrabReviewState::Approved => ReviewState::Approved, + _ => ReviewState::Other, + }), + }) + .collect()) + } + + async fn get_pull_request_comments( + &self, + pr_number: u64, + ) -> Result> { + let page = self + .client + .issues(ORG, REPO) + .list_comments(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?; + + let comments = self.get_all(page).await?; + + Ok(comments + .into_iter() + .map(|comment| PullRequestComment { + user: GitHubUser { + login: comment.user.login, + }, + body: comment.body, + }) + .collect()) + } + + async fn get_commit_authors( + &self, + commit_shas: &[&CommitSha], + ) -> Result { + let query = Self::build_co_authors_query(commit_shas.iter().copied()); + let query = serde_json::json!({ "query": query }); + let mut response = self.graphql::(&query).await?; + + response + .get_mut("data") + .and_then(|data| data.get_mut("repository")) + .and_then(|repo| repo.as_object_mut()) + .ok_or_else(|| anyhow::anyhow!("Unexpected response format!")) + .and_then(|commit_data| { + let mut response_map = serde_json::Map::with_capacity(commit_data.len()); + + for (key, value) in commit_data.iter_mut() { + let key_without_prefix = key.strip_prefix("commit").unwrap_or(key); + if let Some(authors) = value.get_mut("authors") { + if let Some(nodes) = authors.get("nodes") { + *authors = nodes.clone(); + } + } + + response_map.insert(key_without_prefix.to_owned(), value.clone()); + } + + serde_json::from_value(serde_json::Value::Object(response_map)) + .context("Failed to deserialize commit authors") + }) + } + + async fn check_org_membership(&self, login: &GithubLogin) -> Result { + let page = self + .client + .orgs(ORG) + .list_members() + .per_page(PAGE_SIZE) + .send() + .await?; + + let members = self.get_all(page).await?; + + Ok(members + .into_iter() + .any(|member| member.login == login.as_str())) + } + + async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> { + if self + .get_filtered( + self.client + .issues(ORG, REPO) + .list_labels_for_issue(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?, + |pr_label| pr_label.name == label, + ) + .await + .is_ok_and(|l| l.is_empty()) + { + self.client + .issues(ORG, REPO) + .add_labels(pr_number, &[label.to_owned()]) + .await?; + } + + Ok(()) + } + } +} + +#[cfg(feature = "octo-client")] +pub use octo_client::OctocrabClient; diff --git a/tooling/compliance/src/lib.rs b/tooling/compliance/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..9476412c6d6d1f56b1396bf5d700924549c707da --- /dev/null +++ b/tooling/compliance/src/lib.rs @@ -0,0 +1,4 @@ +pub mod checks; +pub mod git; +pub mod github; +pub mod report; diff --git a/tooling/compliance/src/report.rs b/tooling/compliance/src/report.rs new file mode 100644 index 0000000000000000000000000000000000000000..16df145394726b97382884fbdfdc3164c0029786 --- /dev/null +++ b/tooling/compliance/src/report.rs @@ -0,0 +1,446 @@ +use std::{ + fs::{self, File}, + io::{BufWriter, Write}, + path::Path, +}; + +use anyhow::Context as _; +use derive_more::Display; +use itertools::{Either, Itertools}; + +use crate::{ + checks::{ReviewFailure, ReviewResult, ReviewSuccess}, + git::CommitDetails, +}; + +const PULL_REQUEST_BASE_URL: &str = "https://github.com/zed-industries/zed/pull"; + +#[derive(Debug)] +pub struct ReportEntry { + pub commit: CommitDetails, + reason: R, +} + +impl ReportEntry { + fn commit_cell(&self) -> String { + let title = escape_markdown_link_text(self.commit.title()); + + match self.commit.pr_number() { + Some(pr_number) => format!("[{title}]({PULL_REQUEST_BASE_URL}/{pr_number})"), + None => escape_markdown_table_text(self.commit.title()), + } + } + + fn pull_request_cell(&self) -> String { + self.commit + .pr_number() + .map(|pr_number| format!("#{pr_number}")) + .unwrap_or_else(|| "—".to_owned()) + } + + fn author_cell(&self) -> String { + escape_markdown_table_text(&self.commit.author().to_string()) + } + + fn reason_cell(&self) -> String { + escape_markdown_table_text(&self.reason.to_string()) + } +} + +impl ReportEntry { + fn issue_kind(&self) -> IssueKind { + match self.reason { + ReviewFailure::Other(_) => IssueKind::Error, + _ => IssueKind::NotReviewed, + } + } +} + +impl ReportEntry { + fn reviewers_cell(&self) -> String { + match &self.reason.reviewers() { + Ok(reviewers) => escape_markdown_table_text(&reviewers), + Err(_) => "—".to_owned(), + } + } +} + +#[derive(Debug, Default)] +pub struct ReportSummary { + pub pull_requests: usize, + pub reviewed: usize, + pub not_reviewed: usize, + pub errors: usize, +} + +pub enum ReportReviewSummary { + MissingReviews, + MissingReviewsWithErrors, + NoIssuesFound, +} + +impl ReportSummary { + fn from_entries(entries: &[ReportEntry]) -> Self { + Self { + pull_requests: entries + .iter() + .filter_map(|entry| entry.commit.pr_number()) + .unique() + .count(), + reviewed: entries.iter().filter(|entry| entry.reason.is_ok()).count(), + not_reviewed: entries + .iter() + .filter(|entry| { + matches!( + entry.reason, + Err(ReviewFailure::NoPullRequestFound | ReviewFailure::Unreviewed) + ) + }) + .count(), + errors: entries + .iter() + .filter(|entry| matches!(entry.reason, Err(ReviewFailure::Other(_)))) + .count(), + } + } + + pub fn review_summary(&self) -> ReportReviewSummary { + match self.not_reviewed { + 0 if self.errors == 0 => ReportReviewSummary::NoIssuesFound, + 1.. if self.errors == 0 => ReportReviewSummary::MissingReviews, + _ => ReportReviewSummary::MissingReviewsWithErrors, + } + } + + fn has_errors(&self) -> bool { + self.errors > 0 + } +} + +#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, PartialOrd, Ord)] +enum IssueKind { + #[display("Error")] + Error, + #[display("Not reviewed")] + NotReviewed, +} + +#[derive(Debug, Default)] +pub struct Report { + entries: Vec>, +} + +impl Report { + pub fn new() -> Self { + Self::default() + } + + pub fn add(&mut self, commit: CommitDetails, result: ReviewResult) { + self.entries.push(ReportEntry { + commit, + reason: result, + }); + } + + pub fn errors(&self) -> impl Iterator> { + self.entries.iter().filter(|entry| entry.reason.is_err()) + } + + pub fn summary(&self) -> ReportSummary { + ReportSummary::from_entries(&self.entries) + } + + pub fn write_markdown(self, path: impl AsRef) -> anyhow::Result<()> { + let path = path.as_ref(); + + if let Some(parent) = path + .parent() + .filter(|parent| !parent.as_os_str().is_empty()) + { + fs::create_dir_all(parent).with_context(|| { + format!( + "Failed to create parent directory for markdown report at {}", + path.display() + ) + })?; + } + + let summary = self.summary(); + let (successes, mut issues): (Vec<_>, Vec<_>) = + self.entries + .into_iter() + .partition_map(|entry| match entry.reason { + Ok(success) => Either::Left(ReportEntry { + reason: success, + commit: entry.commit, + }), + Err(fail) => Either::Right(ReportEntry { + reason: fail, + commit: entry.commit, + }), + }); + + issues.sort_by_key(|entry| entry.issue_kind()); + + let file = File::create(path) + .with_context(|| format!("Failed to create markdown report at {}", path.display()))?; + let mut writer = BufWriter::new(file); + + writeln!(writer, "# Compliance report")?; + writeln!(writer)?; + writeln!(writer, "## Overview")?; + writeln!(writer)?; + writeln!(writer, "- PRs: {}", summary.pull_requests)?; + writeln!(writer, "- Reviewed: {}", summary.reviewed)?; + writeln!(writer, "- Not reviewed: {}", summary.not_reviewed)?; + if summary.has_errors() { + writeln!(writer, "- Errors: {}", summary.errors)?; + } + writeln!(writer)?; + + write_issue_table(&mut writer, &issues, &summary)?; + write_success_table(&mut writer, &successes)?; + + writer + .flush() + .with_context(|| format!("Failed to flush markdown report to {}", path.display())) + } +} + +fn write_issue_table( + writer: &mut impl Write, + issues: &[ReportEntry], + summary: &ReportSummary, +) -> std::io::Result<()> { + if summary.has_errors() { + writeln!(writer, "## Errors and unreviewed commits")?; + } else { + writeln!(writer, "## Unreviewed commits")?; + } + writeln!(writer)?; + + if issues.is_empty() { + if summary.has_errors() { + writeln!(writer, "No errors or unreviewed commits found.")?; + } else { + writeln!(writer, "No unreviewed commits found.")?; + } + writeln!(writer)?; + return Ok(()); + } + + writeln!(writer, "| Commit | PR | Author | Outcome | Reason |")?; + writeln!(writer, "| --- | --- | --- | --- | --- |")?; + + for entry in issues { + let issue_kind = entry.issue_kind(); + writeln!( + writer, + "| {} | {} | {} | {} | {} |", + entry.commit_cell(), + entry.pull_request_cell(), + entry.author_cell(), + issue_kind, + entry.reason_cell(), + )?; + } + + writeln!(writer)?; + Ok(()) +} + +fn write_success_table( + writer: &mut impl Write, + successful_entries: &[ReportEntry], +) -> std::io::Result<()> { + writeln!(writer, "## Successful commits")?; + writeln!(writer)?; + + if successful_entries.is_empty() { + writeln!(writer, "No successful commits found.")?; + writeln!(writer)?; + return Ok(()); + } + + writeln!(writer, "| Commit | PR | Author | Reviewers | Reason |")?; + writeln!(writer, "| --- | --- | --- | --- | --- |")?; + + for entry in successful_entries { + writeln!( + writer, + "| {} | {} | {} | {} | {} |", + entry.commit_cell(), + entry.pull_request_cell(), + entry.author_cell(), + entry.reviewers_cell(), + entry.reason_cell(), + )?; + } + + writeln!(writer)?; + Ok(()) +} + +fn escape_markdown_link_text(input: &str) -> String { + escape_markdown_table_text(input) + .replace('[', r"\[") + .replace(']', r"\]") +} + +fn escape_markdown_table_text(input: &str) -> String { + input + .replace('\\', r"\\") + .replace('|', r"\|") + .replace('\r', "") + .replace('\n', "
") +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use crate::{ + checks::{ReviewFailure, ReviewSuccess}, + git::{CommitDetails, CommitList}, + github::{GitHubUser, PullRequestReview, ReviewState}, + }; + + use super::{Report, ReportReviewSummary}; + + fn make_commit( + sha: &str, + author_name: &str, + author_email: &str, + title: &str, + body: &str, + ) -> CommitDetails { + let formatted = format!( + "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|{title}|body-delimiter|{body}|commit-delimiter|" + ); + CommitList::from_str(&formatted) + .expect("test commit should parse") + .into_iter() + .next() + .expect("should have one commit") + } + + fn reviewed() -> ReviewSuccess { + ReviewSuccess::PullRequestReviewed(vec![PullRequestReview { + user: Some(GitHubUser { + login: "reviewer".to_owned(), + }), + state: Some(ReviewState::Approved), + }]) + } + + #[test] + fn report_summary_counts_are_accurate() { + let mut report = Report::new(); + + report.add( + make_commit( + "aaa", + "Alice", + "alice@test.com", + "Reviewed commit (#100)", + "", + ), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Unreviewed commit (#200)", ""), + Err(ReviewFailure::Unreviewed), + ); + report.add( + make_commit("ccc", "Carol", "carol@test.com", "No PR commit", ""), + Err(ReviewFailure::NoPullRequestFound), + ); + report.add( + make_commit("ddd", "Dave", "dave@test.com", "Error commit (#300)", ""), + Err(ReviewFailure::Other(anyhow::anyhow!("some error"))), + ); + + let summary = report.summary(); + assert_eq!(summary.pull_requests, 3); + assert_eq!(summary.reviewed, 1); + assert_eq!(summary.not_reviewed, 2); + assert_eq!(summary.errors, 1); + } + + #[test] + fn report_summary_all_reviewed_is_no_issues() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "First (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Second (#200)", ""), + Ok(reviewed()), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::NoIssuesFound + )); + } + + #[test] + fn report_summary_missing_reviews_only() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "Reviewed (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Unreviewed (#200)", ""), + Err(ReviewFailure::Unreviewed), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::MissingReviews + )); + } + + #[test] + fn report_summary_errors_and_missing_reviews() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "Unreviewed (#100)", ""), + Err(ReviewFailure::Unreviewed), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Errored (#200)", ""), + Err(ReviewFailure::Other(anyhow::anyhow!("check failed"))), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::MissingReviewsWithErrors + )); + } + + #[test] + fn report_summary_deduplicates_pull_requests() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "First change (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Second change (#100)", ""), + Ok(reviewed()), + ); + + let summary = report.summary(); + assert_eq!(summary.pull_requests, 1); + } +} diff --git a/tooling/xtask/Cargo.toml b/tooling/xtask/Cargo.toml index 21090d1304ea0eab9ad70808b91f76789f2fd923..f9628dfa6390872210df9f3cc00b367d9420f522 100644 --- a/tooling/xtask/Cargo.toml +++ b/tooling/xtask/Cargo.toml @@ -15,7 +15,8 @@ backtrace.workspace = true cargo_metadata.workspace = true cargo_toml.workspace = true clap = { workspace = true, features = ["derive"] } -toml.workspace = true +compliance = { workspace = true, features = ["octo-client"] } +gh-workflow.workspace = true indoc.workspace = true indexmap.workspace = true itertools.workspace = true @@ -24,5 +25,6 @@ serde.workspace = true serde_json.workspace = true serde_yaml = "0.9.34" strum.workspace = true +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } +toml.workspace = true toml_edit.workspace = true -gh-workflow.workspace = true diff --git a/tooling/xtask/src/main.rs b/tooling/xtask/src/main.rs index 05afe3c766829137a7c2ba6e73d57638624d5e6a..c442f1c509e28172b7283c95e518eee743b7730c 100644 --- a/tooling/xtask/src/main.rs +++ b/tooling/xtask/src/main.rs @@ -15,6 +15,7 @@ struct Args { enum CliCommand { /// Runs `cargo clippy`. Clippy(tasks::clippy::ClippyArgs), + Compliance(tasks::compliance::ComplianceArgs), Licenses(tasks::licenses::LicensesArgs), /// Checks that packages conform to a set of standards. PackageConformity(tasks::package_conformity::PackageConformityArgs), @@ -31,6 +32,7 @@ fn main() -> Result<()> { match args.command { CliCommand::Clippy(args) => tasks::clippy::run_clippy(args), + CliCommand::Compliance(args) => tasks::compliance::check_compliance(args), CliCommand::Licenses(args) => tasks::licenses::run_licenses(args), CliCommand::PackageConformity(args) => { tasks::package_conformity::run_package_conformity(args) diff --git a/tooling/xtask/src/tasks.rs b/tooling/xtask/src/tasks.rs index 80f504fa0345de0d5bc71c5b44c71846f04c50bc..ea67d0abc5fcbd8e85f40251a7997bc6fbbbca1f 100644 --- a/tooling/xtask/src/tasks.rs +++ b/tooling/xtask/src/tasks.rs @@ -1,4 +1,5 @@ pub mod clippy; +pub mod compliance; pub mod licenses; pub mod package_conformity; pub mod publish_gpui; diff --git a/tooling/xtask/src/tasks/compliance.rs b/tooling/xtask/src/tasks/compliance.rs new file mode 100644 index 0000000000000000000000000000000000000000..78cc32b23f3160ae950aaa5e374071dd107ec350 --- /dev/null +++ b/tooling/xtask/src/tasks/compliance.rs @@ -0,0 +1,135 @@ +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use clap::Parser; + +use compliance::{ + checks::Reporter, + git::{CommitsFromVersionToHead, GetVersionTags, GitCommand, VersionTag}, + github::GitHubClient, + report::ReportReviewSummary, +}; + +#[derive(Parser)] +pub struct ComplianceArgs { + #[arg(value_parser = VersionTag::parse)] + // The version to be on the lookout for + pub(crate) version_tag: VersionTag, + #[arg(long)] + // The markdown file to write the compliance report to + report_path: PathBuf, + #[arg(long)] + // An optional branch to use instead of the determined version branch + branch: Option, +} + +impl ComplianceArgs { + pub(crate) fn version_tag(&self) -> &VersionTag { + &self.version_tag + } + + fn version_branch(&self) -> String { + self.branch.clone().unwrap_or_else(|| { + format!( + "v{major}.{minor}.x", + major = self.version_tag().version().major, + minor = self.version_tag().version().minor + ) + }) + } +} + +async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { + let app_id = std::env::var("GITHUB_APP_ID").context("Missing GITHUB_APP_ID")?; + let key = std::env::var("GITHUB_APP_KEY").context("Missing GITHUB_APP_KEY")?; + + let tag = args.version_tag(); + + let previous_version = GitCommand::run(GetVersionTags)? + .sorted() + .find_previous_minor_version(&tag) + .cloned() + .ok_or_else(|| { + anyhow::anyhow!( + "Could not find previous version for tag {tag}", + tag = tag.to_string() + ) + })?; + + println!( + "Checking compliance for version {} with version {} as base", + tag.version(), + previous_version.version() + ); + + let commits = GitCommand::run(CommitsFromVersionToHead::new( + previous_version, + args.version_branch(), + ))?; + + let Some(range) = commits.range() else { + anyhow::bail!("No commits found to check"); + }; + + println!("Checking commit range {range}, {} total", commits.len()); + + let client = GitHubClient::for_app( + app_id.parse().context("Failed to parse app ID as int")?, + key.as_ref(), + ) + .await?; + + println!("Initialized GitHub client for app ID {app_id}"); + + let report = Reporter::new(commits, &client).generate_report().await?; + + println!( + "Generated report for version {}", + args.version_tag().to_string() + ); + + let summary = report.summary(); + + println!( + "Applying compliance labels to {} pull requests", + summary.pull_requests + ); + + for report in report.errors() { + if let Some(pr_number) = report.commit.pr_number() { + println!("Adding review label to PR {}...", pr_number); + + client + .add_label_to_pull_request(compliance::github::PR_REVIEW_LABEL, pr_number) + .await?; + } + } + + let report_path = args.report_path.with_extension("md"); + + report.write_markdown(&report_path)?; + + println!("Wrote compliance report to {}", report_path.display()); + + match summary.review_summary() { + ReportReviewSummary::MissingReviews => Err(anyhow::anyhow!( + "Compliance check failed, found {} commits not reviewed", + summary.not_reviewed + )), + ReportReviewSummary::MissingReviewsWithErrors => Err(anyhow::anyhow!( + "Compliance check failed with {} unreviewed commits and {} other issues", + summary.not_reviewed, + summary.errors + )), + ReportReviewSummary::NoIssuesFound => { + println!("No issues found, compliance check passed."); + Ok(()) + } + } +} + +pub fn check_compliance(args: ComplianceArgs) -> Result<()> { + tokio::runtime::Runtime::new() + .context("Failed to create tokio runtime") + .and_then(|handle| handle.block_on(check_compliance_impl(args))) +} diff --git a/tooling/xtask/src/tasks/workflows.rs b/tooling/xtask/src/tasks/workflows.rs index 414c0b7fd8dc2a99027d8687bcf1d4dbe9c4bb85..387c739a1ac12d4d65d11f33777525c59f05f7f2 100644 --- a/tooling/xtask/src/tasks/workflows.rs +++ b/tooling/xtask/src/tasks/workflows.rs @@ -11,6 +11,7 @@ mod autofix_pr; mod bump_patch_version; mod cherry_pick; mod compare_perf; +mod compliance_check; mod danger; mod deploy_collab; mod extension_auto_bump; @@ -197,6 +198,7 @@ pub fn run_workflows(args: GenerateWorkflowArgs) -> Result<()> { WorkflowFile::zed(bump_patch_version::bump_patch_version), WorkflowFile::zed(cherry_pick::cherry_pick), WorkflowFile::zed(compare_perf::compare_perf), + WorkflowFile::zed(compliance_check::compliance_check), WorkflowFile::zed(danger::danger), WorkflowFile::zed(deploy_collab::deploy_collab), WorkflowFile::zed(extension_bump::extension_bump), diff --git a/tooling/xtask/src/tasks/workflows/compliance_check.rs b/tooling/xtask/src/tasks/workflows/compliance_check.rs new file mode 100644 index 0000000000000000000000000000000000000000..9e2f4ae1e588c545266ec5a8246ac9781c6b668b --- /dev/null +++ b/tooling/xtask/src/tasks/workflows/compliance_check.rs @@ -0,0 +1,66 @@ +use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Workflow}; + +use crate::tasks::workflows::{ + runners, + steps::{self, CommonJobConditions, named}, + vars::{self, StepOutput}, +}; + +pub fn compliance_check() -> Workflow { + let check = scheduled_compliance_check(); + + named::workflow() + .on(Event::default().schedule([Schedule::new("30 17 * * 2")])) + .add_env(("CARGO_TERM_COLOR", "always")) + .add_job(check.name, check.job) +} + +fn scheduled_compliance_check() -> steps::NamedJob { + let determine_version_step = named::bash(indoc::indoc! {r#" + VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]') + if [ -z "$VERSION" ]; then + echo "Could not determine version from crates/zed/Cargo.toml" + exit 1 + fi + TAG="v${VERSION}-pre" + echo "Checking compliance for $TAG" + echo "tag=$TAG" >> "$GITHUB_OUTPUT" + "#}) + .id("determine-version"); + + let tag_output = StepOutput::new(&determine_version_step, "tag"); + + fn run_compliance_check(tag: &StepOutput) -> Step { + named::bash( + r#"cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report"#, + ) + .id("run-compliance-check") + .add_env(("LATEST_TAG", tag.to_string())) + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_failure_slack_notification(tag: &StepOutput) -> Step { + named::bash(indoc::indoc! {r#" + MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews." + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("failure()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(("LATEST_TAG", tag.to_string())) + } + + named::job( + Job::default() + .with_repository_owner_guard() + .runs_on(runners::LINUX_SMALL) + .add_step(steps::checkout_repo().with_full_history()) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(determine_version_step) + .add_step(run_compliance_check(&tag_output)) + .add_step(send_failure_slack_notification(&tag_output)), + ) +} diff --git a/tooling/xtask/src/tasks/workflows/release.rs b/tooling/xtask/src/tasks/workflows/release.rs index 4d7dc24d5e2d78cae87339877d730d3e3fb945b0..3efe3e7c5c127e8580a9ca22d2d0e1ab4e7c80e9 100644 --- a/tooling/xtask/src/tasks/workflows/release.rs +++ b/tooling/xtask/src/tasks/workflows/release.rs @@ -1,11 +1,13 @@ -use gh_workflow::{Event, Expression, Push, Run, Step, Use, Workflow, ctx::Context}; +use gh_workflow::{Event, Expression, Job, Push, Run, Step, Use, Workflow, ctx::Context}; use indoc::formatdoc; use crate::tasks::workflows::{ run_bundling::{bundle_linux, bundle_mac, bundle_windows}, run_tests, runners::{self, Arch, Platform}, - steps::{self, FluentBuilder, NamedJob, dependant_job, named, release_job}, + steps::{ + self, CommonJobConditions, FluentBuilder, NamedJob, dependant_job, named, release_job, + }, vars::{self, StepOutput, assets}, }; @@ -22,6 +24,7 @@ pub(crate) fn release() -> Workflow { let check_scripts = run_tests::check_scripts(); let create_draft_release = create_draft_release(); + let compliance = compliance_check(); let bundle = ReleaseBundleJobs { linux_aarch64: bundle_linux( @@ -92,6 +95,7 @@ pub(crate) fn release() -> Workflow { .add_job(windows_clippy.name, windows_clippy.job) .add_job(check_scripts.name, check_scripts.job) .add_job(create_draft_release.name, create_draft_release.job) + .add_job(compliance.name, compliance.job) .map(|mut workflow| { for job in bundle.into_jobs() { workflow = workflow.add_job(job.name, job.job); @@ -149,6 +153,59 @@ pub(crate) fn create_sentry_release() -> Step { .add_with(("environment", "production")) } +fn compliance_check() -> NamedJob { + fn run_compliance_check() -> Step { + named::bash( + r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT""#, + ) + .id("run-compliance-check") + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_compliance_slack_notification() -> Step { + named::bash(indoc::indoc! {r#" + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + STATUS="✅ Compliance check passed for $GITHUB_REF_NAME" + else + STATUS="❌ Compliance check failed for $GITHUB_REF_NAME" + fi + + REPORT_CONTENT="" + if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then + REPORT_CONTENT=$(cat "$REPORT_FILE") + fi + + MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT") + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("always()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(( + "COMPLIANCE_OUTCOME", + "${{ steps.run-compliance-check.outcome }}", + )) + } + + named::job( + Job::default() + .add_env(("COMPLIANCE_FILE_PATH", "compliance.md")) + .with_repository_owner_guard() + .runs_on(runners::LINUX_DEFAULT) + .add_step( + steps::checkout_repo() + .with_full_history() + .with_ref(Context::github().ref_()), + ) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(run_compliance_check()) + .add_step(send_compliance_slack_notification()), + ) +} + fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob { let expected_assets: Vec = assets::all().iter().map(|a| format!("\"{a}\"")).collect(); let expected_assets_json = format!("[{}]", expected_assets.join(", ")); @@ -171,10 +228,54 @@ fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob { "#, }; + fn run_post_upload_compliance_check() -> Step { + named::bash( + r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report"#, + ) + .id("run-post-upload-compliance-check") + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_post_upload_compliance_notification() -> Step { + named::bash(indoc::indoc! {r#" + if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then + echo "Compliance check was skipped, not sending notification" + exit 0 + fi + + TAG="$GITHUB_REF_NAME" + + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + MESSAGE="✅ Post-upload compliance re-check passed for $TAG" + else + MESSAGE="❌ Post-upload compliance re-check failed for $TAG" + fi + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("always()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(( + "COMPLIANCE_OUTCOME", + "${{ steps.run-post-upload-compliance-check.outcome }}", + )) + } + named::job( - dependant_job(deps).runs_on(runners::LINUX_SMALL).add_step( - named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)), - ), + dependant_job(deps) + .runs_on(runners::LINUX_SMALL) + .add_step(named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN))) + .add_step( + steps::checkout_repo() + .with_full_history() + .with_ref(Context::github().ref_()), + ) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(run_post_upload_compliance_check()) + .add_step(send_post_upload_compliance_notification()), ) } @@ -255,7 +356,7 @@ fn create_draft_release() -> NamedJob { .add_step( steps::checkout_repo() .with_custom_fetch_depth(25) - .with_ref("${{ github.ref }}"), + .with_ref(Context::github().ref_()), ) .add_step(steps::script("script/determine-release-channel")) .add_step(steps::script("mkdir -p target/")) From e2bba5526aad44206abe1f54db8a593b06ae34d3 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 6 Apr 2026 22:26:26 +0200 Subject: [PATCH 11/21] agent: Fix issue with streaming tools when model produces invalid JSON (#52891) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A --- .../agent/src/tests/edit_file_thread_test.rs | 211 ++++++ crates/agent/src/tests/mod.rs | 112 ++++ crates/agent/src/tests/test_tools.rs | 67 +- crates/agent/src/thread.rs | 312 +++++---- .../src/tools/streaming_edit_file_tool.rs | 621 +++++++++++------- crates/language_model/src/fake_provider.rs | 10 + 6 files changed, 959 insertions(+), 374 deletions(-) diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 3beb5cb0d51abc55fbf3cf0849ced248a9d1fa5c..b5ce6441e790e0b79b2798dfe0008cc74eec69b8 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); }); } + +#[gpui::test] +async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes( + cx: &mut TestAppContext, +) { + super::init_test(cx); + super::always_allow_tools(cx); + + // Enable the streaming edit file tool feature flag. + cx.update(|cx| { + cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + model.as_fake().set_supports_streaming_tools(true); + let fake_model = model.as_fake(); + + let thread = cx.new(|cx| { + let mut thread = crate::Thread::new( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + Some(model.clone()), + cx, + ); + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(crate::StreamingEditFileTool::new( + project.clone(), + cx.weak_entity(), + thread.action_log().clone(), + language_registry, + )); + thread + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Write new content to src/main.rs"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use_id = "edit_1"; + let partial_1 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1)); + cx.run_until_parked(); + + let partial_2 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2)); + cx.run_until_parked(); + + // Now send a json parse error. At this point we have started writing content to the buffer. + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use_id.into(), + tool_name: EditFileTool::NAME.into(), + raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 95".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // cx.executor().advance_clock(Duration::from_secs(5)); + // cx.run_until_parked(); + + assert!( + !fake_model.pending_completions().is_empty(), + "Thread should have retried after the error" + ); + + // Respond with a new, well-formed, complete edit_file tool use. + let tool_use = LanguageModelToolUse { + id: "edit_2".into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let pending_completions = fake_model.pending_completions(); + assert!( + pending_completions.len() == 1, + "Expected only the follow-up completion containing the successful tool result" + ); + + let completion = pending_completions + .into_iter() + .last() + .expect("Expected a completion containing the tool result for edit_2"); + + let tool_result = completion + .messages + .iter() + .flat_map(|msg| &msg.content) + .find_map(|content| match content { + language_model::MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") => + { + Some(result) + } + _ => None, + }) + .expect("Should have a tool result for edit_2"); + + // Ensure that the second tool call completed successfully and edits were applied. + assert!( + !tool_result.is_error, + "Tool result should succeed, got: {:?}", + tool_result + ); + let content_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(t) => t.to_string(), + other => panic!("Expected text content, got: {:?}", other), + }; + assert!( + !content_text.contains("file has been modified since you last read it"), + "Did not expect a stale last-read error, got: {content_text}" + ); + assert!( + !content_text.contains("This file has unsaved changes"), + "Did not expect an unsaved-changes error, got: {content_text}" + ); + + let file_content = fs + .load(path!("/project/src/main.rs").as_ref()) + .await + .expect("file should exist"); + super::assert_eq!( + file_content, + "fn main() {\n println!(\"Hello, rewritten!\");\n}\n", + "The second edit should be applied and saved gracefully" + ); + + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index f7b52b2573144e4c2fd378cfb19c9ee2473a37db..ff53136a0ded4bbc283fea30598d8d30e6e29709 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input( }); } +#[gpui::test] +async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool( + cx: &mut TestAppContext, +) { + init_test(cx); + always_allow_tools(cx); + + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(StreamingJsonErrorContextTool); + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Use the streaming_json_error_context tool"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + input: json!({"text": "partial"}), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: "tool_1".into(), + tool_name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 17".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + cx.executor().advance_clock(Duration::from_secs(5)); + cx.run_until_parked(); + + let completion = fake_model + .pending_completions() + .pop() + .expect("No running turn"); + + let tool_results: Vec<_> = completion + .messages + .iter() + .flat_map(|message| &message.content) + .filter_map(|content| match content { + MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") => + { + Some(result) + } + _ => None, + }) + .collect(); + + assert_eq!( + tool_results.len(), + 1, + "Expected exactly 1 tool result for tool_1, got {}: {:#?}", + tool_results.len(), + tool_results + ); + + let result = tool_results[0]; + assert!(result.is_error); + let content_text = match &result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + other => panic!("Expected text content, got {:?}", other), + }; + assert!( + content_text.contains("Saw partial text 'partial' before invalid JSON"), + "Expected tool-enriched partial context, got: {content_text}" + ); + assert!( + content_text + .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"), + "Expected forwarded JSON parse error, got: {content_text}" + ); + assert!( + !content_text.contains("tool input was not fully received"), + "Should not contain orphaned sender error, got: {content_text}" + ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _cx| { + assert!( + thread.is_turn_complete(), + "Thread should not be stuck; the turn should have completed", + ); + }); +} + /// Filters out the stop events for asserting against in tests fn stop_events(result_events: Vec>) -> Vec { result_events @@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { InfiniteTool::NAME: true, CancellationAwareTool::NAME: true, StreamingEchoTool::NAME: true, + StreamingJsonErrorContextTool::NAME: true, StreamingFailingEchoTool::NAME: true, TerminalTool::NAME: true, UpdatePlanTool::NAME: true, diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index f36549a6c42f9e810c7794d8ec683613b6ae6933..4744204fae1213d49af92339b8847e9d1f470125 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool { fn run( self: Arc, - mut input: ToolInput, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take(); cx.spawn(async move |_cx| { - while input.recv_partial().await.is_some() {} let input = input .recv() .await @@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool { } } +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct StreamingJsonErrorContextToolInput { + /// The text to echo. + pub text: String, +} + +pub struct StreamingJsonErrorContextTool; + +impl AgentTool for StreamingJsonErrorContextTool { + type Input = StreamingJsonErrorContextToolInput; + type Output = String; + + const NAME: &'static str = "streaming_json_error_context"; + + fn supports_input_streaming() -> bool { + true + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Streaming JSON Error Context".into() + } + + fn run( + self: Arc, + mut input: ToolInput, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |_cx| { + let mut last_partial_text = None; + + loop { + match input.next().await { + Ok(ToolInputPayload::Partial(partial)) => { + if let Some(text) = partial.get("text").and_then(|value| value.as_str()) { + last_partial_text = Some(text.to_string()); + } + } + Ok(ToolInputPayload::Full(input)) => return Ok(input.text), + Ok(ToolInputPayload::InvalidJson { error_message }) => { + let partial_text = last_partial_text.unwrap_or_default(); + return Err(format!( + "Saw partial text '{partial_text}' before invalid JSON: {error_message}" + )); + } + Err(error) => { + return Err(format!("Failed to receive tool input: {error}")); + } + } + } + }) + } +} + /// A streaming tool that echoes its input, used to test streaming tool /// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends /// before `is_input_complete`). @@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool { ) -> Task> { cx.spawn(async move |_cx| { for _ in 0..self.receive_chunks_until_failure { - let _ = input.recv_partial().await; + let _ = input.next().await; } Err("failed".into()) }) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index bcb5b7b2d2f3eb8cffd5be8b70fc08fef8e9fe37..ea342e8db4e4d97d5eccc849121cd0fd2e403017 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -22,13 +22,13 @@ use client::UserStore; use cloud_api_types::Plan; use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; -use futures::stream; use futures::{ FutureExt, channel::{mpsc, oneshot}, future::Shared, stream::FuturesUnordered, }; +use futures::{StreamExt, stream}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; @@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file}; -use smol::stream::StreamExt; use std::{ collections::BTreeMap, marker::PhantomData, @@ -2095,7 +2094,7 @@ impl Thread { this.update(cx, |this, _cx| { this.pending_message() .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); + .insert(tool_result.tool_use_id.clone(), tool_result) })?; Ok(()) } @@ -2195,15 +2194,15 @@ impl Thread { raw_input, json_parse_error, } => { - return Ok(Some(Task::ready( - self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, - event_stream, - ), - ))); + return Ok(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + event_stream, + cancellation_rx, + cx, + )); } UsageUpdate(usage) => { telemetry::event!( @@ -2304,12 +2303,12 @@ impl Thread { if !tool_use.is_input_complete { if tool.supports_input_streaming() { let running_turn = self.running_turn.as_mut()?; - if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) { + if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) { sender.send_partial(tool_use.input); return None; } - let (sender, tool_input) = ToolInputSender::channel(); + let (mut sender, tool_input) = ToolInputSender::channel(); sender.send_partial(tool_use.input); running_turn .streaming_tool_inputs @@ -2331,13 +2330,13 @@ impl Thread { } } - if let Some(sender) = self + if let Some(mut sender) = self .running_turn .as_mut()? .streaming_tool_inputs .remove(&tool_use.id) { - sender.send_final(tool_use.input); + sender.send_full(tool_use.input); return None; } @@ -2410,10 +2409,12 @@ impl Thread { raw_input: Arc, json_parse_error: String, event_stream: &ThreadEventStream, - ) -> LanguageModelToolResult { + cancellation_rx: watch::Receiver, + cx: &mut Context, + ) -> Option> { let tool_use = LanguageModelToolUse { - id: tool_use_id.clone(), - name: tool_name.clone(), + id: tool_use_id, + name: tool_name, raw_input: raw_input.to_string(), input: serde_json::json!({}), is_input_complete: true, @@ -2426,14 +2427,43 @@ impl Thread { event_stream, ); - let tool_output = format!("Error parsing input JSON: {json_parse_error}"); - LanguageModelToolResult { - tool_use_id, - tool_name, - is_error: true, - content: LanguageModelToolResultContent::Text(tool_output.into()), - output: Some(serde_json::Value::String(raw_input.to_string())), + let tool = self.tool(tool_use.name.as_ref()); + + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + + let error_message = format!("Error parsing input JSON: {json_parse_error}"); + + if tool.supports_input_streaming() + && let Some(mut sender) = self + .running_turn + .as_mut()? + .streaming_tool_inputs + .remove(&tool_use.id) + { + sender.send_invalid_json(error_message); + return None; } + + log::debug!("Running tool {}. Received invalid JSON", tool_use.name); + let tool_input = ToolInput::invalid_json(error_message); + Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )) } fn send_or_update_tool_use( @@ -3114,8 +3144,7 @@ impl EventEmitter for Thread {} /// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams /// them, followed by the final complete input available through `.recv()`. pub struct ToolInput { - partial_rx: mpsc::UnboundedReceiver, - final_rx: oneshot::Receiver, + rx: mpsc::UnboundedReceiver>, _phantom: PhantomData, } @@ -3127,13 +3156,20 @@ impl ToolInput { } pub fn ready(value: serde_json::Value) -> Self { - let (partial_tx, partial_rx) = mpsc::unbounded(); - drop(partial_tx); - let (final_tx, final_rx) = oneshot::channel(); - final_tx.send(value).ok(); + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::Full(value)).ok(); Self { - partial_rx, - final_rx, + rx, + _phantom: PhantomData, + } + } + + pub fn invalid_json(error_message: String) -> Self { + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); + Self { + rx, _phantom: PhantomData, } } @@ -3147,65 +3183,89 @@ impl ToolInput { /// Wait for the final deserialized input, ignoring all partial updates. /// Non-streaming tools can use this to wait until the whole input is available. pub async fn recv(mut self) -> Result { - // Drain any remaining partials - while self.partial_rx.next().await.is_some() {} + while let Ok(value) = self.next().await { + match value { + ToolInputPayload::Full(value) => return Ok(value), + ToolInputPayload::Partial(_) => {} + ToolInputPayload::InvalidJson { error_message } => { + return Err(anyhow!(error_message)); + } + } + } + Err(anyhow!("tool input was not fully received")) + } + + pub async fn next(&mut self) -> Result> { let value = self - .final_rx + .rx + .next() .await - .map_err(|_| anyhow!("tool input was not fully received"))?; - serde_json::from_value(value).map_err(Into::into) - } + .ok_or_else(|| anyhow!("tool input was not fully received"))?; - /// Returns the next partial JSON snapshot, or `None` when input is complete. - /// Once this returns `None`, call `recv()` to get the final input. - pub async fn recv_partial(&mut self) -> Option { - self.partial_rx.next().await + Ok(match value { + ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload), + ToolInputPayload::Full(payload) => { + ToolInputPayload::Full(serde_json::from_value(payload)?) + } + ToolInputPayload::InvalidJson { error_message } => { + ToolInputPayload::InvalidJson { error_message } + } + }) } fn cast(self) -> ToolInput { ToolInput { - partial_rx: self.partial_rx, - final_rx: self.final_rx, + rx: self.rx, _phantom: PhantomData, } } } +pub enum ToolInputPayload { + Partial(serde_json::Value), + Full(T), + InvalidJson { error_message: String }, +} + pub struct ToolInputSender { - partial_tx: mpsc::UnboundedSender, - final_tx: Option>, + has_received_final: bool, + tx: mpsc::UnboundedSender>, } impl ToolInputSender { pub(crate) fn channel() -> (Self, ToolInput) { - let (partial_tx, partial_rx) = mpsc::unbounded(); - let (final_tx, final_rx) = oneshot::channel(); + let (tx, rx) = mpsc::unbounded(); let sender = Self { - partial_tx, - final_tx: Some(final_tx), + tx, + has_received_final: false, }; let input = ToolInput { - partial_rx, - final_rx, + rx, _phantom: PhantomData, }; (sender, input) } pub(crate) fn has_received_final(&self) -> bool { - self.final_tx.is_none() + self.has_received_final } - pub(crate) fn send_partial(&self, value: serde_json::Value) { - self.partial_tx.unbounded_send(value).ok(); + pub fn send_partial(&mut self, payload: serde_json::Value) { + self.tx + .unbounded_send(ToolInputPayload::Partial(payload)) + .ok(); } - pub(crate) fn send_final(mut self, value: serde_json::Value) { - // Close the partial channel so recv_partial() returns None - self.partial_tx.close_channel(); - if let Some(final_tx) = self.final_tx.take() { - final_tx.send(value).ok(); - } + pub fn send_full(&mut self, payload: serde_json::Value) { + self.has_received_final = true; + self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok(); + } + + pub fn send_invalid_json(&mut self, error_message: String) { + self.has_received_final = true; + self.tx + .unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); } } @@ -4251,68 +4311,78 @@ mod tests { ) { let (thread, event_stream) = setup_thread_for_test(cx).await; - cx.update(|cx| { - thread.update(cx, |thread, _cx| { - let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); - let tool_name: Arc = Arc::from("test_tool"); - let raw_input: Arc = Arc::from("{invalid json"); - let json_parse_error = "expected value at line 1 column 1".to_string(); - - // Call the function under test - let result = thread.handle_tool_use_json_parse_error_event( - tool_use_id.clone(), - tool_name.clone(), - raw_input.clone(), - json_parse_error, - &event_stream, - ); - - // Verify the result is an error - assert!(result.is_error); - assert_eq!(result.tool_use_id, tool_use_id); - assert_eq!(result.tool_name, tool_name); - assert!(matches!( - result.content, - LanguageModelToolResultContent::Text(_) - )); - - // Verify the tool use was added to the message content - { - let last_message = thread.pending_message(); - assert_eq!( - last_message.content.len(), - 1, - "Should have one tool_use in content" - ); - - match &last_message.content[0] { - AgentMessageContent::ToolUse(tool_use) => { - assert_eq!(tool_use.id, tool_use_id); - assert_eq!(tool_use.name, tool_name); - assert_eq!(tool_use.raw_input, raw_input.to_string()); - assert!(tool_use.is_input_complete); - // Should fall back to empty object for invalid JSON - assert_eq!(tool_use.input, json!({})); - } - _ => panic!("Expected ToolUse content"), - } - } - - // Insert the tool result (simulating what the caller does) - thread - .pending_message() - .tool_results - .insert(result.tool_use_id.clone(), result); + let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); + let tool_name: Arc = Arc::from("test_tool"); + let raw_input: Arc = Arc::from("{invalid json"); + let json_parse_error = "expected value at line 1 column 1".to_string(); + + let (_cancellation_tx, cancellation_rx) = watch::channel(false); + + let result = cx + .update(|cx| { + thread.update(cx, |thread, cx| { + // Call the function under test + thread + .handle_tool_use_json_parse_error_event( + tool_use_id.clone(), + tool_name.clone(), + raw_input.clone(), + json_parse_error, + &event_stream, + cancellation_rx, + cx, + ) + .unwrap() + }) + }) + .await; + + // Verify the result is an error + assert!(result.is_error); + assert_eq!(result.tool_use_id, tool_use_id); + assert_eq!(result.tool_name, tool_name); + assert!(matches!( + result.content, + LanguageModelToolResultContent::Text(_) + )); - // Verify the tool result was added + thread.update(cx, |thread, _cx| { + // Verify the tool use was added to the message content + { let last_message = thread.pending_message(); assert_eq!( - last_message.tool_results.len(), + last_message.content.len(), 1, - "Should have one tool_result" + "Should have one tool_use in content" ); - assert!(last_message.tool_results.contains_key(&tool_use_id)); - }); - }); + + match &last_message.content[0] { + AgentMessageContent::ToolUse(tool_use) => { + assert_eq!(tool_use.id, tool_use_id); + assert_eq!(tool_use.name, tool_name); + assert_eq!(tool_use.raw_input, raw_input.to_string()); + assert!(tool_use.is_input_complete); + // Should fall back to empty object for invalid JSON + assert_eq!(tool_use.input, json!({})); + } + _ => panic!("Expected ToolUse content"), + } + } + + // Insert the tool result (simulating what the caller does) + thread + .pending_message() + .tool_results + .insert(result.tool_use_id.clone(), result); + + // Verify the tool result was added + let last_message = thread.pending_message(); + assert_eq!( + last_message.tool_results.len(), + 1, + "Should have one tool_result" + ); + assert!(last_message.tool_results.contains_key(&tool_use_id)); + }) } } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index bc99515e499696e3df11101be8b813afa027c8f4..47da35bbf25ad188f3f6b98e843b2955910bb7ac 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool; use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use super::tool_edit_parser::{ToolEditEvent, ToolEditParser}; +use crate::ToolInputPayload; use crate::{ AgentTool, Thread, ToolCallEventStream, ToolInput, edit_agent::{ @@ -12,7 +13,7 @@ use crate::{ use acp_thread::Diff; use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; -use anyhow::{Context as _, Result}; +use anyhow::Result; use collections::HashSet; use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; @@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput { }, Error { error: String, + #[serde(default)] + input_path: Option, + #[serde(default)] + diff: String, }, } @@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput { pub fn error(error: impl Into) -> Self { Self::Error { error: error.into(), + input_path: None, + diff: String::new(), } } } @@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput { ) } } - StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"), + StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } => { + write!(f, "{error}\n")?; + if let Some(input_path) = input_path + && !diff.is_empty() + { + write!( + f, + "Edited {}:\n\n```diff\n{diff}\n```", + input_path.display() + ) + } else { + write!(f, "No edits were made.") + } + } } } } @@ -233,6 +257,14 @@ pub struct StreamingEditFileTool { language_registry: Arc, } +enum EditSessionResult { + Completed(EditSession), + Failed { + error: String, + session: Option, + }, +} + impl StreamingEditFileTool { pub fn new( project: Entity, @@ -276,6 +308,158 @@ impl StreamingEditFileTool { }); } } + + async fn ensure_buffer_saved(&self, buffer: &Entity, cx: &mut AsyncApp) { + let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); + settings.format_on_save != FormatOnSave::Off + }); + + if format_on_save_enabled { + self.project + .update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + }) + .await + .log_err(); + } + + self.project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .log_err(); + + self.action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + } + + async fn process_streaming_edits( + &self, + input: &mut ToolInput, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> EditSessionResult { + let mut session: Option = None; + let mut last_partial: Option = None; + + loop { + futures::select! { + payload = input.next().fuse() => { + match payload { + Ok(payload) => match payload { + ToolInputPayload::Partial(partial) => { + if let Ok(parsed) = serde_json::from_value::(partial) { + let path_complete = parsed.path.is_some() + && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref()); + + last_partial = Some(parsed.clone()); + + if session.is_none() + && path_complete + && let StreamingEditFileToolPartialInput { + path: Some(path), + display_description: Some(display_description), + mode: Some(mode), + .. + } = &parsed + { + match EditSession::new( + PathBuf::from(path), + display_description, + *mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => session = Some(created_session), + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + } + + if let Some(current_session) = &mut session + && let Err(error) = current_session.process(parsed, self, event_stream, cx) + { + log::error!("Failed to process edit: {}", error); + return EditSessionResult::Failed { error, session }; + } + } + } + ToolInputPayload::Full(full_input) => { + let mut session = if let Some(session) = session { + session + } else { + match EditSession::new( + full_input.path.clone(), + &full_input.display_description, + full_input.mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => created_session, + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + }; + + return match session.finalize(full_input, self, event_stream, cx).await { + Ok(()) => EditSessionResult::Completed(session), + Err(error) => { + log::error!("Failed to finalize edit: {}", error); + EditSessionResult::Failed { + error, + session: Some(session), + } + } + }; + } + ToolInputPayload::InvalidJson { error_message } => { + log::error!("Received invalid JSON: {error_message}"); + return EditSessionResult::Failed { + error: error_message, + session, + }; + } + }, + Err(error) => { + return EditSessionResult::Failed { + error: format!("Failed to receive tool input: {error}"), + session, + }; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + return EditSessionResult::Failed { + error: "Edit cancelled by user".to_string(), + session, + }; + } + } + } + } } impl AgentTool for StreamingEditFileTool { @@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool { cx: &mut App, ) -> Task> { cx.spawn(async move |cx: &mut AsyncApp| { - let mut state: Option = None; - let mut last_partial: Option = None; - loop { - futures::select! { - partial = input.recv_partial().fuse() => { - let Some(partial_value) = partial else { break }; - if let Ok(parsed) = serde_json::from_value::(partial_value) { - let path_complete = parsed.path.is_some() - && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref()); - - last_partial = Some(parsed.clone()); - - if state.is_none() - && path_complete - && let StreamingEditFileToolPartialInput { - path: Some(path), - display_description: Some(display_description), - mode: Some(mode), - .. - } = &parsed - { - match EditSession::new( - &PathBuf::from(path), - display_description, - *mode, - &self, - &event_stream, - cx, - ) - .await - { - Ok(session) => state = Some(session), - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } - } - } - - if let Some(state) = &mut state { - if let Err(e) = state.process(parsed, &self, &event_stream, cx) { - log::error!("Failed to process edit: {}", e); - return Err(e); - } - } - } - } - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - } - } - let full_input = - input - .recv() - .await - .map_err(|e| { - let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")); - log::error!("Failed to receive tool input: {e}"); - err - })?; - - let mut state = if let Some(state) = state { - state - } else { - match EditSession::new( - &full_input.path, - &full_input.display_description, - full_input.mode, - &self, - &event_stream, - cx, - ) + match self + .process_streaming_edits(&mut input, &event_stream, cx) .await - { - Ok(session) => session, - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } + { + EditSessionResult::Completed(session) => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (new_text, diff) = session.compute_new_text_and_diff(cx).await; + Ok(StreamingEditFileToolOutput::Success { + old_text: session.old_text.clone(), + new_text, + input_path: session.input_path, + diff, + }) } - }; - match state.finalize(full_input, &self, &event_stream, cx).await { - Ok(output) => Ok(output), - Err(e) => { - log::error!("Failed to finalize edit: {}", e); - Err(e) + EditSessionResult::Failed { + error, + session: Some(session), + } => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (_new_text, diff) = session.compute_new_text_and_diff(cx).await; + Err(StreamingEditFileToolOutput::Error { + error, + input_path: Some(session.input_path), + diff, + }) } + EditSessionResult::Failed { + error, + session: None, + } => Err(StreamingEditFileToolOutput::Error { + error, + input_path: None, + diff: String::new(), + }), } }) } @@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool { pub struct EditSession { abs_path: PathBuf, + input_path: PathBuf, buffer: Entity, old_text: Arc, diff: Entity, @@ -518,23 +649,21 @@ impl EditPipeline { impl EditSession { async fn new( - path: &PathBuf, + path: PathBuf, display_description: &str, mode: StreamingEditFileMode, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let project_path = cx - .update(|cx| resolve_path(mode, &path, &tool.project, cx)) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + ) -> Result { + let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?; let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) else { - return Err(StreamingEditFileToolOutput::error(format!( + return Err(format!( "Worktree at '{}' does not exist", path.to_string_lossy() - ))); + )); }; event_stream.update_fields( @@ -543,13 +672,13 @@ impl EditSession { cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; let buffer = tool .project .update(cx, |project, cx| project.open_buffer(project_path, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; @@ -578,6 +707,7 @@ impl EditSession { Ok(Self { abs_path, + input_path: path, buffer, old_text, diff, @@ -594,22 +724,20 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let old_text = self.old_text.clone(); - + ) -> Result<(), String> { match input.mode { StreamingEditFileMode::Write => { - let content = input.content.ok_or_else(|| { - StreamingEditFileToolOutput::error("'content' field is required for write mode") - })?; + let content = input + .content + .ok_or_else(|| "'content' field is required for write mode".to_string())?; let events = self.parser.finalize_content(&content); self.process_events(&events, tool, event_stream, cx)?; } StreamingEditFileMode::Edit => { - let edits = input.edits.ok_or_else(|| { - StreamingEditFileToolOutput::error("'edits' field is required for edit mode") - })?; + let edits = input + .edits + .ok_or_else(|| "'edits' field is required for edit mode".to_string())?; let events = self.parser.finalize_edits(&edits); self.process_events(&events, tool, event_stream, cx)?; @@ -625,53 +753,15 @@ impl EditSession { } } } + Ok(()) + } - let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| { - let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); - settings.format_on_save != FormatOnSave::Off - }); - - if format_on_save_enabled { - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - - let format_task = tool.project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([self.buffer.clone()]), - LspFormatTarget::Buffers, - false, - FormatTrigger::Save, - cx, - ) - }); - futures::select! { - result = format_task.fuse() => { result.log_err(); }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - } - - let save_task = tool.project.update(cx, |project, cx| { - project.save_buffer(self.buffer.clone(), cx) - }); - futures::select! { - result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - + async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) { let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); + let old_text = self.old_text.clone(); async move { let new_text = new_snapshot.text(); let diff = language::unified_diff(&old_text, &new_text); @@ -679,14 +769,7 @@ impl EditSession { } }) .await; - - let output = StreamingEditFileToolOutput::Success { - input_path: input.path, - new_text, - old_text: old_text.clone(), - diff: unified_diff, - }; - Ok(output) + (new_text, unified_diff) } fn process( @@ -695,7 +778,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { match &self.mode { StreamingEditFileMode::Write => { if let Some(content) = &partial.content { @@ -719,7 +802,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { for event in events { match event { ToolEditEvent::ContentChunk { chunk } => { @@ -969,14 +1052,14 @@ fn extract_match( buffer: &Entity, edit_index: &usize, cx: &mut AsyncApp, -) -> Result, StreamingEditFileToolOutput> { +) -> Result, String> { match matches.len() { - 0 => Err(StreamingEditFileToolOutput::error(format!( + 0 => Err(format!( "Could not find matching text for edit at index {}. \ The old_text did not match any content in the file. \ Please read the file again to get the current content.", edit_index, - ))), + )), 1 => Ok(matches.into_iter().next().unwrap()), _ => { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); @@ -985,12 +1068,12 @@ fn extract_match( .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) .collect::>() .join(", "); - Err(StreamingEditFileToolOutput::error(format!( + Err(format!( "Edit {} matched multiple locations in the file at lines: {}. \ Please provide more context in old_text to uniquely \ identify the location.", edit_index, lines - ))) + )) } } } @@ -1022,7 +1105,7 @@ fn ensure_buffer_saved( abs_path: &PathBuf, tool: &StreamingEditFileTool, cx: &mut AsyncApp, -) -> Result<(), StreamingEditFileToolOutput> { +) -> Result<(), String> { let last_read_mtime = tool .action_log .read_with(cx, |log, _| log.file_read_time(abs_path)); @@ -1063,15 +1146,14 @@ fn ensure_buffer_saved( then ask them to save or revert the file manually and inform you when it's ok to proceed." } }; - return Err(StreamingEditFileToolOutput::error(message)); + return Err(message.to_string()); } if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) { if current != last_read { - return Err(StreamingEditFileToolOutput::error( - "The file has been modified since you last read it. \ - Please read the file again to get the current state before editing it.", - )); + return Err("The file has been modified since you last read it. \ + Please read the file again to get the current state before editing it." + .to_string()); } } @@ -1083,56 +1165,63 @@ fn resolve_path( path: &PathBuf, project: &Entity, cx: &mut App, -) -> Result { +) -> Result { let project = project.read(cx); match mode { StreamingEditFileMode::Edit => { let path = project .find_project_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; let entry = project .entry_for_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; - anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); - Ok(path) + if entry.is_file() { + Ok(path) + } else { + Err("Can't edit file: path is a directory".to_string()) + } } StreamingEditFileMode::Write => { if let Some(path) = project.find_project_path(&path, cx) && let Some(entry) = project.entry_for_path(&path, cx) { - anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory"); - return Ok(path); + if entry.is_file() { + return Ok(path); + } else { + return Err("Can't write to file: path is a directory".to_string()); + } } - let parent_path = path.parent().context("Can't create file: incorrect path")?; + let parent_path = path + .parent() + .ok_or_else(|| "Can't create file: incorrect path".to_string())?; let parent_project_path = project.find_project_path(&parent_path, cx); let parent_entry = parent_project_path .as_ref() .and_then(|path| project.entry_for_path(path, cx)) - .context("Can't create file: parent directory doesn't exist")?; + .ok_or_else(|| "Can't create file: parent directory doesn't exist")?; - anyhow::ensure!( - parent_entry.is_dir(), - "Can't create file: parent is not a directory" - ); + if !parent_entry.is_dir() { + return Err("Can't create file: parent is not a directory".to_string()); + } let file_name = path .file_name() .and_then(|file_name| file_name.to_str()) .and_then(|file_name| RelPath::unix(file_name).ok()) - .context("Can't create file: invalid filename")?; + .ok_or_else(|| "Can't create file: invalid filename".to_string())?; let new_file_path = parent_project_path.map(|parent| ProjectPath { path: parent.path.join(file_name), ..parent }); - new_file_path.context("Can't create file") + new_file_path.ok_or_else(|| "Can't create file".to_string()) } } } @@ -1382,10 +1471,17 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert_eq!(error, "Can't edit file: path not found"); + assert!(diff.is_empty()); + assert_eq!(input_path, None); } #[gpui::test] @@ -1411,7 +1507,7 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1424,7 +1520,7 @@ mod tests { async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1447,7 +1543,7 @@ mod tests { cx.run_until_parked(); // Now send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1465,7 +1561,7 @@ mod tests { async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1485,7 +1581,7 @@ mod tests { cx.run_until_parked(); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -1503,7 +1599,7 @@ mod tests { async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1521,7 +1617,7 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1537,7 +1633,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1578,7 +1674,7 @@ mod tests { cx.run_until_parked(); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1601,7 +1697,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1625,7 +1721,7 @@ mod tests { cx.run_until_parked(); // Final with full content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -1643,12 +1739,12 @@ mod tests { async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1669,7 +1765,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1739,7 +1835,7 @@ mod tests { ); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1767,7 +1863,7 @@ mod tests { async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1835,7 +1931,7 @@ mod tests { assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n")); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit three lines", "path": "root/file.txt", "mode": "edit", @@ -1857,7 +1953,7 @@ mod tests { async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1893,16 +1989,17 @@ mod tests { })); cx.run_until_parked(); - // Verify edit 1 was applied - let buffer_text = project.update(cx, |project, cx| { + let buffer = project.update(cx, |project, cx| { let pp = project .find_project_path(&PathBuf::from("root/file.txt"), cx) .unwrap(); - project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + project.get_open_buffer(&pp, cx).unwrap() }); + + // Verify edit 1 was applied + let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text()); assert_eq!( - buffer_text.as_deref(), - Some("MODIFIED\nline 2\nline 3\n"), + buffer_text, "MODIFIED\nline 2\nline 3\n", "First edit should be applied even though second edit will fail" ); @@ -1925,20 +2022,32 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("Could not find matching text for edit at index 1"), "Expected error about edit 1 failing, got: {error}" ); + // Ensure that first edit was applied successfully and that we saved the buffer + assert_eq!(input_path, Some(PathBuf::from("root/file.txt"))); + assert_eq!( + diff, + "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n" + ); } #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1975,7 +2084,7 @@ mod tests { ); // Send final — the edit is applied during finalization - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Single edit", "path": "root/file.txt", "mode": "edit", @@ -1993,7 +2102,7 @@ mod tests { async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2020,7 +2129,7 @@ mod tests { cx.run_until_parked(); // Send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -2038,7 +2147,7 @@ mod tests { async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2064,7 +2173,7 @@ mod tests { // Create a channel and send multiple partials before a final, then use // ToolInput::resolved-style immediate delivery to confirm recv() works // when partials are already buffered. - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2077,7 +2186,7 @@ mod tests { "path": "root/dir/new.txt", "mode": "write" })); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create", "path": "root/dir/new.txt", "mode": "write", @@ -2109,13 +2218,13 @@ mod tests { let result = test_resolve_path(&mode, "root/dir/subdir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't write to file: path is a directory" ); let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't create file: parent directory doesn't exist" ); } @@ -2133,14 +2242,11 @@ mod tests { assert_resolved_path_eq(result.await, rel_path(path_without_root)); let result = test_resolve_path(&mode, "root/nonexistent.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't edit file: path not found" - ); + assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found"); let result = test_resolve_path(&mode, "root/dir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't edit file: path is a directory" ); } @@ -2149,7 +2255,7 @@ mod tests { mode: &StreamingEditFileMode, path: &str, cx: &mut TestAppContext, - ) -> anyhow::Result { + ) -> Result { init_test(cx); let fs = project::FakeFs::new(cx.executor()); @@ -2170,7 +2276,7 @@ mod tests { } #[track_caller] - fn assert_resolved_path_eq(path: anyhow::Result, expected: &RelPath) { + fn assert_resolved_path_eq(path: Result, expected: &RelPath) { let actual = path.expect("Should return valid path").path; assert_eq!(actual.as_ref(), expected); } @@ -2259,7 +2365,7 @@ mod tests { }); // Use streaming pattern so executor can pump the LSP request/response - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2271,7 +2377,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create main function", "path": "root/src/main.rs", "mode": "write", @@ -2310,7 +2416,7 @@ mod tests { }); }); - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let tool2 = Arc::new(StreamingEditFileTool::new( @@ -2329,7 +2435,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Update main function", "path": "root/src/main.rs", "mode": "write", @@ -3288,14 +3394,22 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("has been modified since you last read it"), "Error should mention file modification, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3362,7 +3476,12 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert!( @@ -3380,6 +3499,8 @@ mod tests { "Error should ask user to manually save or revert when tools aren't available, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3390,7 +3511,7 @@ mod tests { // the modified buffer and succeeds. let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3420,7 +3541,7 @@ mod tests { cx.run_until_parked(); // Send the final input with all three edits. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overlapping edits", "path": "root/file.txt", "mode": "edit", @@ -3441,7 +3562,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3495,7 +3616,7 @@ mod tests { ); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -3516,7 +3637,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, mut receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3559,7 +3680,7 @@ mod tests { }); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3587,7 +3708,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3634,7 +3755,7 @@ mod tests { ); // Send final input with complete content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3656,7 +3777,7 @@ mod tests { async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3690,7 +3811,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3708,7 +3829,7 @@ mod tests { async fn test_streaming_final_input_stringified_edits_succeeds(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3719,7 +3840,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3823,7 +3944,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3849,7 +3970,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "write", "content": "new_content", @@ -3869,7 +3990,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3902,7 +4023,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}], @@ -3939,11 +4060,11 @@ mod tests { let old_text = "}\n\n\n\nfn render_search"; let new_text = "}\n\nfn render_search"; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Remove extra blank lines", "path": "root/file.rs", "mode": "edit", @@ -3980,11 +4101,11 @@ mod tests { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.rs": file_content})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "description", "path": "root/file.rs", "mode": "edit", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index ae01084a2657abdc86e7510aa49663cf98aabe70..50037f31facbac446de7ecf38536d1e4a24c7867 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -125,6 +125,7 @@ pub struct FakeLanguageModel { >, forbid_requests: AtomicBool, supports_thinking: AtomicBool, + supports_streaming_tools: AtomicBool, } impl Default for FakeLanguageModel { @@ -137,6 +138,7 @@ impl Default for FakeLanguageModel { current_completion_txs: Mutex::new(Vec::new()), forbid_requests: AtomicBool::new(false), supports_thinking: AtomicBool::new(false), + supports_streaming_tools: AtomicBool::new(false), } } } @@ -169,6 +171,10 @@ impl FakeLanguageModel { self.supports_thinking.store(supports, SeqCst); } + pub fn set_supports_streaming_tools(&self, supports: bool) { + self.supports_streaming_tools.store(supports, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel { self.supports_thinking.load(SeqCst) } + fn supports_streaming_tools(&self) -> bool { + self.supports_streaming_tools.load(SeqCst) + } + fn telemetry_id(&self) -> String { "fake".to_string() } From f0df39311df651b7cc4a6335751e528601b41e3f Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Mon, 6 Apr 2026 15:29:37 -0500 Subject: [PATCH 12/21] Consolidate prompt formatting logic into `zeta_prompt` (#53079) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [ ] Unsafe blocks (if any) have justifying comments - [ ] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ... --- Cargo.lock | 1 + .../src/edit_prediction_tests.rs | 59 ++++++ crates/edit_prediction/src/example_spec.rs | 111 +--------- crates/edit_prediction/src/zeta.rs | 22 +- .../edit_prediction_cli/src/parse_output.rs | 43 +--- crates/zeta_prompt/Cargo.toml | 1 + crates/zeta_prompt/src/udiff.rs | 200 ++++++++++++++++++ crates/zeta_prompt/src/zeta_prompt.rs | 157 +++++++++++++- 8 files changed, 434 insertions(+), 160 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f7597693960b2c9e66121794f9c99cdb8d6ddcea..e1a5a11ad0c0549791545cd7e020e283decb5b53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22542,6 +22542,7 @@ name = "zeta_prompt" version = "0.1.0" dependencies = [ "anyhow", + "imara-diff", "indoc", "serde", "strum 0.27.2", diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 1ba8b27aa785024a47a09c3299a1f3786a028ccf..ea7233cd976148f5eb726730635e0efaf6ceef86 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte }); } +#[gpui::test] +async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + "/root", + json!({ + "foo.txt": "hello" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("root/foo.txt"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(0, 5)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let excerpt_length = request.input.cursor_excerpt.len(); + respond_tx + .send(PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: "hello<|user_cursor|> world".to_string(), + editable_range: 0..excerpt_length, + model_version: None, + }) + .unwrap(); + + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + let prediction = ep_store + .prediction_at(&buffer, None, &project, cx) + .expect("should have prediction"); + let snapshot = buffer.read(cx).snapshot(); + let edits: Vec<_> = prediction + .edits + .iter() + .map(|(range, text)| (range.to_offset(&snapshot), text.clone())) + .collect(); + + assert_eq!(edits, vec![(5..5, " world".into())]); + }); +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 4486cde22c3429568bf29f152d0f5f2ded59e8f4..a7da51173eefbcdb9e014f7dcca917e6ebebebf5 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,10 +1,11 @@ -use crate::udiff::DiffLine; use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; use telemetry_events::EditPredictionRating; -pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; +pub use zeta_prompt::udiff::{ + CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch, +}; pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// Maximum cursor file size to capture (64KB). @@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// falling back to git-based loading. pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024; -/// Encodes a cursor position into a diff patch by adding a comment line with a caret -/// pointing to the cursor column. -/// -/// The cursor offset is relative to the start of the new text content (additions and context lines). -/// Returns the patch with cursor marker comment lines inserted after the relevant addition line. -pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { - let Some(cursor_offset) = cursor_offset else { - return patch.to_string(); - }; - - let mut result = String::new(); - let mut line_start_offset = 0usize; - - for line in patch.lines() { - if matches!( - DiffLine::parse(line), - DiffLine::Garbage(content) - if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) - ) { - continue; - } - - if !result.is_empty() { - result.push('\n'); - } - result.push_str(line); - - match DiffLine::parse(line) { - DiffLine::Addition(content) => { - let line_end_offset = line_start_offset + content.len(); - - if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { - let cursor_column = cursor_offset - line_start_offset; - - result.push('\n'); - result.push('#'); - for _ in 0..cursor_column { - result.push(' '); - } - write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); - } - - line_start_offset = line_end_offset + 1; - } - DiffLine::Context(content) => { - line_start_offset += content.len() + 1; - } - _ => {} - } - } - - if patch.ends_with('\n') { - result.push('\n'); - } - - result -} - #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] pub struct ExampleSpec { #[serde(default)] @@ -509,53 +452,7 @@ impl ExampleSpec { pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option)> { self.expected_patches .iter() - .map(|patch| { - let mut clean_patch = String::new(); - let mut cursor_offset: Option = None; - let mut line_start_offset = 0usize; - let mut prev_line_start_offset = 0usize; - - for line in patch.lines() { - let diff_line = DiffLine::parse(line); - - match &diff_line { - DiffLine::Garbage(content) - if content.starts_with('#') - && content.contains(CURSOR_POSITION_MARKER) => - { - let caret_column = if let Some(caret_pos) = content.find('^') { - caret_pos - } else if let Some(_) = content.find('<') { - 0 - } else { - continue; - }; - let cursor_column = caret_column.saturating_sub('#'.len_utf8()); - cursor_offset = Some(prev_line_start_offset + cursor_column); - } - _ => { - if !clean_patch.is_empty() { - clean_patch.push('\n'); - } - clean_patch.push_str(line); - - match diff_line { - DiffLine::Addition(content) | DiffLine::Context(content) => { - prev_line_start_offset = line_start_offset; - line_start_offset += content.len() + 1; - } - _ => {} - } - } - } - } - - if patch.ends_with('\n') && !clean_patch.is_empty() { - clean_patch.push('\n'); - } - - (clean_patch, cursor_offset) - }) + .map(|patch| extract_cursor_from_patch(patch)) .collect() } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index fdfe3ebcf06c8319f5ce00066fa279d79eda7eea..b4556e58b9247624e2d4caeddb5614ff5000d854 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput}; use std::{env, ops::Range, path::Path, sync::Arc}; use zeta_prompt::{ - CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, - prompt_input_contains_special_tokens, stop_tokens_for_format, + ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, + parsed_output_from_editable_region, prompt_input_contains_special_tokens, + stop_tokens_for_format, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; @@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta( let parsed_output = output_text.map(|text| ParsedOutput { new_editable_region: text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: None, }); (request_id, parsed_output, None, None) @@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(response.request_id.into()); let output_text = Some(response.output).filter(|s| !s.is_empty()); let model_version = response.model_version; - let parsed_output = ParsedOutput { - new_editable_region: output_text.unwrap_or_default(), - range_in_excerpt: response.editable_range, - }; + let parsed_output = parsed_output_from_editable_region( + response.editable_range, + output_text.unwrap_or_default(), + ); Some((request_id, Some(parsed_output), model_version, usage)) }) @@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta( let Some(ParsedOutput { new_editable_region: mut output_text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: cursor_offset_in_output, }) = output else { return Ok((Some((request_id, None)), None)); @@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta( .text_for_range(editable_range_in_buffer.clone()) .collect::(); - // Client-side cursor marker processing (applies to both raw and v3 responses) - let cursor_offset_in_output = output_text.find(CURSOR_MARKER); - if let Some(offset) = cursor_offset_in_output { - log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); - output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - } - if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(DebugEvent::EditPredictionFinished( diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 2b41384e176ac7a6cc5c3dc7f93ddbba3cf027ae..fc85afa371a4edfe8080d602000c38ecedb98c86 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -5,8 +5,7 @@ use crate::{ repair, }; use anyhow::{Context as _, Result}; -use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output}; +use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -65,46 +64,18 @@ fn parse_zeta2_output( .context("prompt_inputs required")?; let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?; - let range_in_excerpt = parsed.range_in_excerpt; - + let range_in_excerpt = parsed.range_in_excerpt.clone(); let excerpt = prompt_inputs.cursor_excerpt.as_ref(); - let old_text = excerpt[range_in_excerpt.clone()].to_string(); - let mut new_text = parsed.new_editable_region; - - let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { - new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - Some(offset) - } else { - None - }; + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - // Normalize trailing newlines for diff generation - let mut old_text_normalized = old_text; + let mut new_text = parsed.new_editable_region.clone(); if !new_text.is_empty() && !new_text.ends_with('\n') { new_text.push('\n'); } - if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { - old_text_normalized.push('\n'); - } - - let editable_region_offset = range_in_excerpt.start; - let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - let editable_region_lines = old_text_normalized.lines().count() as u32; - - let diff = language::unified_diff_with_context( - &old_text_normalized, - &new_text, - editable_region_start_line as u32, - editable_region_start_line as u32, - editable_region_lines, - ); - - let formatted_diff = format!( - "--- a/{path}\n+++ b/{path}\n{diff}", - path = example.spec.cursor_path.to_string_lossy(), - ); - let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset); + let cursor_offset = parsed.cursor_offset_in_new_editable_region; + let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?; let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| { ActualCursor::from_editable_region( diff --git a/crates/zeta_prompt/Cargo.toml b/crates/zeta_prompt/Cargo.toml index 21634583d33e13cd9570041f3e8466d05cef9944..8acd91a7a43613fd63f4f46ab73e9485fd64e7d2 100644 --- a/crates/zeta_prompt/Cargo.toml +++ b/crates/zeta_prompt/Cargo.toml @@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs" [dependencies] anyhow.workspace = true +imara-diff.workspace = true serde.workspace = true strum.workspace = true diff --git a/crates/zeta_prompt/src/udiff.rs b/crates/zeta_prompt/src/udiff.rs index 2658da5893ee923dc0f5798554276f5735abb51a..ab0837b9f54ac0bf9ef74038f0c876b751f70200 100644 --- a/crates/zeta_prompt/src/udiff.rs +++ b/crates/zeta_prompt/src/udiff.rs @@ -6,6 +6,10 @@ use std::{ }; use anyhow::{Context as _, Result, anyhow}; +use imara_diff::{ + Algorithm, Sink, diff, + intern::{InternedInput, Interner, Token}, +}; pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> { if prefix.is_empty() { @@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number( } } +pub fn unified_diff_with_context( + old_text: &str, + new_text: &str, + old_start_line: u32, + new_start_line: u32, + context_lines: u32, +) -> String { + let input = InternedInput::new(old_text, new_text); + diff( + Algorithm::Histogram, + &input, + OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines), + ) +} + +struct OffsetUnifiedDiffBuilder<'a> { + before: &'a [Token], + after: &'a [Token], + interner: &'a Interner<&'a str>, + pos: u32, + before_hunk_start: u32, + after_hunk_start: u32, + before_hunk_len: u32, + after_hunk_len: u32, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + buffer: String, + dst: String, +} + +impl<'a> OffsetUnifiedDiffBuilder<'a> { + fn new( + input: &'a InternedInput<&'a str>, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + ) -> Self { + Self { + before_hunk_start: 0, + after_hunk_start: 0, + before_hunk_len: 0, + after_hunk_len: 0, + old_line_offset, + new_line_offset, + context_lines, + buffer: String::with_capacity(8), + dst: String::new(), + interner: &input.interner, + before: &input.before, + after: &input.after, + pos: 0, + } + } + + fn print_tokens(&mut self, tokens: &[Token], prefix: char) { + for &token in tokens { + writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap(); + } + } + + fn flush(&mut self) { + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + return; + } + + let end = (self.pos + self.context_lines).min(self.before.len() as u32); + self.update_pos(end, end); + + writeln!( + &mut self.dst, + "@@ -{},{} +{},{} @@", + self.before_hunk_start + 1 + self.old_line_offset, + self.before_hunk_len, + self.after_hunk_start + 1 + self.new_line_offset, + self.after_hunk_len, + ) + .unwrap(); + write!(&mut self.dst, "{}", &self.buffer).unwrap(); + self.buffer.clear(); + self.before_hunk_len = 0; + self.after_hunk_len = 0; + } + + fn update_pos(&mut self, print_to: u32, move_to: u32) { + self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' '); + let len = print_to - self.pos; + self.before_hunk_len += len; + self.after_hunk_len += len; + self.pos = move_to; + } +} + +impl Sink for OffsetUnifiedDiffBuilder<'_> { + type Out = String; + + fn process_change(&mut self, before: Range, after: Range) { + if before.start - self.pos > self.context_lines * 2 { + self.flush(); + } + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + self.pos = before.start.saturating_sub(self.context_lines); + self.before_hunk_start = self.pos; + self.after_hunk_start = after.start.saturating_sub(self.context_lines); + } + + self.update_pos(before.start, before.end); + self.before_hunk_len += before.end - before.start; + self.after_hunk_len += after.end - after.start; + self.print_tokens( + &self.before[before.start as usize..before.end as usize], + '-', + ); + self.print_tokens(&self.after[after.start as usize..after.end as usize], '+'); + } + + fn finish(mut self) -> Self::Out { + self.flush(); + self.dst + } +} + +pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { + let Some(cursor_offset) = cursor_offset else { + return patch.to_string(); + }; + + let mut result = String::new(); + let mut line_start_offset = 0usize; + + for line in patch.lines() { + if matches!( + DiffLine::parse(line), + DiffLine::Garbage(content) + if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) + ) { + continue; + } + + if !result.is_empty() { + result.push('\n'); + } + result.push_str(line); + + match DiffLine::parse(line) { + DiffLine::Addition(content) => { + let line_end_offset = line_start_offset + content.len(); + + if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { + let cursor_column = cursor_offset - line_start_offset; + + result.push('\n'); + result.push('#'); + for _ in 0..cursor_column { + result.push(' '); + } + write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); + } + + line_start_offset = line_end_offset + 1; + } + DiffLine::Context(content) => { + line_start_offset += content.len() + 1; + } + _ => {} + } + } + + if patch.ends_with('\n') { + result.push('\n'); + } + + result +} + pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text) } @@ -1203,4 +1382,25 @@ mod tests { // Edit range end should be clamped to 7 (new context length). assert_eq!(hunk.edits[0].range, 4..7); } + + #[test] + fn test_unified_diff_with_context_matches_expected_context_window() { + let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n"; + let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n"; + + let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3); + assert_eq!( + diff_default, + "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8); + assert_eq!( + diff_full_context, + "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0); + assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n"); + } } diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 0d72d6cd7a46782aa4b572a4ef564d5fe3dec417..49b86404a8ad49c27e29bb2b887fb3fc8171c35c 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat { impl ZetaFormat { pub fn parse(format_name: &str) -> Result { + let lower = format_name.to_lowercase(); + + // Exact case-insensitive match takes priority, bypassing ambiguity checks. + for variant in ZetaFormat::iter() { + if <&'static str>::from(&variant).to_lowercase() == lower { + return Ok(variant); + } + } + let mut results = ZetaFormat::iter().filter(|version| { <&'static str>::from(version) .to_lowercase() - .contains(&format_name.to_lowercase()) + .contains(&lower) }); let Some(result) = results.next() else { anyhow::bail!( @@ -927,11 +936,39 @@ fn cursor_in_new_text( }) } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct ParsedOutput { /// Text that should replace the editable region pub new_editable_region: String, /// The byte range within `cursor_excerpt` that this replacement applies to pub range_in_excerpt: Range, + /// Byte offset of the cursor marker within `new_editable_region`, if present + pub cursor_offset_in_new_editable_region: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CursorPosition { + pub path: String, + pub row: usize, + pub column: usize, + pub offset: usize, + pub editable_region_offset: usize, +} + +pub fn parsed_output_from_editable_region( + range_in_excerpt: Range, + mut new_editable_region: String, +) -> ParsedOutput { + let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_new_editable_region { + new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + } + + ParsedOutput { + new_editable_region, + range_in_excerpt, + cursor_offset_in_new_editable_region, + } } /// Parse model output for the given zeta format @@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output( let range_in_excerpt = range_in_context.start + context_start..range_in_context.end + context_start; - Ok(ParsedOutput { - new_editable_region: output, - range_in_excerpt, + Ok(parsed_output_from_editable_region(range_in_excerpt, output)) +} + +pub fn parse_zeta2_model_output_as_patch( + output: &str, + format: ZetaFormat, + prompt_inputs: &ZetaPromptInput, +) -> Result { + let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?; + parsed_output_to_patch(prompt_inputs, parsed) +} + +pub fn cursor_position_from_parsed_output( + prompt_inputs: &ZetaPromptInput, + parsed: &ParsedOutput, +) -> Option { + let cursor_offset = parsed.cursor_offset_in_new_editable_region?; + let editable_region_offset = parsed.range_in_excerpt.start; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); + + let new_editable_region = &parsed.new_editable_region; + let prefix_end = cursor_offset.min(new_editable_region.len()); + let new_region_prefix = &new_editable_region[..prefix_end]; + + let row = editable_region_start_line + new_region_prefix.matches('\n').count(); + + let column = match new_region_prefix.rfind('\n') { + Some(last_newline) => cursor_offset - last_newline - 1, + None => { + let content_prefix = &excerpt[..editable_region_offset]; + let content_column = match content_prefix.rfind('\n') { + Some(last_newline) => editable_region_offset - last_newline - 1, + None => editable_region_offset, + }; + content_column + cursor_offset + } + }; + + Some(CursorPosition { + path: prompt_inputs.cursor_path.to_string_lossy().into_owned(), + row, + column, + offset: editable_region_offset + cursor_offset, + editable_region_offset: cursor_offset, }) } +pub fn parsed_output_to_patch( + prompt_inputs: &ZetaPromptInput, + parsed: ParsedOutput, +) -> Result { + let range_in_excerpt = parsed.range_in_excerpt; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let old_text = excerpt[range_in_excerpt.clone()].to_string(); + let mut new_text = parsed.new_editable_region; + + let mut old_text_normalized = old_text; + if !new_text.is_empty() && !new_text.ends_with('\n') { + new_text.push('\n'); + } + if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { + old_text_normalized.push('\n'); + } + + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32; + let editable_region_lines = old_text_normalized.lines().count() as u32; + + let diff = udiff::unified_diff_with_context( + &old_text_normalized, + &new_text, + editable_region_start_line, + editable_region_start_line, + editable_region_lines, + ); + + let path = prompt_inputs + .cursor_path + .to_string_lossy() + .trim_start_matches('/') + .to_string(); + let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}"); + + Ok(udiff::encode_cursor_in_patch( + &formatted_diff, + parsed.cursor_offset_in_new_editable_region, + )) +} + pub fn excerpt_range_for_format( format: ZetaFormat, ranges: &ExcerptRanges, @@ -5400,6 +5522,33 @@ mod tests { assert_eq!(apply_edit(excerpt, &output1), "new content\n"); } + #[test] + fn test_parsed_output_to_patch_round_trips_through_udiff_application() { + let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n"; + let context_start = excerpt.find("ctx start").unwrap(); + let context_end = excerpt.find("after ctx").unwrap(); + let editable_start = excerpt.find("editable old").unwrap(); + let editable_end = editable_start + "editable old\n".len(); + let input = make_input_with_context_range( + excerpt, + editable_start..editable_end, + context_start..context_end, + editable_start, + ); + + let parsed = parse_zeta2_model_output( + "editable new\n>>>>>>> UPDATED\n", + ZetaFormat::V0131GitMergeMarkersPrefix, + &input, + ) + .unwrap(); + let expected = apply_edit(excerpt, &parsed); + let patch = parsed_output_to_patch(&input, parsed).unwrap(); + let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap(); + + assert_eq!(patched, expected); + } + #[test] fn test_special_tokens_not_triggered_by_comment_separator() { // Regression test for https://github.com/zed-industries/zed/issues/52489 From 5bd78e3f8eb9070ce9a8d7fdc71df7508633baae Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Mon, 6 Apr 2026 17:40:58 -0300 Subject: [PATCH 13/21] sidebar: Fix space not working in archive view's search editor (#53268) Similar to https://github.com/zed-industries/zed/pull/52444 but now in the archive view's search editor. Release Notes: - N/A --- crates/agent_ui/src/threads_archive_view.rs | 7 +++++++ crates/sidebar/src/sidebar.rs | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/crates/agent_ui/src/threads_archive_view.rs b/crates/agent_ui/src/threads_archive_view.rs index f0c02eefc34a03c5c45730ac4b53645c5b15a2e1..b7afe2c37d0c278a23d9a41a560e45c356e7b4e1 100644 --- a/crates/agent_ui/src/threads_archive_view.rs +++ b/crates/agent_ui/src/threads_archive_view.rs @@ -218,6 +218,13 @@ impl ThreadsArchiveView { handle.focus(window, cx); } + pub fn is_filter_editor_focused(&self, window: &Window, cx: &App) -> bool { + self.filter_editor + .read(cx) + .focus_handle(cx) + .is_focused(window) + } + fn update_items(&mut self, cx: &mut Context) { let sessions = ThreadMetadataStore::global(cx) .read(cx) diff --git a/crates/sidebar/src/sidebar.rs b/crates/sidebar/src/sidebar.rs index 53ae57d1a7c55f66e40e1d704859d689d41045e4..4d3e282c403d4df27781066c35837f88f3b4cccd 100644 --- a/crates/sidebar/src/sidebar.rs +++ b/crates/sidebar/src/sidebar.rs @@ -1769,7 +1769,11 @@ impl Sidebar { dispatch_context.add("ThreadsSidebar"); dispatch_context.add("menu"); - let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) { + let is_archived_search_focused = matches!(&self.view, SidebarView::Archive(archive) if archive.read(cx).is_filter_editor_focused(window, cx)); + + let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) + || is_archived_search_focused + { "searching" } else { "not_searching" From fb2bff879ce45b14f1cd6589edf1703e7ed4b37b Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Mon, 6 Apr 2026 17:31:59 -0400 Subject: [PATCH 14/21] Add allow_empty commits, detached worktree creation, and new git operations (#53213) Extend the git API with several new capabilities needed for worktree archival and restoration: - Add `allow_empty` flag to `CommitOptions` for creating WIP marker commits - Change `create_worktree` to accept `Option` branch, enabling detached worktree creation when `None` is passed - Add `head_sha()` to read the current HEAD commit hash - Add `update_ref()` and `delete_ref()` for managing git references - Add `stage_all_including_untracked()` to stage everything before a WIP commit - Implement all new operations in `FakeGitRepository` with functional commit history tracking, reset support, and ref management - Update existing call sites for the new `CommitOptions` field and `create_worktree` signature Part 1 of 3 in the persist-worktree stack. These are nonbreaking API additions with no behavioral changes to existing code. Release Notes: - N/A --------- Co-authored-by: Anthony Eid --- crates/collab/src/rpc.rs | 1 + crates/collab/tests/integration/git_tests.rs | 52 +++++ crates/fs/src/fake_git_repo.rs | 128 +++++++++-- crates/fs/tests/integration/fake_git_repo.rs | 12 +- crates/git/src/repository.rs | 73 +++++-- crates/git_ui/src/commit_modal.rs | 1 + crates/git_ui/src/git_panel.rs | 8 +- crates/project/src/git_store.rs | 214 +++++++++++++++++-- crates/proto/proto/git.proto | 10 + crates/proto/proto/zed.proto | 4 +- crates/proto/src/proto.rs | 4 + 11 files changed, 450 insertions(+), 57 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7ed488b0ba62c10326a0e2154f0d2ba895e20a4f..20316fc3403de0e6212d13d455c5b619000d71b1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -435,6 +435,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_read_only_project_request::) + .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(disallow_guest_request::) .add_request_handler(disallow_guest_request::) diff --git a/crates/collab/tests/integration/git_tests.rs b/crates/collab/tests/integration/git_tests.rs index fdaacd768444bd44d8414247f922f38afb7e81d5..2fa67b072f1c3d49ef5ca1b90056fd08d57df1ba 100644 --- a/crates/collab/tests/integration/git_tests.rs +++ b/crates/collab/tests/integration/git_tests.rs @@ -424,6 +424,58 @@ async fn test_remote_git_worktrees( ); } +#[gpui::test] +async fn test_remote_git_head_sha( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + path!("/project"), + json!({ ".git": {}, "file.txt": "content" }), + ) + .await; + + let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await; + let local_head_sha = cx_a.update(|cx| { + project_a + .read(cx) + .active_repository(cx) + .unwrap() + .update(cx, |repository, _| repository.head_sha()) + }); + let local_head_sha = local_head_sha.await.unwrap().unwrap(); + + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + + executor.run_until_parked(); + + let remote_head_sha = cx_b.update(|cx| { + project_b + .read(cx) + .active_repository(cx) + .unwrap() + .update(cx, |repository, _| repository.head_sha()) + }); + let remote_head_sha = remote_head_sha.await.unwrap(); + + assert_eq!(remote_head_sha.unwrap(), local_head_sha); +} + #[gpui::test] async fn test_linked_worktrees_sync( executor: BackgroundExecutor, diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index c25b0ded5daea0674629ce4bea00736cb2eb3ffb..751796fb83164b78dc5d6789f0ae7870eff16ce1 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -36,8 +36,16 @@ pub struct FakeGitRepository { pub(crate) is_trusted: Arc, } +#[derive(Debug, Clone)] +pub struct FakeCommitSnapshot { + pub head_contents: HashMap, + pub index_contents: HashMap, + pub sha: String, +} + #[derive(Debug, Clone)] pub struct FakeGitRepositoryState { + pub commit_history: Vec, pub event_emitter: smol::channel::Sender, pub unmerged_paths: HashMap, pub head_contents: HashMap, @@ -74,6 +82,7 @@ impl FakeGitRepositoryState { oids: Default::default(), remotes: HashMap::default(), graph_commits: Vec::new(), + commit_history: Vec::new(), stash_entries: Default::default(), } } @@ -217,11 +226,52 @@ impl GitRepository for FakeGitRepository { fn reset( &self, - _commit: String, - _mode: ResetMode, + commit: String, + mode: ResetMode, _env: Arc>, ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + self.with_state_async(true, move |state| { + let pop_count = if commit == "HEAD~" || commit == "HEAD^" { + 1 + } else if let Some(suffix) = commit.strip_prefix("HEAD~") { + suffix + .parse::() + .with_context(|| format!("Invalid HEAD~ offset: {commit}"))? + } else { + match state + .commit_history + .iter() + .rposition(|entry| entry.sha == commit) + { + Some(index) => state.commit_history.len() - index, + None => anyhow::bail!("Unknown commit ref: {commit}"), + } + }; + + if pop_count == 0 || pop_count > state.commit_history.len() { + anyhow::bail!( + "Cannot reset {pop_count} commit(s): only {} in history", + state.commit_history.len() + ); + } + + let target_index = state.commit_history.len() - pop_count; + let snapshot = state.commit_history[target_index].clone(); + state.commit_history.truncate(target_index); + + match mode { + ResetMode::Soft => { + state.head_contents = snapshot.head_contents; + } + ResetMode::Mixed => { + state.head_contents = snapshot.head_contents; + state.index_contents = state.head_contents.clone(); + } + } + + state.refs.insert("HEAD".into(), snapshot.sha); + Ok(()) + }) } fn checkout_files( @@ -490,7 +540,7 @@ impl GitRepository for FakeGitRepository { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>> { @@ -505,8 +555,10 @@ impl GitRepository for FakeGitRepository { if let Some(message) = &state.simulated_create_worktree_error { anyhow::bail!("{message}"); } - if state.branches.contains(&branch_name) { - bail!("a branch named '{}' already exists", branch_name); + if let Some(ref name) = branch_name { + if state.branches.contains(name) { + bail!("a branch named '{}' already exists", name); + } } Ok(()) })??; @@ -515,13 +567,22 @@ impl GitRepository for FakeGitRepository { fs.create_dir(&path).await?; // Create .git/worktrees// directory with HEAD, commondir, gitdir. - let ref_name = format!("refs/heads/{branch_name}"); - let worktrees_entry_dir = common_dir_path.join("worktrees").join(&branch_name); + let worktree_entry_name = branch_name + .as_deref() + .unwrap_or_else(|| path.file_name().unwrap().to_str().unwrap()); + let worktrees_entry_dir = common_dir_path.join("worktrees").join(worktree_entry_name); fs.create_dir(&worktrees_entry_dir).await?; + let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); + let head_content = if let Some(ref branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + format!("ref: {ref_name}") + } else { + sha.clone() + }; fs.write_file_internal( worktrees_entry_dir.join("HEAD"), - format!("ref: {ref_name}").into_bytes(), + head_content.into_bytes(), false, )?; fs.write_file_internal( @@ -544,10 +605,12 @@ impl GitRepository for FakeGitRepository { )?; // Update git state: add ref and branch. - let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); fs.with_git_state(&dot_git_path, true, move |state| { - state.refs.insert(ref_name, sha); - state.branches.insert(branch_name); + if let Some(branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + state.refs.insert(ref_name, sha); + state.branches.insert(branch_name); + } Ok::<(), anyhow::Error>(()) })??; Ok(()) @@ -822,11 +885,30 @@ impl GitRepository for FakeGitRepository { &self, _message: gpui::SharedString, _name_and_email: Option<(gpui::SharedString, gpui::SharedString)>, - _options: CommitOptions, + options: CommitOptions, _askpass: AskPassDelegate, _env: Arc>, ) -> BoxFuture<'_, Result<()>> { - async { Ok(()) }.boxed() + self.with_state_async(true, move |state| { + if !options.allow_empty && !options.amend && state.index_contents == state.head_contents + { + anyhow::bail!("nothing to commit (use allow_empty to create an empty commit)"); + } + + let old_sha = state.refs.get("HEAD").cloned().unwrap_or_default(); + state.commit_history.push(FakeCommitSnapshot { + head_contents: state.head_contents.clone(), + index_contents: state.index_contents.clone(), + sha: old_sha, + }); + + state.head_contents = state.index_contents.clone(); + + let new_sha = format!("fake-commit-{}", state.commit_history.len()); + state.refs.insert("HEAD".into(), new_sha); + + Ok(()) + }) } fn run_hook( @@ -1210,6 +1292,24 @@ impl GitRepository for FakeGitRepository { anyhow::bail!("commit_data_reader not supported for FakeGitRepository") } + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> { + self.with_state_async(true, move |state| { + state.refs.insert(ref_name, commit); + Ok(()) + }) + } + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> { + self.with_state_async(true, move |state| { + state.refs.remove(&ref_name); + Ok(()) + }) + } + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> { + async { Ok(()) }.boxed() + } + fn set_trusted(&self, trusted: bool) { self.is_trusted .store(trusted, std::sync::atomic::Ordering::Release); diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs index 6428083c161235001ef29daf3583520e7f7d25a2..f4192a22bb42f88f8769ef59f817b2bf2a288fb9 100644 --- a/crates/fs/tests/integration/fake_git_repo.rs +++ b/crates/fs/tests/integration/fake_git_repo.rs @@ -24,7 +24,7 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a worktree let worktree_1_dir = worktrees_dir.join("feature-branch"); repo.create_worktree( - "feature-branch".to_string(), + Some("feature-branch".to_string()), worktree_1_dir.clone(), Some("abc123".to_string()), ) @@ -47,9 +47,13 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a second worktree (without explicit commit) let worktree_2_dir = worktrees_dir.join("bugfix-branch"); - repo.create_worktree("bugfix-branch".to_string(), worktree_2_dir.clone(), None) - .await - .unwrap(); + repo.create_worktree( + Some("bugfix-branch".to_string()), + worktree_2_dir.clone(), + None, + ) + .await + .unwrap(); let worktrees = repo.worktrees().await.unwrap(); assert_eq!(worktrees.len(), 3); diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index b03fe1b0c63904bfc751ab7946f92a7c8595db00..c42d2e28cf041e40404c1b8276ddcf5d10ca5f01 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -329,6 +329,7 @@ impl Upstream { pub struct CommitOptions { pub amend: bool, pub signoff: bool, + pub allow_empty: bool, } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -715,7 +716,7 @@ pub trait GitRepository: Send + Sync { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>>; @@ -916,6 +917,12 @@ pub trait GitRepository: Send + Sync { fn commit_data_reader(&self) -> Result; + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>>; + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>>; + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>>; + fn set_trusted(&self, trusted: bool); fn is_trusted(&self) -> bool; } @@ -1660,19 +1667,20 @@ impl GitRepository for RealGitRepository { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let git_binary = self.git_binary(); - let mut args = vec![ - OsString::from("worktree"), - OsString::from("add"), - OsString::from("-b"), - OsString::from(branch_name.as_str()), - OsString::from("--"), - OsString::from(path.as_os_str()), - ]; + let mut args = vec![OsString::from("worktree"), OsString::from("add")]; + if let Some(branch_name) = &branch_name { + args.push(OsString::from("-b")); + args.push(OsString::from(branch_name.as_str())); + } else { + args.push(OsString::from("--detach")); + } + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); if let Some(from_commit) = from_commit { args.push(OsString::from(from_commit)); } else { @@ -2165,6 +2173,10 @@ impl GitRepository for RealGitRepository { cmd.arg("--signoff"); } + if options.allow_empty { + cmd.arg("--allow-empty"); + } + if let Some((name, email)) = name_and_email { cmd.arg("--author").arg(&format!("{name} <{email}>")); } @@ -2176,6 +2188,39 @@ impl GitRepository for RealGitRepository { .boxed() } + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["update-ref".into(), ref_name.into(), commit.into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["update-ref".into(), "-d".into(), ref_name.into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["worktree".into(), "repair".into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + fn push( &self, branch_name: String, @@ -4009,7 +4054,7 @@ mod tests { // Create a new worktree repo.create_worktree( - "test-branch".to_string(), + Some("test-branch".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4068,7 +4113,7 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("worktree-to-remove"); repo.create_worktree( - "to-remove".to_string(), + Some("to-remove".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4092,7 +4137,7 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("dirty-wt"); repo.create_worktree( - "dirty-wt".to_string(), + Some("dirty-wt".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4162,7 +4207,7 @@ mod tests { // Create a worktree let old_path = worktrees_dir.join("old-worktree-name"); repo.create_worktree( - "old-name".to_string(), + Some("old-name".to_string()), old_path.clone(), Some("HEAD".to_string()), ) diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index 432da803e6eedfec304836198f6111f5418084cc..2088ad77ec5d7e71bdfb42ebcbfab6d001f64375 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -453,6 +453,7 @@ impl CommitModal { CommitOptions { amend: is_amend_pending, signoff: is_signoff_enabled, + allow_empty: false, }, window, cx, diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index aac1ec1a19ab53913a830738ae528fb2c0c10248..0cb8ec6b78929d216b700b6e21cbf43a538c6f56 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2155,6 +2155,7 @@ impl GitPanel { CommitOptions { amend: false, signoff: self.signoff_enabled, + allow_empty: false, }, window, cx, @@ -2195,6 +2196,7 @@ impl GitPanel { CommitOptions { amend: true, signoff: self.signoff_enabled, + allow_empty: false, }, window, cx, @@ -4454,7 +4456,11 @@ impl GitPanel { git_panel .update(cx, |git_panel, cx| { git_panel.commit_changes( - CommitOptions { amend, signoff }, + CommitOptions { + amend, + signoff, + allow_empty: false, + }, window, cx, ); diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 6bc7f1ab52db8665efac7ab5631986b5ec0c8e33..e7e84ffe673881d898a56b64892887b9c8d6c809 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -329,6 +329,12 @@ pub struct GraphDataResponse<'a> { pub error: Option, } +#[derive(Clone, Debug)] +enum CreateWorktreeStartPoint { + Detached, + Branched { name: String }, +} + pub struct Repository { this: WeakEntity, snapshot: RepositorySnapshot, @@ -588,6 +594,7 @@ impl GitStore { client.add_entity_request_handler(Self::handle_create_worktree); client.add_entity_request_handler(Self::handle_remove_worktree); client.add_entity_request_handler(Self::handle_rename_worktree); + client.add_entity_request_handler(Self::handle_get_head_sha); } pub fn is_local(&self) -> bool { @@ -2340,6 +2347,7 @@ impl GitStore { CommitOptions { amend: options.amend, signoff: options.signoff, + allow_empty: options.allow_empty, }, askpass, cx, @@ -2406,12 +2414,18 @@ impl GitStore { let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; let directory = PathBuf::from(envelope.payload.directory); - let name = envelope.payload.name; + let start_point = if envelope.payload.name.is_empty() { + CreateWorktreeStartPoint::Detached + } else { + CreateWorktreeStartPoint::Branched { + name: envelope.payload.name, + } + }; let commit = envelope.payload.commit; repository_handle .update(&mut cx, |repository_handle, _| { - repository_handle.create_worktree(name, directory, commit) + repository_handle.create_worktree_with_start_point(start_point, directory, commit) }) .await??; @@ -2456,6 +2470,21 @@ impl GitStore { Ok(proto::Ack {}) } + async fn handle_get_head_sha( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + let head_sha = repository_handle + .update(&mut cx, |repository_handle, _| repository_handle.head_sha()) + .await??; + + Ok(proto::GitGetHeadShaResponse { sha: head_sha }) + } + async fn handle_get_branches( this: Entity, envelope: TypedEnvelope, @@ -5493,6 +5522,7 @@ impl Repository { options: Some(proto::commit::CommitOptions { amend: options.amend, signoff: options.signoff, + allow_empty: options.allow_empty, }), askpass_id, }) @@ -5974,36 +6004,174 @@ impl Repository { }) } + fn create_worktree_with_start_point( + &mut self, + start_point: CreateWorktreeStartPoint, + path: PathBuf, + commit: Option, + ) -> oneshot::Receiver> { + if matches!( + &start_point, + CreateWorktreeStartPoint::Branched { name } if name.is_empty() + ) { + let (sender, receiver) = oneshot::channel(); + sender + .send(Err(anyhow!("branch name cannot be empty"))) + .ok(); + return receiver; + } + + let id = self.id; + let message = match &start_point { + CreateWorktreeStartPoint::Detached => "git worktree add (detached)".into(), + CreateWorktreeStartPoint::Branched { name } => { + format!("git worktree add: {name}").into() + } + }; + + self.send_job(Some(message), move |repo, _cx| async move { + let branch_name = match start_point { + CreateWorktreeStartPoint::Detached => None, + CreateWorktreeStartPoint::Branched { name } => Some(name), + }; + let remote_name = branch_name.clone().unwrap_or_default(); + + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.create_worktree(branch_name, path, commit).await + } + RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + client + .request(proto::GitCreateWorktree { + project_id: project_id.0, + repository_id: id.to_proto(), + name: remote_name, + directory: path.to_string_lossy().to_string(), + commit, + }) + .await?; + + Ok(()) + } + } + }) + } + pub fn create_worktree( &mut self, branch_name: String, path: PathBuf, commit: Option, ) -> oneshot::Receiver> { + self.create_worktree_with_start_point( + CreateWorktreeStartPoint::Branched { name: branch_name }, + path, + commit, + ) + } + + pub fn create_worktree_detached( + &mut self, + path: PathBuf, + commit: String, + ) -> oneshot::Receiver> { + self.create_worktree_with_start_point( + CreateWorktreeStartPoint::Detached, + path, + Some(commit), + ) + } + + pub fn head_sha(&mut self) -> oneshot::Receiver>> { let id = self.id; - self.send_job( - Some(format!("git worktree add: {}", branch_name).into()), - move |repo, _cx| async move { - match repo { - RepositoryState::Local(LocalRepositoryState { backend, .. }) => { - backend.create_worktree(branch_name, path, commit).await - } - RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { - client - .request(proto::GitCreateWorktree { - project_id: project_id.0, - repository_id: id.to_proto(), - name: branch_name, - directory: path.to_string_lossy().to_string(), - commit, - }) - .await?; + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + Ok(backend.head_sha().await) + } + RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + let response = client + .request(proto::GitGetHeadSha { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await?; - Ok(()) - } + Ok(response.sha) } - }, - ) + } + }) + } + + pub fn update_ref( + &mut self, + ref_name: String, + commit: String, + ) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.update_ref(ref_name, commit).await + } + RepositoryState::Remote(_) => { + anyhow::bail!("update_ref is not supported for remote repositories") + } + } + }) + } + + pub fn delete_ref(&mut self, ref_name: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.delete_ref(ref_name).await + } + RepositoryState::Remote(_) => { + anyhow::bail!("delete_ref is not supported for remote repositories") + } + } + }) + } + + pub fn resolve_commit(&mut self, sha: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + let results = backend.revparse_batch(vec![sha]).await?; + Ok(results.into_iter().next().flatten().is_some()) + } + RepositoryState::Remote(_) => { + anyhow::bail!("resolve_commit is not supported for remote repositories") + } + } + }) + } + + pub fn repair_worktrees(&mut self) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.repair_worktrees().await + } + RepositoryState::Remote(_) => { + anyhow::bail!("repair_worktrees is not supported for remote repositories") + } + } + }) + } + + pub fn commit_exists(&mut self, sha: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + let results = backend.revparse_batch(vec![sha]).await?; + Ok(results.into_iter().next().flatten().is_some()) + } + RepositoryState::Remote(_) => { + anyhow::bail!("commit_exists is not supported for remote repositories") + } + } + }) } pub fn remove_worktree(&mut self, path: PathBuf, force: bool) -> oneshot::Receiver> { diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 0cbb635d78dddc81aa7c75340f2fbebe83a474e3..9324feb21b1f50ac1041ed0afc8b59cb9b7fe2c6 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -403,6 +403,7 @@ message Commit { message CommitOptions { bool amend = 1; bool signoff = 2; + bool allow_empty = 3; } } @@ -567,6 +568,15 @@ message GitGetWorktrees { uint64 repository_id = 2; } +message GitGetHeadSha { + uint64 project_id = 1; + uint64 repository_id = 2; +} + +message GitGetHeadShaResponse { + optional string sha = 1; +} + message GitWorktreesResponse { repeated Worktree worktrees = 1; } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 24e7c5372f2679eab1726487e1967edcef6024ed..8b62754d7af40b7c4f5e1a87ad42899d682ba453 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -474,7 +474,9 @@ message Envelope { GitCompareCheckpoints git_compare_checkpoints = 436; GitCompareCheckpointsResponse git_compare_checkpoints_response = 437; GitDiffCheckpoints git_diff_checkpoints = 438; - GitDiffCheckpointsResponse git_diff_checkpoints_response = 439; // current max + GitDiffCheckpointsResponse git_diff_checkpoints_response = 439; + GitGetHeadSha git_get_head_sha = 440; + GitGetHeadShaResponse git_get_head_sha_response = 441; // current max } reserved 87 to 88; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index dd77d2a2da8d4dbc2c0f91f63cb59dd1591ee3f4..b77bd02313c13a9b04eb7762a97f9e77ac8cbaf8 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -351,6 +351,8 @@ messages!( (NewExternalAgentVersionAvailable, Background), (RemoteStarted, Background), (GitGetWorktrees, Background), + (GitGetHeadSha, Background), + (GitGetHeadShaResponse, Background), (GitWorktreesResponse, Background), (GitCreateWorktree, Background), (GitRemoveWorktree, Background), @@ -558,6 +560,7 @@ request_messages!( (GetContextServerCommand, ContextServerCommand), (RemoteStarted, Ack), (GitGetWorktrees, GitWorktreesResponse), + (GitGetHeadSha, GitGetHeadShaResponse), (GitCreateWorktree, Ack), (GitRemoveWorktree, Ack), (GitRenameWorktree, Ack), @@ -749,6 +752,7 @@ entity_messages!( ExternalAgentLoadingStatusUpdated, NewExternalAgentVersionAvailable, GitGetWorktrees, + GitGetHeadSha, GitCreateWorktree, GitRemoveWorktree, GitRenameWorktree, From bc4d25ca760277f61cd310f8796557becb5b5822 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Mon, 6 Apr 2026 23:38:51 +0200 Subject: [PATCH 15/21] lsp: Do not pass in null diagnostic identifiers (#53272) This fixes a crash with new Preview versions of tsgo after https://github.com/microsoft/typescript-go/pull/3313 Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [ ] Unsafe blocks (if any) have justifying comments - [ ] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [ ] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1a5a11ad0c0549791545cd7e020e283decb5b53..97412711a55667a4976a35313eb6c0388acc74ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10074,7 +10074,7 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.95.1" -source = "git+https://github.com/zed-industries/lsp-types?rev=a4f410987660bf560d1e617cb78117c6b6b9f599#a4f410987660bf560d1e617cb78117c6b6b9f599" +source = "git+https://github.com/zed-industries/lsp-types?rev=c7396459fefc7886b4adfa3b596832405ae1e880#c7396459fefc7886b4adfa3b596832405ae1e880" dependencies = [ "bitflags 1.3.2", "serde", diff --git a/Cargo.toml b/Cargo.toml index a800a6c9b276c5f30d6b6eca2f9f0f660f28b02d..5cb5b991b645ec1b78b16f48493c7c8dc1426344 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -599,7 +599,7 @@ linkify = "0.10.0" libwebrtc = "0.3.26" livekit = { version = "0.7.32", features = ["tokio", "rustls-tls-native-roots"] } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } -lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "a4f410987660bf560d1e617cb78117c6b6b9f599" } +lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c7396459fefc7886b4adfa3b596832405ae1e880" } mach2 = "0.5" markup5ever_rcdom = "0.3.0" metal = "0.33" From a018333d41bfe07102d1d0b68383ff5afac00307 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Mon, 6 Apr 2026 14:42:04 -0700 Subject: [PATCH 16/21] Introduce the temporary/retained workspace behavior based on whether the sidebar is open (#53267) Self-Review Checklist: - [ ] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --- crates/agent_ui/src/agent_panel.rs | 12 +- crates/agent_ui/src/conversation_view.rs | 1 - crates/agent_ui/src/threads_archive_view.rs | 1 - crates/recent_projects/src/recent_projects.rs | 10 +- crates/settings_ui/src/settings_ui.rs | 1 - crates/sidebar/src/sidebar.rs | 53 +-- crates/sidebar/src/sidebar_tests.rs | 425 +++++++++++++----- crates/title_bar/src/title_bar.rs | 2 - crates/workspace/src/multi_workspace.rs | 234 ++++++---- crates/workspace/src/multi_workspace_tests.rs | 24 + crates/workspace/src/persistence.rs | 29 +- crates/workspace/src/workspace.rs | 32 +- crates/zed/src/visual_test_runner.rs | 28 +- crates/zed/src/zed.rs | 68 ++- 14 files changed, 634 insertions(+), 286 deletions(-) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5fd39509df4ec2263e47c7e87b3e4b7852eaf154..41900e71e5d3ad7e5327ee7e04f73cb05eed5a5b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -5175,7 +5175,7 @@ mod tests { multi_workspace .read_with(cx, |multi_workspace, _cx| { assert_eq!( - multi_workspace.workspaces().len(), + multi_workspace.workspaces().count(), 1, "LocalProject should not create a new workspace" ); @@ -5451,6 +5451,11 @@ mod tests { let multi_workspace = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); let workspace = multi_workspace .read_with(cx, |multi_workspace, _cx| { @@ -5538,15 +5543,14 @@ mod tests { .read_with(cx, |multi_workspace, cx| { // There should be more than one workspace now (the original + the new worktree). assert!( - multi_workspace.workspaces().len() > 1, + multi_workspace.workspaces().count() > 1, "expected a new workspace to have been created, found {}", - multi_workspace.workspaces().len(), + multi_workspace.workspaces().count(), ); // Check the newest workspace's panel for the correct agent. let new_workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.entity_id() != workspace.entity_id()) .expect("should find the new workspace"); let new_panel = new_workspace diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index ce125a5d7c901ccb6fc89f405f482cbf52b94f5d..149ed2e2fc0f9b22244e0d69deebf5aa7bb7d4c5 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -3375,7 +3375,6 @@ pub(crate) mod tests { // Verify workspace1 is no longer the active workspace multi_workspace_handle .read_with(cx, |mw, _cx| { - assert_eq!(mw.active_workspace_index(), 1); assert_ne!(mw.workspace(), &workspace1); }) .unwrap(); diff --git a/crates/agent_ui/src/threads_archive_view.rs b/crates/agent_ui/src/threads_archive_view.rs index b7afe2c37d0c278a23d9a41a560e45c356e7b4e1..13b2aa1a37cd506c338d13db78bce751882e426a 100644 --- a/crates/agent_ui/src/threads_archive_view.rs +++ b/crates/agent_ui/src/threads_archive_view.rs @@ -353,7 +353,6 @@ impl ThreadsArchiveView { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 24010017ff9fa4eb62a1787332fed70f740ccc2d..e3bfc0dc08c95c0ce57b818e50965433a6c6bc98 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -357,7 +357,6 @@ pub fn init(cx: &mut App) { .update(cx, |multi_workspace, window, cx| { let sibling_workspace_ids: HashSet = multi_workspace .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect(); @@ -1113,7 +1112,6 @@ impl PickerDelegate for RecentProjectsDelegate { .update(cx, |multi_workspace, window, cx| { let workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.read(cx).database_id() == Some(workspace_id)) .cloned(); if let Some(workspace) = workspace { @@ -1932,7 +1930,6 @@ impl RecentProjectsDelegate { .update(cx, |multi_workspace, window, cx| { let workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.read(cx).database_id() == Some(workspace_id)) .cloned(); if let Some(workspace) = workspace { @@ -2055,6 +2052,11 @@ mod tests { assert_eq!(cx.update(|cx| cx.windows().len()), 1); let multi_workspace = cx.update(|cx| cx.windows()[0].downcast::().unwrap()); + multi_workspace + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); multi_workspace .update(cx, |multi_workspace, _, cx| { assert!(!multi_workspace.workspace().read(cx).is_edited()) @@ -2141,7 +2143,7 @@ mod tests { ); assert!( - multi_workspace.workspaces().contains(&dirty_workspace), + multi_workspace.workspaces().any(|w| w == &dirty_workspace), "The dirty workspace should still be present in multi-workspace mode" ); diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 634db0e247fdc370c479df0ed4f6d1f84a5284f6..4c7a98f6c0fa94e659a6db4e00aa28e2b4516e13 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -3753,7 +3753,6 @@ fn all_projects( .flat_map(|multi_workspace| { multi_workspace .workspaces() - .iter() .map(|workspace| workspace.read(cx).project().clone()) .collect::>() }), diff --git a/crates/sidebar/src/sidebar.rs b/crates/sidebar/src/sidebar.rs index 4d3e282c403d4df27781066c35837f88f3b4cccd..d6589361cd9417c2ac6d9025af92f1e096b341b1 100644 --- a/crates/sidebar/src/sidebar.rs +++ b/crates/sidebar/src/sidebar.rs @@ -434,7 +434,7 @@ impl Sidebar { }) .detach(); - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect(); cx.defer_in(window, move |this, window, cx| { for workspace in &workspaces { this.subscribe_to_workspace(workspace, window, cx); @@ -673,7 +673,6 @@ impl Sidebar { let mw = self.multi_workspace.upgrade()?; let mw = mw.read(cx); mw.workspaces() - .iter() .find(|ws| ws.read(cx).project_group_key(cx).path_list() == path_list) .cloned() } @@ -716,8 +715,8 @@ impl Sidebar { return; }; let mw = multi_workspace.read(cx); - let workspaces = mw.workspaces().to_vec(); - let active_workspace = mw.workspaces().get(mw.active_workspace_index()).cloned(); + let workspaces: Vec<_> = mw.workspaces().cloned().collect(); + let active_workspace = Some(mw.workspace().clone()); let agent_server_store = workspaces .first() @@ -1993,7 +1992,6 @@ impl Sidebar { let workspace = window.read(cx).ok().and_then(|multi_workspace| { multi_workspace .workspaces() - .iter() .find(|workspace| predicate(workspace, cx)) .cloned() })?; @@ -2010,7 +2008,6 @@ impl Sidebar { multi_workspace .read(cx) .workspaces() - .iter() .find(|workspace| predicate(workspace, cx)) .cloned() }) @@ -2203,12 +2200,10 @@ impl Sidebar { return; } - let active_workspace = self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }); + let active_workspace = self + .multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()); if let Some(workspace) = active_workspace { self.activate_thread_locally(&metadata, &workspace, window, cx); @@ -2343,7 +2338,7 @@ impl Sidebar { return; }; - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect(); for workspace in workspaces { if let Some(agent_panel) = workspace.read(cx).panel::(cx) { let cancelled = @@ -2936,7 +2931,6 @@ impl Sidebar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) @@ -3404,12 +3398,9 @@ impl Sidebar { } fn active_workspace(&self, cx: &App) -> Option> { - self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }) + self.multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()) } fn show_thread_import_modal(&mut self, window: &mut Window, cx: &mut Context) { @@ -3517,12 +3508,11 @@ impl Sidebar { } fn show_archive(&mut self, window: &mut Window, cx: &mut Context) { - let Some(active_workspace) = self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }) else { + let Some(active_workspace) = self + .multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()) + else { return; }; let Some(agent_panel) = active_workspace.read(cx).panel::(cx) else { @@ -3824,12 +3814,12 @@ pub fn dump_workspace_info( let multi_workspace = workspace.multi_workspace().and_then(|weak| weak.upgrade()); let workspaces: Vec> = match &multi_workspace { - Some(mw) => mw.read(cx).workspaces().to_vec(), + Some(mw) => mw.read(cx).workspaces().cloned().collect(), None => vec![this_entity.clone()], }; - let active_index = multi_workspace + let active_workspace = multi_workspace .as_ref() - .map(|mw| mw.read(cx).active_workspace_index()); + .map(|mw| mw.read(cx).workspace().clone()); writeln!(output, "MultiWorkspace: {} workspace(s)", workspaces.len()).ok(); @@ -3841,13 +3831,10 @@ pub fn dump_workspace_info( } } - if let Some(index) = active_index { - writeln!(output, "Active workspace index: {index}").ok(); - } writeln!(output).ok(); for (index, ws) in workspaces.iter().enumerate() { - let is_active = active_index == Some(index); + let is_active = active_workspace.as_ref() == Some(ws); writeln!( output, "--- Workspace {index}{} ---", diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index a50c5dadbdbff77ccadd81dd96196a469e920e87..60881acfe9461f7897d6013831970444b7a65544 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -77,6 +77,18 @@ async fn init_test_project( fn setup_sidebar( multi_workspace: &Entity, cx: &mut gpui::VisualTestContext, +) -> Entity { + let sidebar = setup_sidebar_closed(multi_workspace, cx); + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + cx.run_until_parked(); + sidebar +} + +fn setup_sidebar_closed( + multi_workspace: &Entity, + cx: &mut gpui::VisualTestContext, ) -> Entity { let multi_workspace = multi_workspace.clone(); let sidebar = @@ -172,16 +184,7 @@ fn save_thread_metadata( cx.run_until_parked(); } -fn open_and_focus_sidebar(sidebar: &Entity, cx: &mut gpui::VisualTestContext) { - let multi_workspace = sidebar.read_with(cx, |s, _| s.multi_workspace.upgrade()); - if let Some(multi_workspace) = multi_workspace { - multi_workspace.update_in(cx, |mw, window, cx| { - if !mw.sidebar_open() { - mw.toggle_sidebar(window, cx); - } - }); - } - cx.run_until_parked(); +fn focus_sidebar(sidebar: &Entity, cx: &mut gpui::VisualTestContext) { sidebar.update_in(cx, |_, window, cx| { cx.focus_self(window); }); @@ -544,7 +547,7 @@ async fn test_workspace_lifecycle(cx: &mut TestAppContext) { // Remove the second workspace multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).cloned().unwrap(); mw.remove(&workspace, window, cx); }); cx.run_until_parked(); @@ -604,7 +607,7 @@ async fn test_view_more_batched_expansion(cx: &mut TestAppContext) { assert!(entries.iter().any(|e| e.contains("View More"))); // Focus and navigate to View More, then confirm to expand by one batch - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); for _ in 0..7 { cx.dispatch_action(SelectNext); } @@ -915,7 +918,7 @@ async fn test_keyboard_select_next_and_previous(cx: &mut TestAppContext) { // Entries: [header, thread3, thread2, thread1] // Focusing the sidebar does not set a selection; select_next/select_previous // handle None gracefully by starting from the first or last entry. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // First SelectNext from None starts at index 0 @@ -970,7 +973,7 @@ async fn test_keyboard_select_first_and_last(cx: &mut TestAppContext) { multi_workspace.update_in(cx, |_, _window, cx| cx.notify()); cx.run_until_parked(); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); // SelectLast jumps to the end cx.dispatch_action(SelectLast); @@ -993,7 +996,7 @@ async fn test_keyboard_focus_in_does_not_set_selection(cx: &mut TestAppContext) // Open the sidebar so it's rendered, then focus it to trigger focus_in. // focus_in no longer sets a default selection. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // Manually set a selection, blur, then refocus — selection should be preserved @@ -1030,7 +1033,7 @@ async fn test_keyboard_confirm_on_project_header_toggles_collapse(cx: &mut TestA ); // Focus the sidebar and select the header (index 0) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1071,7 +1074,7 @@ async fn test_keyboard_confirm_on_view_more_expands(cx: &mut TestAppContext) { assert!(entries.iter().any(|e| e.contains("View More"))); // Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 6) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); for _ in 0..7 { cx.dispatch_action(SelectNext); } @@ -1105,7 +1108,7 @@ async fn test_keyboard_expand_and_collapse_selected_entry(cx: &mut TestAppContex ); // Focus sidebar and manually select the header (index 0). Press left to collapse. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1144,7 +1147,7 @@ async fn test_keyboard_collapse_from_child_selects_parent(cx: &mut TestAppContex cx.run_until_parked(); // Focus sidebar (selection starts at None), then navigate down to the thread (child) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); cx.dispatch_action(SelectNext); cx.dispatch_action(SelectNext); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1)); @@ -1179,7 +1182,7 @@ async fn test_keyboard_navigation_on_empty_list(cx: &mut TestAppContext) { ); // Focus sidebar — focus_in does not set a selection - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // First SelectNext from None starts at index 0 (header) @@ -1211,7 +1214,7 @@ async fn test_selection_clamps_after_entry_removal(cx: &mut TestAppContext) { cx.run_until_parked(); // Focus sidebar (selection starts at None), navigate down to the thread (index 1) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); cx.dispatch_action(SelectNext); cx.dispatch_action(SelectNext); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1)); @@ -1492,7 +1495,7 @@ async fn test_escape_clears_search_and_restores_full_list(cx: &mut TestAppContex ); // User types a search query to filter down. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); type_in_search(&sidebar, "alpha", cx); assert_eq!( visible_entries_as_strings(&sidebar, cx), @@ -1540,8 +1543,9 @@ async fn test_search_only_shows_workspace_headers_with_matches(cx: &mut TestAppC }); cx.run_until_parked(); - let project_b = - multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone()); + let project_b = multi_workspace.read_with(cx, |mw, cx| { + mw.workspaces().nth(1).unwrap().read(cx).project().clone() + }); for (id, title, hour) in [ ("b1", "Refactor sidebar layout", 3), @@ -1621,8 +1625,9 @@ async fn test_search_matches_workspace_name(cx: &mut TestAppContext) { }); cx.run_until_parked(); - let project_b = - multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone()); + let project_b = multi_workspace.read_with(cx, |mw, cx| { + mw.workspaces().nth(1).unwrap().read(cx).project().clone() + }); for (id, title, hour) in [ ("b1", "Refactor sidebar layout", 3), @@ -1764,7 +1769,7 @@ async fn test_search_finds_threads_inside_collapsed_groups(cx: &mut TestAppConte // User focuses the sidebar and collapses the group using keyboard: // manually select the header, then press SelectParent to collapse. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1807,7 +1812,7 @@ async fn test_search_then_keyboard_navigate_and_confirm(cx: &mut TestAppContext) } cx.run_until_parked(); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); // User types "fix" — two threads match. type_in_search(&sidebar, "fix", cx); @@ -1856,6 +1861,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC }); cx.run_until_parked(); + let (workspace_0, workspace_1) = multi_workspace.read_with(cx, |mw, _| { + ( + mw.workspaces().next().unwrap().clone(), + mw.workspaces().nth(1).unwrap().clone(), + ) + }); + save_thread_metadata( acp::SessionId::new(Arc::from("hist-1")), "Historical Thread".into(), @@ -1875,13 +1887,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC // Switch to workspace 1 so we can verify the confirm switches back. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_1 ); // Confirm on the historical (non-live) thread at index 1. @@ -1895,8 +1907,8 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_0 ); } @@ -2037,7 +2049,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { let panel_b = add_agent_panel(&workspace_b, cx); cx.run_until_parked(); - let workspace_a = multi_workspace.read_with(cx, |mw, _cx| mw.workspaces()[0].clone()); + let workspace_a = + multi_workspace.read_with(cx, |mw, _cx| mw.workspaces().next().unwrap().clone()); // ── 1. Initial state: focused thread derived from active panel ───── sidebar.read_with(cx, |sidebar, _cx| { @@ -2135,7 +2148,7 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { }); multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); @@ -2190,8 +2203,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { // Switching workspaces via the multi_workspace (simulates clicking // a workspace header) should clear focused_thread. multi_workspace.update_in(cx, |mw, window, cx| { - if let Some(index) = mw.workspaces().iter().position(|w| w == &workspace_b) { - let workspace = mw.workspaces()[index].clone(); + let workspace = mw.workspaces().find(|w| *w == &workspace_b).cloned(); + if let Some(workspace) = workspace { mw.activate(workspace, window, cx); } }); @@ -2477,6 +2490,8 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -2485,12 +2500,10 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp // Switch to the worktree workspace. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - // Create a non-empty thread in the worktree workspace. let connection = StubAgentConnection::new(); connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk( @@ -3027,6 +3040,8 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -3037,12 +3052,10 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp // Switch back to the main workspace before setting up the sidebar. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - // Start a thread in the worktree workspace's panel and keep it // generating (don't resolve it). let connection = StubAgentConnection::new(); @@ -3127,6 +3140,8 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -3134,12 +3149,10 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp let worktree_panel = add_agent_panel(&worktree_workspace, cx); multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - let connection = StubAgentConnection::new(); open_thread_with_connection(&worktree_panel, connection.clone(), cx); send_message(&worktree_panel, cx); @@ -3231,12 +3244,12 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut // Only 1 workspace should exist. assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 1, ); // Focus the sidebar and select the worktree thread. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(1); // index 0 is header, 1 is the thread }); @@ -3248,11 +3261,11 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut // A new workspace should have been created for the worktree path. let new_workspace = multi_workspace.read_with(cx, |mw, _| { assert_eq!( - mw.workspaces().len(), + mw.workspaces().count(), 2, "confirming a worktree thread without a workspace should open one", ); - mw.workspaces()[1].clone() + mw.workspaces().nth(1).unwrap().clone() }); let new_path_list = @@ -3318,7 +3331,7 @@ async fn test_clicking_worktree_thread_does_not_briefly_render_as_separate_proje vec!["v [project]", " WT Thread {wt-feature-a}"], ); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(1); // index 0 is header, 1 is the thread }); @@ -3444,18 +3457,19 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); // Activate the main workspace before setting up the sidebar. - multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); - mw.activate(workspace, window, cx); + let main_workspace = multi_workspace.update_in(cx, |mw, window, cx| { + let workspace = mw.workspaces().next().unwrap().clone(); + mw.activate(workspace.clone(), window, cx); + workspace }); - let sidebar = setup_sidebar(&multi_workspace, cx); - save_named_thread_metadata("thread-main", "Main Thread", &main_project, cx).await; save_named_thread_metadata("thread-wt", "WT Thread", &worktree_project, cx).await; @@ -3475,13 +3489,13 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( .expect("should find the worktree thread entry"); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + main_workspace, "main workspace should be active initially" ); // Focus the sidebar and select the absorbed worktree thread. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(wt_thread_index); }); @@ -3491,9 +3505,7 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( cx.run_until_parked(); // The worktree workspace should now be active, not the main one. - let active_workspace = multi_workspace.read_with(cx, |mw, _| { - mw.workspaces()[mw.active_workspace_index()].clone() - }); + let active_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); assert_eq!( active_workspace, worktree_workspace, "clicking an absorbed worktree thread should activate the worktree workspace" @@ -3520,25 +3532,27 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b.clone(), window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b.clone(), window, cx) + }); + let workspace_a = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); + // Save a thread with path_list pointing to project-b. let session_id = acp::SessionId::new(Arc::from("archived-1")); save_test_thread_metadata(&session_id, &project_b, cx).await; // Ensure workspace A is active. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_a ); // Call activate_archived_thread – should resolve saved paths and @@ -3562,8 +3576,8 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have activated the workspace matching the saved path_list" ); } @@ -3588,21 +3602,23 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b, window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b, window, cx) + }); + let workspace_a = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); + // Start with workspace A active. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_a ); // No thread saved to the store – cwd is the only path hint. @@ -3625,8 +3641,8 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace( cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have activated the workspace matching the cwd" ); } @@ -3651,21 +3667,21 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b, window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b, window, cx) + }); + // Activate workspace B (index 1) to make it the active one. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b ); // No saved thread, no cwd – should fall back to the active workspace. @@ -3688,8 +3704,8 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace( cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have stayed on the active workspace when no path info is available" ); } @@ -3719,7 +3735,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut let session_id = acp::SessionId::new(Arc::from("archived-new-ws")); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 1, "should start with one workspace" ); @@ -3743,7 +3759,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 2, "should have opened a second workspace for the archived thread's saved paths" ); @@ -3768,6 +3784,10 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m cx.add_window(|window, cx| MultiWorkspace::test_new(project_b, window, cx)); let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap(); + let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap(); + + let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx); + let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b); let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx); let sidebar = setup_sidebar(&multi_workspace_a_entity, cx_a); @@ -3794,14 +3814,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should not add the other window's workspace into the current window" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should reuse the existing workspace in the other window" @@ -3871,14 +3891,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window_with_t assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should not add the other window's workspace into the current window" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should reuse the existing workspace in the other window" @@ -3921,6 +3941,10 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths cx.add_window(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap(); + let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap(); + + let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx); + let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b); let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx); let sidebar_a = setup_sidebar(&multi_workspace_a_entity, cx_a); @@ -3958,14 +3982,14 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths }); assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "current window should continue reusing its existing workspace" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "other windows should not be activated just because they also match the saved paths" @@ -4029,19 +4053,20 @@ async fn test_archive_thread_uses_next_threads_own_workspace(cx: &mut TestAppCon let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); // Activate main workspace so the sidebar tracks the main panel. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - - let main_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspaces()[0].clone()); + let main_workspace = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); let main_panel = add_agent_panel(&main_workspace, cx); let _worktree_panel = add_agent_panel(&worktree_workspace, cx); @@ -4195,10 +4220,10 @@ async fn test_linked_worktree_threads_not_duplicated_across_groups(cx: &mut Test let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_only.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(multi_root.clone(), window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); // Save a thread under the linked worktree path. save_named_thread_metadata("wt-thread", "Worktree Thread", &worktree_project, cx).await; @@ -4313,8 +4338,8 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) { // so all three have last_accessed_at set. // Access order is: A (most recent), B, C (oldest). - // ── 1. Open switcher: threads sorted by last_accessed_at ─────────── - open_and_focus_sidebar(&sidebar, cx); + // ── 1. Open switcher: threads sorted by last_accessed_at ───────────────── + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, window, cx| { sidebar.on_toggle_thread_switcher(&ToggleThreadSwitcher::default(), window, cx); }); @@ -4759,6 +4784,170 @@ async fn test_linked_worktree_workspace_shows_main_worktree_threads(cx: &mut Tes ); } +async fn init_multi_project_test( + paths: &[&str], + cx: &mut TestAppContext, +) -> (Arc, Entity) { + agent_ui::test_support::init_test(cx); + cx.update(|cx| { + cx.update_flags(false, vec!["agent-v2".into()]); + ThreadStore::init_global(cx); + ThreadMetadataStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + prompt_store::init(cx); + }); + let fs = FakeFs::new(cx.executor()); + for path in paths { + fs.insert_tree(path, serde_json::json!({ ".git": {}, "src": {} })) + .await; + } + cx.update(|cx| ::set_global(fs.clone(), cx)); + let project = + project::Project::test(fs.clone() as Arc, [paths[0].as_ref()], cx).await; + (fs, project) +} + +async fn add_test_project( + path: &str, + fs: &Arc, + multi_workspace: &Entity, + cx: &mut gpui::VisualTestContext, +) -> Entity { + let project = project::Project::test(fs.clone() as Arc, [path.as_ref()], cx).await; + let workspace = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project, window, cx) + }); + cx.run_until_parked(); + workspace +} + +#[gpui::test] +async fn test_transient_workspace_lifecycle(cx: &mut TestAppContext) { + let (fs, project_a) = + init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + let _sidebar = setup_sidebar_closed(&multi_workspace, cx); + + // Sidebar starts closed. Initial workspace A is transient. + let workspace_a = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + assert!(!multi_workspace.read_with(cx, |mw, _| mw.sidebar_open())); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_a)); + + // Add B — replaces A as the transient workspace. + let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b)); + + // Add C — replaces B as the transient workspace. + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); +} + +#[gpui::test] +async fn test_transient_workspace_retained(cx: &mut TestAppContext) { + let (fs, project_a) = init_multi_project_test( + &["/project-a", "/project-b", "/project-c", "/project-d"], + cx, + ) + .await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + let _sidebar = setup_sidebar(&multi_workspace, cx); + assert!(multi_workspace.read_with(cx, |mw, _| mw.sidebar_open())); + + // Add B — retained since sidebar is open. + let workspace_a = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Switch to A — B survives. (Switching from one internal workspace, to another) + multi_workspace.update_in(cx, |mw, window, cx| mw.activate(workspace_a, window, cx)); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Close sidebar — both A and B remain retained. + multi_workspace.update_in(cx, |mw, window, cx| mw.close_sidebar(window, cx)); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Add C — added as new transient workspace. (switching from retained, to transient) + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 3 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); + + // Add D — replaces C as the transient workspace (Have retained and transient workspaces, transient workspace is dropped) + let workspace_d = add_test_project("/project-d", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 3 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_d)); +} + +#[gpui::test] +async fn test_transient_workspace_promotion(cx: &mut TestAppContext) { + let (fs, project_a) = + init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + setup_sidebar_closed(&multi_workspace, cx); + + // Add B — replaces A as the transient workspace (A is discarded). + let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b)); + + // Open sidebar — promotes the transient B to retained. + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspaces().any(|w| w == &workspace_b))); + + // Close sidebar — the retained B remains. + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + + // Add C — added as new transient workspace. + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); +} + #[gpui::test] async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &mut TestAppContext) { init_test(cx); @@ -4843,12 +5032,12 @@ async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &m // Verify only 1 workspace before clicking. assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 1, ); // Focus and select the legacy thread, then confirm. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); let thread_index = sidebar.read_with(cx, |sidebar, _| { sidebar .contents @@ -5057,7 +5246,12 @@ mod property_test { match operation { Operation::SaveThread { workspace_index } => { let project = multi_workspace.read_with(cx, |mw, cx| { - mw.workspaces()[workspace_index].read(cx).project().clone() + mw.workspaces() + .nth(workspace_index) + .unwrap() + .read(cx) + .project() + .clone() }); save_thread_to_path(state, &project, cx); } @@ -5144,7 +5338,7 @@ mod property_test { } Operation::RemoveWorkspace { index } => { let removed = multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[index].clone(); + let workspace = mw.workspaces().nth(index).unwrap().clone(); mw.remove(&workspace, window, cx) }); if removed { @@ -5158,8 +5352,8 @@ mod property_test { } } Operation::SwitchWorkspace { index } => { - let workspace = - multi_workspace.read_with(cx, |mw, _| mw.workspaces()[index].clone()); + let workspace = multi_workspace + .read_with(cx, |mw, _| mw.workspaces().nth(index).unwrap().clone()); multi_workspace.update_in(cx, |mw, window, cx| { mw.activate(workspace, window, cx); }); @@ -5209,8 +5403,9 @@ mod property_test { .await; // Re-scan the main workspace's project so it discovers the new worktree. - let main_workspace = - multi_workspace.read_with(cx, |mw, _| mw.workspaces()[workspace_index].clone()); + let main_workspace = multi_workspace.read_with(cx, |mw, _| { + mw.workspaces().nth(workspace_index).unwrap().clone() + }); let main_project = main_workspace.read_with(cx, |ws, _| ws.project().clone()); main_project .update(cx, |p, cx| p.git_scans_complete(cx)) @@ -5297,7 +5492,11 @@ mod property_test { let Some(multi_workspace) = sidebar.multi_workspace.upgrade() else { anyhow::bail!("sidebar should still have an associated multi-workspace"); }; - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces = multi_workspace + .read(cx) + .workspaces() + .cloned() + .collect::>(); let thread_store = ThreadMetadataStore::global(cx); let sidebar_thread_ids: HashSet = sidebar diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 440249907adb6d29602ad8e950d0fd26a2d1c31d..dfcd933dc20df9a6f6643402719f2ec1143cc7fe 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -740,7 +740,6 @@ impl TitleBar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) @@ -803,7 +802,6 @@ impl TitleBar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index a0c5eaabc629073dd9a46ac1b5073ddfbd26bd28..a61ad3576c57ecd8b1811363d6b5607ead737821 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -40,10 +40,7 @@ actions!( CloseWorkspaceSidebar, /// Moves focus to or from the workspace sidebar without closing it. FocusWorkspaceSidebar, - /// Switches to the next workspace. - NextWorkspace, - /// Switches to the previous workspace. - PreviousWorkspace, + //TODO: Restore next/previous workspace ] ); @@ -221,10 +218,57 @@ impl SidebarHandle for Entity { } } +/// Tracks which workspace the user is currently looking at. +/// +/// `Persistent` workspaces live in the `workspaces` vec and are shown in the +/// sidebar. `Transient` workspaces exist outside the vec and are discarded +/// when the user switches away. +enum ActiveWorkspace { + /// A persistent workspace, identified by index into the `workspaces` vec. + Persistent(usize), + /// A workspace not in the `workspaces` vec that will be discarded on + /// switch or promoted to persistent when the sidebar is opened. + Transient(Entity), +} + +impl ActiveWorkspace { + fn persistent_index(&self) -> Option { + match self { + Self::Persistent(index) => Some(*index), + Self::Transient(_) => None, + } + } + + fn transient_workspace(&self) -> Option<&Entity> { + match self { + Self::Transient(workspace) => Some(workspace), + Self::Persistent(_) => None, + } + } + + /// Sets the active workspace to transient, returning the previous + /// transient workspace (if any). + fn set_transient(&mut self, workspace: Entity) -> Option> { + match std::mem::replace(self, Self::Transient(workspace)) { + Self::Transient(old) => Some(old), + Self::Persistent(_) => None, + } + } + + /// Sets the active workspace to persistent at the given index, + /// returning the previous transient workspace (if any). + fn set_persistent(&mut self, index: usize) -> Option> { + match std::mem::replace(self, Self::Persistent(index)) { + Self::Transient(workspace) => Some(workspace), + Self::Persistent(_) => None, + } + } +} + pub struct MultiWorkspace { window_id: WindowId, workspaces: Vec>, - active_workspace_index: usize, + active_workspace: ActiveWorkspace, project_group_keys: Vec, sidebar: Option>, sidebar_open: bool, @@ -260,12 +304,15 @@ impl MultiWorkspace { } }); let quit_subscription = cx.on_app_quit(Self::app_will_quit); - let settings_subscription = - cx.observe_global_in::(window, |this, window, cx| { - if DisableAiSettings::get_global(cx).disable_ai && this.sidebar_open { - this.close_sidebar(window, cx); + let settings_subscription = cx.observe_global_in::(window, { + let mut previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai; + move |this, window, cx| { + if DisableAiSettings::get_global(cx).disable_ai != previous_disable_ai { + this.collapse_to_single_workspace(window, cx); + previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai; } - }); + } + }); Self::subscribe_to_workspace(&workspace, window, cx); let weak_self = cx.weak_entity(); workspace.update(cx, |workspace, cx| { @@ -273,9 +320,9 @@ impl MultiWorkspace { }); Self { window_id: window.window_handle().window_id(), - project_group_keys: vec![workspace.read(cx).project_group_key(cx)], - workspaces: vec![workspace], - active_workspace_index: 0, + project_group_keys: Vec::new(), + workspaces: Vec::new(), + active_workspace: ActiveWorkspace::Transient(workspace), sidebar: None, sidebar_open: false, sidebar_overlay: None, @@ -337,7 +384,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { self.close_sidebar(window, cx); } else { self.open_sidebar(cx); @@ -353,7 +400,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { self.close_sidebar(window, cx); } } @@ -363,7 +410,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { let sidebar_is_focused = self .sidebar .as_ref() @@ -388,8 +435,13 @@ impl MultiWorkspace { pub fn open_sidebar(&mut self, cx: &mut Context) { self.sidebar_open = true; + if let ActiveWorkspace::Transient(workspace) = &self.active_workspace { + let workspace = workspace.clone(); + let index = self.promote_transient(workspace, cx); + self.active_workspace = ActiveWorkspace::Persistent(index); + } let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx)); - for workspace in &self.workspaces { + for workspace in self.workspaces.iter() { workspace.update(cx, |workspace, _cx| { workspace.set_sidebar_focus_handle(sidebar_focus_handle.clone()); }); @@ -400,7 +452,7 @@ impl MultiWorkspace { pub fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context) { self.sidebar_open = false; - for workspace in &self.workspaces { + for workspace in self.workspaces.iter() { workspace.update(cx, |workspace, _cx| { workspace.set_sidebar_focus_handle(None); }); @@ -415,7 +467,7 @@ impl MultiWorkspace { pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context) { cx.spawn_in(window, async move |this, cx| { let workspaces = this.update(cx, |multi_workspace, _cx| { - multi_workspace.workspaces().to_vec() + multi_workspace.workspaces().cloned().collect::>() })?; for workspace in workspaces { @@ -657,6 +709,12 @@ impl MultiWorkspace { return Task::ready(Ok(workspace)); } + if let Some(transient) = self.active_workspace.transient_workspace() { + if transient.read(cx).project_group_key(cx).path_list() == &path_list { + return Task::ready(Ok(transient.clone())); + } + } + let paths = path_list.paths().to_vec(); let app_state = self.workspace().read(cx).app_state().clone(); let requesting_window = window.window_handle().downcast::(); @@ -680,25 +738,23 @@ impl MultiWorkspace { } pub fn workspace(&self) -> &Entity { - &self.workspaces[self.active_workspace_index] - } - - pub fn workspaces(&self) -> &[Entity] { - &self.workspaces + match &self.active_workspace { + ActiveWorkspace::Persistent(index) => &self.workspaces[*index], + ActiveWorkspace::Transient(workspace) => workspace, + } } - pub fn active_workspace_index(&self) -> usize { - self.active_workspace_index + pub fn workspaces(&self) -> impl Iterator> { + self.workspaces + .iter() + .chain(self.active_workspace.transient_workspace()) } - /// Adds a workspace to this window without changing which workspace is - /// active. + /// Adds a workspace to this window as persistent without changing which + /// workspace is active. Unlike `activate()`, this always inserts into the + /// persistent list regardless of sidebar state — it's used for system- + /// initiated additions like deserialization and worktree discovery. pub fn add(&mut self, workspace: Entity, window: &Window, cx: &mut Context) { - if !self.multi_workspace_enabled(cx) { - self.set_single_workspace(workspace, cx); - return; - } - self.insert_workspace(workspace, window, cx); } @@ -709,26 +765,74 @@ impl MultiWorkspace { window: &mut Window, cx: &mut Context, ) { - if !self.multi_workspace_enabled(cx) { - self.set_single_workspace(workspace, cx); + // Re-activating the current workspace is a no-op. + if self.workspace() == &workspace { + self.focus_active_workspace(window, cx); return; } - let index = self.insert_workspace(workspace, &*window, cx); - let changed = self.active_workspace_index != index; - self.active_workspace_index = index; - if changed { - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); - self.serialize(cx); + // Resolve where we're going. + let new_index = if let Some(index) = self.workspaces.iter().position(|w| *w == workspace) { + Some(index) + } else if self.sidebar_open { + Some(self.insert_workspace(workspace.clone(), &*window, cx)) + } else { + None + }; + + // Transition the active workspace. + if let Some(index) = new_index { + if let Some(old) = self.active_workspace.set_persistent(index) { + if self.sidebar_open { + self.promote_transient(old, cx); + } else { + self.detach_workspace(&old, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id())); + } + } + } else { + Self::subscribe_to_workspace(&workspace, window, cx); + let weak_self = cx.weak_entity(); + workspace.update(cx, |workspace, cx| { + workspace.set_multi_workspace(weak_self, cx); + }); + if let Some(old) = self.active_workspace.set_transient(workspace) { + self.detach_workspace(&old, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id())); + } } + + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + self.serialize(cx); self.focus_active_workspace(window, cx); cx.notify(); } - fn set_single_workspace(&mut self, workspace: Entity, cx: &mut Context) { - self.workspaces[0] = workspace; - self.active_workspace_index = 0; - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + /// Promotes a former transient workspace into the persistent list. + /// Returns the index of the newly inserted workspace. + fn promote_transient(&mut self, workspace: Entity, cx: &mut Context) -> usize { + let project_group_key = workspace.read(cx).project().read(cx).project_group_key(cx); + self.add_project_group_key(project_group_key); + self.workspaces.push(workspace.clone()); + cx.emit(MultiWorkspaceEvent::WorkspaceAdded(workspace)); + self.workspaces.len() - 1 + } + + /// Collapses to a single transient workspace, discarding all persistent + /// workspaces. Used when multi-workspace is disabled (e.g. disable_ai). + fn collapse_to_single_workspace(&mut self, window: &mut Window, cx: &mut Context) { + if self.sidebar_open { + self.close_sidebar(window, cx); + } + let active = self.workspace().clone(); + for workspace in std::mem::take(&mut self.workspaces) { + if workspace != active { + self.detach_workspace(&workspace, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(workspace.entity_id())); + } + } + self.project_group_keys.clear(); + self.active_workspace = ActiveWorkspace::Transient(active); cx.notify(); } @@ -784,7 +888,7 @@ impl MultiWorkspace { } fn sync_sidebar_to_workspace(&self, workspace: &Entity, cx: &mut Context) { - if self.sidebar_open { + if self.sidebar_open() { let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx)); workspace.update(cx, |workspace, _| { workspace.set_sidebar_focus_handle(sidebar_focus_handle); @@ -792,30 +896,6 @@ impl MultiWorkspace { } } - fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context) { - let count = self.workspaces.len() as isize; - if count <= 1 { - return; - } - let current = self.active_workspace_index as isize; - let next = ((current + delta).rem_euclid(count)) as usize; - let workspace = self.workspaces[next].clone(); - self.activate(workspace, window, cx); - } - - fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context) { - self.cycle_workspace(1, window, cx); - } - - fn previous_workspace( - &mut self, - _: &PreviousWorkspace, - window: &mut Window, - cx: &mut Context, - ) { - self.cycle_workspace(-1, window, cx); - } - pub(crate) fn serialize(&mut self, cx: &mut Context) { self._serialize_task = Some(cx.spawn(async move |this, cx| { let Some((window_id, state)) = this @@ -1070,7 +1150,7 @@ impl MultiWorkspace { let new_workspace = cx.new(|cx| Workspace::new(None, project, app_state, window, cx)); self.workspaces[0] = new_workspace.clone(); - self.active_workspace_index = 0; + self.active_workspace = ActiveWorkspace::Persistent(0); Self::subscribe_to_workspace(&new_workspace, window, cx); @@ -1090,10 +1170,12 @@ impl MultiWorkspace { } else { let removed_workspace = self.workspaces.remove(index); - if self.active_workspace_index >= self.workspaces.len() { - self.active_workspace_index = self.workspaces.len() - 1; - } else if self.active_workspace_index > index { - self.active_workspace_index -= 1; + if let Some(active_index) = self.active_workspace.persistent_index() { + if active_index >= self.workspaces.len() { + self.active_workspace = ActiveWorkspace::Persistent(self.workspaces.len() - 1); + } else if active_index > index { + self.active_workspace = ActiveWorkspace::Persistent(active_index - 1); + } } self.detach_workspace(&removed_workspace, cx); @@ -1343,8 +1425,6 @@ impl Render for MultiWorkspace { this.focus_sidebar(window, cx); }, )) - .on_action(cx.listener(Self::next_workspace)) - .on_action(cx.listener(Self::previous_workspace)) .on_action(cx.listener(Self::move_active_workspace_to_new_window)) .on_action(cx.listener( |this: &mut Self, action: &ToggleThreadSwitcher, window, cx| { diff --git a/crates/workspace/src/multi_workspace_tests.rs b/crates/workspace/src/multi_workspace_tests.rs index 3083c23f6e3add91b0389a961567fc88e2043678..ab6ca43d5aff482b637add9083b1ad9d388d7993 100644 --- a/crates/workspace/src/multi_workspace_tests.rs +++ b/crates/workspace/src/multi_workspace_tests.rs @@ -99,6 +99,10 @@ async fn test_project_group_keys_initial(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.read_with(cx, |mw, _cx| { let keys: Vec<&ProjectGroupKey> = mw.project_group_keys().collect(); assert_eq!(keys.len(), 1, "should have exactly one key on creation"); @@ -125,6 +129,10 @@ async fn test_project_group_keys_add_workspace(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.read_with(cx, |mw, _cx| { assert_eq!(mw.project_group_keys().count(), 1); }); @@ -162,6 +170,10 @@ async fn test_project_group_keys_duplicate_not_added(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_a2, window, cx); }); @@ -189,6 +201,10 @@ async fn test_project_group_keys_on_worktree_added(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + // Add a second worktree to the same project. let (worktree, _) = project .update(cx, |project, cx| { @@ -232,6 +248,10 @@ async fn test_project_group_keys_on_worktree_removed(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + // Remove one worktree. let worktree_b_id = project.read_with(cx, |project, cx| { project @@ -282,6 +302,10 @@ async fn test_project_group_keys_across_multiple_workspaces_and_worktree_changes let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_b, window, cx); }); diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 644ff0282df216e79d6be24918d29b802e50a0e8..2994e9d0f67d73a30838f922c9b6a0b01b21ed14 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -2535,6 +2535,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -2564,7 +2568,7 @@ mod tests { // --- Remove the second workspace (index 1) --- multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4191,6 +4195,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -4233,7 +4241,7 @@ mod tests { // Remove workspace at index 1 (the second workspace). multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4288,6 +4296,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.workspace().update(cx, |ws, _cx| { ws.set_database_id(ws1_id); @@ -4339,7 +4351,7 @@ mod tests { // Remove workspace2 (index 1). multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4385,6 +4397,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -4418,7 +4434,7 @@ mod tests { // Remove workspace2 — this pushes a task to pending_removal_tasks. multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4427,7 +4443,6 @@ mod tests { let all_tasks = multi_workspace.update_in(cx, |mw, window, cx| { let mut tasks: Vec> = mw .workspaces() - .iter() .map(|workspace| { workspace.update(cx, |workspace, cx| { workspace.flush_serialization(window, cx) @@ -4747,6 +4762,10 @@ mod tests { let (multi_workspace, cx) = cx .add_window_view(|window, cx| MultiWorkspace::test_new(project_2.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_1.clone(), window, cx); }); diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index cc5d1e8635e9194522fea5506fef4084f8133c53..7979ffe828cbf8c4da5a40a29eaa6537f1433c3c 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -32,8 +32,8 @@ pub use crate::notifications::NotificationFrame; pub use dock::Panel; pub use multi_workspace::{ CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace, - MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarEvent, SidebarHandle, - SidebarRenderState, SidebarSide, ToggleWorkspaceSidebar, sidebar_side_context_menu, + MultiWorkspaceEvent, Sidebar, SidebarEvent, SidebarHandle, SidebarRenderState, SidebarSide, + ToggleWorkspaceSidebar, sidebar_side_context_menu, }; pub use path_list::{PathList, SerializedPathList}; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; @@ -9079,7 +9079,7 @@ pub fn workspace_windows_for_location( }; multi_workspace.read(cx).is_ok_and(|multi_workspace| { - multi_workspace.workspaces().iter().any(|workspace| { + multi_workspace.workspaces().any(|workspace| { match workspace.read(cx).workspace_location(cx) { WorkspaceLocation::Location(location, _) => { match (&location, serialized_location) { @@ -10741,6 +10741,12 @@ mod tests { cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); cx.run_until_parked(); + multi_workspace_handle + .update(cx, |mw, _window, cx| { + mw.open_sidebar(cx); + }) + .unwrap(); + let workspace_a = multi_workspace_handle .read_with(cx, |mw, _| mw.workspace().clone()) .unwrap(); @@ -10754,7 +10760,7 @@ mod tests { // Activate workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -10776,7 +10782,7 @@ mod tests { // Verify workspace A is active multi_workspace_handle .read_with(cx, |mw, _| { - assert_eq!(mw.active_workspace_index(), 0); + assert_eq!(mw.workspace(), &workspace_a); }) .unwrap(); @@ -10792,8 +10798,8 @@ mod tests { multi_workspace_handle .read_with(cx, |mw, _| { assert_eq!( - mw.active_workspace_index(), - 1, + mw.workspace(), + &workspace_b, "workspace B should be activated when it prompts" ); }) @@ -14511,6 +14517,12 @@ mod tests { cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); cx.run_until_parked(); + multi_workspace_handle + .update(cx, |mw, _window, cx| { + mw.open_sidebar(cx); + }) + .unwrap(); + let workspace_a = multi_workspace_handle .read_with(cx, |mw, _| mw.workspace().clone()) .unwrap(); @@ -14524,7 +14536,7 @@ mod tests { // Switch to workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -14570,7 +14582,7 @@ mod tests { // Switch to workspace B multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -14579,7 +14591,7 @@ mod tests { // Switch back to workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index f1ed73fe89f0980a2705631063dcf4efbbe84bfb..b59123a1a159487f802210f3916e16856daf8e61 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -2606,7 +2606,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Add worktree to workspace 1 (index 0) so it shows as "private-test-remote" let add_worktree1_task = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace1 = &multi_workspace.workspaces()[0]; + let workspace1 = multi_workspace.workspaces().next().unwrap(); let project = workspace1.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&workspace1_dir, true, cx) @@ -2625,7 +2625,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Add worktree to workspace 2 (index 1) so it shows as "zed" let add_worktree2_task = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace2 = &multi_workspace.workspaces()[1]; + let workspace2 = multi_workspace.workspaces().nth(1).unwrap(); let project = workspace2.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&workspace2_dir, true, cx) @@ -2644,7 +2644,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Switch to workspace 1 so it's highlighted as active (index 0) multi_workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = multi_workspace.workspaces()[0].clone(); + let workspace = multi_workspace.workspaces().next().unwrap().clone(); multi_workspace.activate(workspace, window, cx); }) .context("Failed to activate workspace 1")?; @@ -2672,7 +2672,7 @@ fn run_multi_workspace_sidebar_visual_tests( let save_tasks = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { let thread_store = agent::ThreadStore::global(cx); - let workspaces = multi_workspace.workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.workspaces().cloned().collect(); let mut tasks = Vec::new(); for (index, workspace) in workspaces.iter().enumerate() { @@ -3211,7 +3211,7 @@ edition = "2021" // Add the git project as a worktree let add_worktree_task = workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); let project = workspace.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&project_path, true, cx) @@ -3236,7 +3236,7 @@ edition = "2021" // Open the project panel let (weak_workspace, async_window_cx) = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); (workspace.read(cx).weak_handle(), window.to_async(cx)) }) .context("Failed to get workspace handle")?; @@ -3250,7 +3250,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { workspace.add_panel(project_panel, window, cx); workspace.open_panel::(window, cx); @@ -3263,7 +3263,7 @@ edition = "2021" // Open main.rs in the editor let open_file_task = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { let worktree = workspace.project().read(cx).worktrees(cx).next(); if let Some(worktree) = worktree { @@ -3291,7 +3291,7 @@ edition = "2021" // Load the AgentPanel let (weak_workspace, async_window_cx) = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); (workspace.read(cx).weak_handle(), window.to_async(cx)) }) .context("Failed to get workspace handle for agent panel")?; @@ -3335,7 +3335,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { workspace.add_panel(panel.clone(), window, cx); workspace.open_panel::(window, cx); @@ -3512,7 +3512,7 @@ edition = "2021" .is_none() }); let workspace_count = workspace_window.update(cx, |multi_workspace, _window, _cx| { - multi_workspace.workspaces().len() + multi_workspace.workspaces().count() })?; if workspace_count == 2 && status_cleared { creation_complete = true; @@ -3531,7 +3531,7 @@ edition = "2021" // error state by injecting the stub server, and shrink the panel so the // editor content is visible. workspace_window.update(cx, |multi_workspace, window, cx| { - let new_workspace = &multi_workspace.workspaces()[1]; + let new_workspace = multi_workspace.workspaces().nth(1).unwrap(); new_workspace.update(cx, |workspace, cx| { if let Some(new_panel) = workspace.panel::(cx) { new_panel.update(cx, |panel, cx| { @@ -3544,7 +3544,7 @@ edition = "2021" // Type and send a message so the thread target dropdown disappears. let new_panel = workspace_window.update(cx, |multi_workspace, _window, cx| { - let new_workspace = &multi_workspace.workspaces()[1]; + let new_workspace = multi_workspace.workspaces().nth(1).unwrap(); new_workspace.read(cx).panel::(cx) })?; if let Some(new_panel) = new_panel { @@ -3585,7 +3585,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); let project = workspace.read(cx).project().clone(); project.update(cx, |project, cx| { let worktree_ids: Vec<_> = diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index ed49236a9da6b69f80c8c981eaddaa16ca69face..03e128415e1aa8390d1b95816755d3644064dada 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1524,7 +1524,7 @@ fn quit(_: &Quit, cx: &mut App) { let window = *window; let workspaces = window .update(cx, |multi_workspace, _, _| { - multi_workspace.workspaces().to_vec() + multi_workspace.workspaces().cloned().collect::>() }) .log_err(); @@ -2458,7 +2458,6 @@ mod tests { .update(cx, |multi_workspace, window, cx| { let mut tasks = multi_workspace .workspaces() - .iter() .map(|workspace| { workspace.update(cx, |workspace, cx| { workspace.flush_serialization(window, cx) @@ -2610,7 +2609,7 @@ mod tests { cx.run_until_parked(); multi_workspace_1 .update(cx, |multi_workspace, _window, cx| { - assert_eq!(multi_workspace.workspaces().len(), 2); + assert_eq!(multi_workspace.workspaces().count(), 2); assert!(multi_workspace.sidebar_open()); let workspace = multi_workspace.workspace().read(cx); assert_eq!( @@ -5512,6 +5511,11 @@ mod tests { let project = project1.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); assert_eq!(cx.windows().len(), 1, "Should start with 1 window"); @@ -5534,7 +5538,7 @@ mod tests { let workspace1 = window .read_with(cx, |multi_workspace, _| { - multi_workspace.workspaces()[0].clone() + multi_workspace.workspaces().next().unwrap().clone() }) .unwrap(); @@ -5543,8 +5547,8 @@ mod tests { multi_workspace.activate(workspace2.clone(), window, cx); multi_workspace.activate(workspace3.clone(), window, cx); // Switch back to workspace1 for test setup - multi_workspace.activate(workspace1, window, cx); - assert_eq!(multi_workspace.active_workspace_index(), 0); + multi_workspace.activate(workspace1.clone(), window, cx); + assert_eq!(multi_workspace.workspace(), &workspace1); }) .unwrap(); @@ -5553,8 +5557,8 @@ mod tests { // Verify setup: 3 workspaces, workspace 0 active, still 1 window window .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.workspaces().len(), 3); - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspaces().count(), 3); + assert_eq!(multi_workspace.workspace(), &workspace1); }) .unwrap(); assert_eq!(cx.windows().len(), 1); @@ -5577,8 +5581,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 2, + multi_workspace.workspace(), + &workspace3, "Should have switched to workspace 3 which contains /dir3" ); let active_item = multi_workspace @@ -5611,8 +5615,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 1, + multi_workspace.workspace(), + &workspace2, "Should have switched to workspace 2 which contains /dir2" ); let active_item = multi_workspace @@ -5660,8 +5664,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 0, + multi_workspace.workspace(), + &workspace1, "Should have switched back to workspace 0 which contains /dir1" ); let active_item = multi_workspace @@ -5711,6 +5715,11 @@ mod tests { let project = project1.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window1 + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); @@ -5737,6 +5746,11 @@ mod tests { let project = project3.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window2 + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); assert_eq!(cx.windows().len(), 2); @@ -5771,7 +5785,7 @@ mod tests { // Verify workspace1_1 is active window1 .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspace(), &workspace1_1); }) .unwrap(); @@ -5837,7 +5851,7 @@ mod tests { // Verify workspace1_1 is still active (not workspace1_2 with dirty item) window1 .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspace(), &workspace1_1); }) .unwrap(); @@ -5848,8 +5862,8 @@ mod tests { window1 .read_with(cx, |multi_workspace, _| { assert_eq!( - multi_workspace.active_workspace_index(), - 1, + multi_workspace.workspace(), + &workspace1_2, "Case 2: Non-active workspace should be activated when it has dirty item" ); }) @@ -6002,6 +6016,12 @@ mod tests { .await .expect("failed to open first workspace"); + window_a + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); + window_a .update(cx, |multi_workspace, window, cx| { multi_workspace.open_project(vec![dir2.into()], OpenMode::Activate, window, cx) @@ -6028,13 +6048,19 @@ mod tests { .await .expect("failed to open third workspace"); + window_b + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); + // Currently dir2 is active because it was added last. // So, switch window_a's active workspace to dir1 (index 0). // This sets up a non-trivial assertion: after restore, dir1 should // still be active rather than whichever workspace happened to restore last. window_a .update(cx, |multi_workspace, window, cx| { - let workspace = multi_workspace.workspaces()[0].clone(); + let workspace = multi_workspace.workspaces().next().unwrap().clone(); multi_workspace.activate(workspace, window, cx); }) .unwrap(); @@ -6150,7 +6176,7 @@ mod tests { ProjectGroupKey::new(None, PathList::new(&[dir2])), ] ); - assert_eq!(mw.workspaces().len(), 1); + assert_eq!(mw.workspaces().count(), 1); }) .unwrap(); @@ -6161,7 +6187,7 @@ mod tests { mw.project_group_keys().cloned().collect::>(), vec![ProjectGroupKey::new(None, PathList::new(&[dir3]))] ); - assert_eq!(mw.workspaces().len(), 1); + assert_eq!(mw.workspaces().count(), 1); }) .unwrap(); } From 3a5dc8ef6aaed33f7c57f9eeab36bbc71a19bc59 Mon Sep 17 00:00:00 2001 From: Yoni Sirote <96873891+yonisirote@users.noreply.github.com> Date: Tue, 7 Apr 2026 01:08:49 +0300 Subject: [PATCH 17/21] Restore ACP slash commands when reopening threads (#53209) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the UI/UX checklist - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #52239 ## Summary Note - The code in this fix is AI generated. OpenCode ACP chats lost slash-command support after Zed restarted and a thread was reopened. The UI no longer showed `/ for commands`, and slash commands like `/help` were treated as unsupported. ## Root Cause ACP available commands were treated as transient UI state instead of durable thread state. - `AcpThread` handled `AvailableCommandsUpdate` but did not retain the commands on the thread - restored thread views rebuilt `SessionCapabilities` with an empty `available_commands` list - the message-editor placeholder started in the wrong state for restored threads - live command updates could be applied to the wrong thread view ## Fix - persisted `available_commands` on `AcpThread` - restored `SessionCapabilities` from thread state - reused the same command augmentation logic for restore and live update paths - updated live command handling to target the correct thread view - initialized the message-editor placeholder from current command availability - added a regression test for the restore path ## Verification - `cargo test -p agent_ui conversation_view::tests::test_restored_threads_keep_available_commands -- --exact --nocapture` - `./script/clippy -p agent_ui --tests` Release Notes: - Fixed ACP slash commands disappearing after reopening restored threads. --------- Co-authored-by: Ben Brandt --- crates/acp_thread/src/acp_thread.rs | 11 +- crates/agent_ui/src/conversation_view.rs | 214 +++++++++++++++--- .../src/conversation_view/thread_view.rs | 3 +- 3 files changed, 191 insertions(+), 37 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 0bcb8254c8b8123eef3faaa913bb360de8dcc76d..36c9fb40c4a573e09da05618a29c1898cced60ad 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1032,6 +1032,7 @@ pub struct AcpThread { connection: Rc, token_usage: Option, prompt_capabilities: acp::PromptCapabilities, + available_commands: Vec, _observe_prompt_capabilities: Task>, terminals: HashMap>, pending_terminal_output: HashMap>>, @@ -1220,6 +1221,7 @@ impl AcpThread { session_id, token_usage: None, prompt_capabilities, + available_commands: Vec::new(), _observe_prompt_capabilities: task, terminals: HashMap::default(), pending_terminal_output: HashMap::default(), @@ -1239,6 +1241,10 @@ impl AcpThread { self.prompt_capabilities.clone() } + pub fn available_commands(&self) -> &[acp::AvailableCommand] { + &self.available_commands + } + pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> { self.draft_prompt.as_deref() } @@ -1419,7 +1425,10 @@ impl AcpThread { acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate { available_commands, .. - }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)), + }) => { + self.available_commands = available_commands.clone(); + cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)); + } acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate { current_mode_id, .. diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index 149ed2e2fc0f9b22244e0d69deebf5aa7bb7d4c5..7c9acfdf27d5b750afe4b8817af7f657f5fcdecc 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -812,7 +812,7 @@ impl ConversationView { let agent_id = self.agent.agent_id(); let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new( thread.read(cx).prompt_capabilities(), - vec![], + thread.read(cx).available_commands().to_vec(), ))); let action_log = thread.read(cx).action_log().clone(); @@ -1448,40 +1448,24 @@ impl ConversationView { self.emit_token_limit_telemetry_if_needed(thread, cx); } AcpThreadEvent::AvailableCommandsUpdated(available_commands) => { - let mut available_commands = available_commands.clone(); + if let Some(thread_view) = self.thread_view(&thread_id) { + let has_commands = !available_commands.is_empty(); - if thread - .read(cx) - .connection() - .auth_methods() - .iter() - .any(|method| method.id().0.as_ref() == "claude-login") - { - available_commands.push(acp::AvailableCommand::new("login", "Authenticate")); - available_commands.push(acp::AvailableCommand::new("logout", "Authenticate")); - } - - let has_commands = !available_commands.is_empty(); - if let Some(active) = self.active_thread() { - active.update(cx, |active, _cx| { - active - .session_capabilities - .write() - .set_available_commands(available_commands); - }); - } - - let agent_display_name = self - .agent_server_store - .read(cx) - .agent_display_name(&self.agent.agent_id()) - .unwrap_or_else(|| self.agent.agent_id().0.to_string().into()); + let agent_display_name = self + .agent_server_store + .read(cx) + .agent_display_name(&self.agent.agent_id()) + .unwrap_or_else(|| self.agent.agent_id().0.to_string().into()); - if let Some(active) = self.active_thread() { let new_placeholder = placeholder_text(agent_display_name.as_ref(), has_commands); - active.update(cx, |active, cx| { - active.message_editor.update(cx, |editor, cx| { + + thread_view.update(cx, |thread_view, cx| { + thread_view + .session_capabilities + .write() + .set_available_commands(available_commands.clone()); + thread_view.message_editor.update(cx, |editor, cx| { editor.set_placeholder_text(&new_placeholder, window, cx); }); }); @@ -2348,9 +2332,9 @@ impl ConversationView { } } + #[cfg(feature = "audio")] fn play_notification_sound(&self, window: &Window, cx: &mut App) { - let settings = AgentSettings::get_global(cx); - let _visible = window.is_window_active() + let visible = window.is_window_active() && if let Some(mw) = window.root::().flatten() { self.agent_panel_visible(&mw, cx) } else { @@ -2358,8 +2342,8 @@ impl ConversationView { .upgrade() .is_some_and(|workspace| AgentPanel::is_visible(&workspace, cx)) }; - #[cfg(feature = "audio")] - if settings.play_sound_when_agent_done.should_play(_visible) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done.should_play(visible) { Audio::play_sound(Sound::AgentDone, cx); } } @@ -2989,6 +2973,166 @@ pub(crate) mod tests { }); } + #[derive(Clone)] + struct RestoredAvailableCommandsConnection; + + impl AgentConnection for RestoredAvailableCommandsConnection { + fn agent_id(&self) -> AgentId { + AgentId::new("restored-available-commands") + } + + fn telemetry_id(&self) -> SharedString { + "restored-available-commands".into() + } + + fn new_session( + self: Rc, + project: Entity, + _work_dirs: PathList, + cx: &mut App, + ) -> Task>> { + let thread = build_test_thread( + self, + project, + "RestoredAvailableCommandsConnection", + SessionId::new("new-session"), + cx, + ); + Task::ready(Ok(thread)) + } + + fn supports_load_session(&self) -> bool { + true + } + + fn load_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + _work_dirs: PathList, + _title: Option, + cx: &mut App, + ) -> Task>> { + let thread = build_test_thread( + self, + project, + "RestoredAvailableCommandsConnection", + session_id, + cx, + ); + + thread + .update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AvailableCommandsUpdate( + acp::AvailableCommandsUpdate::new(vec![acp::AvailableCommand::new( + "help", "Get help", + )]), + ), + cx, + ) + }) + .expect("available commands update should succeed"); + + Task::ready(Ok(thread)) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(())) + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} + + fn into_any(self: Rc) -> Rc { + self + } + } + + #[gpui::test] + async fn test_restored_threads_keep_available_commands(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + + let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx))); + let connection_store = + cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx))); + + let conversation_view = cx.update(|window, cx| { + cx.new(|cx| { + ConversationView::new( + Rc::new(StubAgentServer::new(RestoredAvailableCommandsConnection)), + connection_store, + Agent::Custom { id: "Test".into() }, + Some(SessionId::new("restored-session")), + None, + None, + None, + workspace.downgrade(), + project, + Some(thread_store), + None, + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let message_editor = message_editor(&conversation_view, cx); + let editor = + message_editor.update(cx, |message_editor, _cx| message_editor.editor().clone()); + let placeholder = editor.update(cx, |editor, cx| editor.placeholder_text(cx)); + + active_thread(&conversation_view, cx).read_with(cx, |view, _cx| { + let available_commands = view + .session_capabilities + .read() + .available_commands() + .to_vec(); + assert_eq!(available_commands.len(), 1); + assert_eq!(available_commands[0].name.as_str(), "help"); + assert_eq!(available_commands[0].description.as_str(), "Get help"); + }); + + assert_eq!( + placeholder, + Some("Message Test — @ to include context, / for commands".to_string()) + ); + + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("/help", window, cx); + }); + + let contents_result = message_editor + .update(cx, |editor, cx| editor.contents(false, cx)) + .await; + + assert!(contents_result.is_ok()); + } + #[gpui::test] async fn test_resume_thread_uses_session_cwd_when_inside_project(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 25af09832f3473aa690c7b205e1b56bab86e9709..9f9b5dff00536953b76a50b65a4ab64e427bc554 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -344,7 +344,8 @@ impl ThreadView { ) -> Self { let id = thread.read(cx).session_id().clone(); - let placeholder = placeholder_text(agent_display_name.as_ref(), false); + let has_commands = !session_capabilities.read().available_commands().is_empty(); + let placeholder = placeholder_text(agent_display_name.as_ref(), has_commands); let history_subscription = history.as_ref().map(|h| { cx.observe(h, |this, history, cx| { From 9c5f3b10fdd0b029cdd983aeb3310f7d4bf91e6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Soares?= <37777652+Dnreikronos@users.noreply.github.com> Date: Mon, 6 Apr 2026 23:46:30 -0300 Subject: [PATCH 18/21] terminal_view: Reset cursor blink on `SendText` and `SendKeystroke` actions (#53171) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [ ] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #53115 Release Notes: - Fixed terminal cursor blink not resetting when navigating with action-bound keys (e.g., alt+left/right on macOS, alt+b/f on Linux) ## Demo ### Before the fix The cursor stays invisible after word-jumping because the blink cycle keeps going without resetting. https://github.com/user-attachments/assets/00dbdba6-d793-4a23-abcc-37887f4d1262 ### After the fix The cursor shows up at the new position right after each word-jump, then blinks again as expected. https://github.com/user-attachments/assets/48d5906c-4899-4f4a-adbd-5908ebea0cfb --- crates/terminal_view/src/terminal_view.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 3ecc6c844db834da91e2f24c3f0cf2d460b5f246..acccd6129f75ee2f5213fa359203220a7fee08c0 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -850,6 +850,7 @@ impl TerminalView { fn send_text(&mut self, text: &SendText, _: &mut Window, cx: &mut Context) { self.clear_bell(cx); + self.blink_manager.update(cx, BlinkManager::pause_blinking); self.terminal.update(cx, |term, _| { term.input(text.0.to_string().into_bytes()); }); @@ -858,6 +859,7 @@ impl TerminalView { fn send_keystroke(&mut self, text: &SendKeystroke, _: &mut Window, cx: &mut Context) { if let Some(keystroke) = Keystroke::parse(&text.0).log_err() { self.clear_bell(cx); + self.blink_manager.update(cx, BlinkManager::pause_blinking); self.process_keystroke(&keystroke, cx); } } From 46fc6938a6da938e381707a8c62fe5d0eb2a3d86 Mon Sep 17 00:00:00 2001 From: Sean Hagstrom Date: Mon, 6 Apr 2026 19:47:02 -0700 Subject: [PATCH 19/21] vim: Add editor setting for changing regex mode default in vim searches (#53092) Closes #48007 Release Notes: - Added editor setting for changing regex mode default in vim searches Summary: - Based on the report in #48007 and the discussion here https://github.com/zed-industries/zed/pull/48127#issuecomment-3838678903 - There was feedback mentioning that vim-mode needs to default vim-searches to use regex-mode (even when the editor regex-search setting is disabled). However, it was suggested that a vim search setting could be configured to adjust this behaviour. - In this PR a new vim setting was added to change whether vim-searches will use regex-mode by default, so now users can can configure vim-search to not use regex-mode when typing the `/` character (or using the vim search command). Screen Captures: https://github.com/user-attachments/assets/172669fb-ab78-41a1-9485-c973825543c5 --- assets/settings/default.json | 1 + .../settings_content/src/settings_content.rs | 1 + crates/settings_ui/src/page_data.rs | 20 +++++- crates/vim/src/normal/search.rs | 64 ++++++++++++++++++- crates/vim/src/vim.rs | 2 + docs/src/vim.md | 2 + 6 files changed, 88 insertions(+), 2 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 5e1eb0e68d2f8a17f89422597aa29b99516333e8..63e906e3b11206fc458f8d7353f3ecba0abeb825 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -2417,6 +2417,7 @@ "toggle_relative_line_numbers": false, "use_system_clipboard": "always", "use_smartcase_find": false, + "use_regex_search": true, "gdefault": false, "highlight_on_yank_duration": 200, "custom_digraphs": {}, diff --git a/crates/settings_content/src/settings_content.rs b/crates/settings_content/src/settings_content.rs index 325e86e9e3af0fb888c2691f4be1b0fdeb06dfb4..6c60a7010f7cfc5b4fadf9a8cc386fe6e3267abc 100644 --- a/crates/settings_content/src/settings_content.rs +++ b/crates/settings_content/src/settings_content.rs @@ -763,6 +763,7 @@ pub struct VimSettingsContent { pub toggle_relative_line_numbers: Option, pub use_system_clipboard: Option, pub use_smartcase_find: Option, + pub use_regex_search: Option, /// When enabled, the `:substitute` command replaces all matches in a line /// by default. The 'g' flag then toggles this behavior., pub gdefault: Option, diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 20a0c53534988a873b3b3f6e393eefd5bb0b3f7c..9978832c05bb29c97f118fccbe301214d81fa0c6 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -2447,7 +2447,7 @@ fn editor_page() -> SettingsPage { ] } - fn vim_settings_section() -> [SettingsPageItem; 12] { + fn vim_settings_section() -> [SettingsPageItem; 13] { [ SettingsPageItem::SectionHeader("Vim"), SettingsPageItem::SettingItem(SettingItem { @@ -2556,6 +2556,24 @@ fn editor_page() -> SettingsPage { metadata: None, files: USER, }), + SettingsPageItem::SettingItem(SettingItem { + title: "Regex Search", + description: "Use regex search by default in Vim search.", + field: Box::new(SettingField { + json_path: Some("vim.use_regex_search"), + pick: |settings_content| { + settings_content.vim.as_ref()?.use_regex_search.as_ref() + }, + write: |settings_content, value| { + settings_content + .vim + .get_or_insert_default() + .use_regex_search = value; + }, + }), + metadata: None, + files: USER, + }), SettingsPageItem::SettingItem(SettingItem { title: "Cursor Shape - Normal Mode", description: "Cursor shape for normal mode.", diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index 6a8394f44710b7e241b7ba38f4913899a5afbce6..22c453c877ec89fdbf432d19d89167285b78b12f 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -245,7 +245,7 @@ impl Vim { search_bar.set_replacement(None, cx); let mut options = SearchOptions::NONE; - if action.regex { + if action.regex && VimSettings::get_global(cx).use_regex_search { options |= SearchOptions::REGEX; } if action.backwards { @@ -1446,4 +1446,66 @@ mod test { // The cursor should be at the match location on line 3 (row 2). cx.assert_state("hello world\nfoo bar\nhello ˇagain\n", Mode::Normal); } + + #[gpui::test] + async fn test_vim_search_respects_search_settings(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings.vim.get_or_insert_default().use_regex_search = Some(false); + }); + }); + + cx.set_state("ˇcontent", Mode::Normal); + cx.simulate_keystrokes("/"); + cx.run_until_parked(); + + // Verify search options are set from settings + let search_bar = cx.workspace(|workspace, _, cx| { + workspace + .active_pane() + .read(cx) + .toolbar() + .read(cx) + .item_of_type::() + .expect("Buffer search bar should be active") + }); + + cx.update_entity(search_bar, |bar, _window, _cx| { + assert!( + !bar.has_search_option(search::SearchOptions::REGEX), + "Vim search open without regex mode" + ); + }); + + cx.simulate_keystrokes("escape"); + cx.run_until_parked(); + + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings.vim.get_or_insert_default().use_regex_search = Some(true); + }); + }); + + cx.simulate_keystrokes("/"); + cx.run_until_parked(); + + let search_bar = cx.workspace(|workspace, _, cx| { + workspace + .active_pane() + .read(cx) + .toolbar() + .read(cx) + .item_of_type::() + .expect("Buffer search bar should be active") + }); + + cx.update_entity(search_bar, |bar, _window, _cx| { + assert!( + bar.has_search_option(search::SearchOptions::REGEX), + "Vim search opens with regex mode" + ); + }); + } } diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 6e1849340f17b776a34546dd9a118dc55e8dab84..a66111cae1576744c4c51d717984d67c12fc8235 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -2141,6 +2141,7 @@ struct VimSettings { pub toggle_relative_line_numbers: bool, pub use_system_clipboard: settings::UseSystemClipboard, pub use_smartcase_find: bool, + pub use_regex_search: bool, pub gdefault: bool, pub custom_digraphs: HashMap>, pub highlight_on_yank_duration: u64, @@ -2227,6 +2228,7 @@ impl Settings for VimSettings { toggle_relative_line_numbers: vim.toggle_relative_line_numbers.unwrap(), use_system_clipboard: vim.use_system_clipboard.unwrap(), use_smartcase_find: vim.use_smartcase_find.unwrap(), + use_regex_search: vim.use_regex_search.unwrap(), gdefault: vim.gdefault.unwrap(), custom_digraphs: vim.custom_digraphs.unwrap(), highlight_on_yank_duration: vim.highlight_on_yank_duration.unwrap(), diff --git a/docs/src/vim.md b/docs/src/vim.md index 1798f16a93244f2694b30ffa70119da1e4498fdc..8e93edff081681a3e094c811e2d76822766ef67e 100644 --- a/docs/src/vim.md +++ b/docs/src/vim.md @@ -562,6 +562,7 @@ You can change the following settings to modify vim mode's behavior: | use_system_clipboard | Determines how system clipboard is used:
  • "always": use for all operations
  • "never": only use when explicitly specified
  • "on_yank": use for yank operations
| "always" | | use_multiline_find | deprecated | | use_smartcase_find | If `true`, `f` and `t` motions are case-insensitive when the target letter is lowercase. | false | +| use_regex_search | If `true`, then vim search will use regex mode | true | | gdefault | If `true`, the `:substitute` command replaces all matches in a line by default (as if `g` flag was given). The `g` flag then toggles this, replacing only the first match. | false | | toggle_relative_line_numbers | If `true`, line numbers are relative in normal mode and absolute in insert mode, giving you the best of both options. | false | | custom_digraphs | An object that allows you to add custom digraphs. Read below for an example. | {} | @@ -587,6 +588,7 @@ Here's an example of these settings changed: "default_mode": "insert", "use_system_clipboard": "never", "use_smartcase_find": true, + "use_regex_search": true, "gdefault": true, "toggle_relative_line_numbers": true, "highlight_on_yank_duration": 50, From 092c7058a49e4386afbb935740f0c220c3cafde0 Mon Sep 17 00:00:00 2001 From: Juan Pablo Briones Date: Mon, 6 Apr 2026 22:51:54 -0400 Subject: [PATCH 20/21] vim: Fix % for multiline comments and preprocessor directives (#53148) Implements: [49806](https://github.com/zed-industries/zed/discussions/49806) Closes: [24820](https://github.com/zed-industries/zed/issues/24820) Zeds impl of `%` didn't handle preprocessor directives and multiline To implement this feature for multiline comment, a tree-sitter query is used to check if we are inside a comment range and then replicate the logic used in brackets. For preprocessor directives using `TextObjects` wasn't a option, so it was implemented through a text based query that searches for the next preprocessor directives. Using text based queries might not be the best for performance, so I'm open to any suggestions. Release Notes: - Fixed vim's matching '%' to handle multiline comments `/* */` and preprocessor directives `#if #else #endif`. --- crates/vim/src/motion.rs | 261 +++++++++++++++++- .../vim/test_data/test_matching_comments.json | 10 + ...test_matching_preprocessor_directives.json | 18 ++ 3 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 crates/vim/test_data/test_matching_comments.json create mode 100644 crates/vim/test_data/test_matching_preprocessor_directives.json diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index 6bf2afd09ae07ff8453a481a8d6e6e6a254e670f..6e992704f54bf7aba3cc775d906a90281234dbd0 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -7,7 +7,7 @@ use editor::{ }, }; use gpui::{Action, Context, Window, actions, px}; -use language::{CharKind, Point, Selection, SelectionGoal}; +use language::{CharKind, Point, Selection, SelectionGoal, TextObject, TreeSitterOptions}; use multi_buffer::MultiBufferRow; use schemars::JsonSchema; use serde::Deserialize; @@ -2451,6 +2451,10 @@ fn find_matching_bracket_text_based( .take_while(|(_, char_offset)| *char_offset < line_range.end) .find_map(|(ch, char_offset)| get_bracket_pair(ch).map(|info| (info, char_offset))); + if bracket_info.is_none() { + return find_matching_c_preprocessor_directive(map, line_range); + } + let (open, close, is_opening) = bracket_info?.0; let bracket_offset = bracket_info?.1; @@ -2482,6 +2486,122 @@ fn find_matching_bracket_text_based( None } +fn find_matching_c_preprocessor_directive( + map: &DisplaySnapshot, + line_range: Range, +) -> Option { + let line_start = map + .buffer_chars_at(line_range.start) + .skip_while(|(c, _)| *c == ' ' || *c == '\t') + .map(|(c, _)| c) + .take(6) + .collect::(); + + if line_start.starts_with("#if") + || line_start.starts_with("#else") + || line_start.starts_with("#elif") + { + let mut depth = 0i32; + for (ch, char_offset) in map.buffer_chars_at(line_range.end) { + if ch != '\n' { + continue; + } + let mut line_offset = char_offset + '\n'.len_utf8(); + + // Skip leading whitespace + map.buffer_chars_at(line_offset) + .take_while(|(c, _)| *c == ' ' || *c == '\t') + .for_each(|(_, _)| line_offset += 1); + + // Check what directive starts the next line + let next_line_start = map + .buffer_chars_at(line_offset) + .map(|(c, _)| c) + .take(6) + .collect::(); + + if next_line_start.starts_with("#if") { + depth += 1; + } else if next_line_start.starts_with("#endif") { + if depth > 0 { + depth -= 1; + } else { + return Some(line_offset); + } + } else if next_line_start.starts_with("#else") || next_line_start.starts_with("#elif") { + if depth == 0 { + return Some(line_offset); + } + } + } + } else if line_start.starts_with("#endif") { + let mut depth = 0i32; + for (ch, char_offset) in + map.reverse_buffer_chars_at(line_range.start.saturating_sub_usize(1)) + { + let mut line_offset = if char_offset == MultiBufferOffset(0) { + MultiBufferOffset(0) + } else if ch != '\n' { + continue; + } else { + char_offset + '\n'.len_utf8() + }; + + // Skip leading whitespace + map.buffer_chars_at(line_offset) + .take_while(|(c, _)| *c == ' ' || *c == '\t') + .for_each(|(_, _)| line_offset += 1); + + // Check what directive starts this line + let line_start = map + .buffer_chars_at(line_offset) + .skip_while(|(c, _)| *c == ' ' || *c == '\t') + .map(|(c, _)| c) + .take(6) + .collect::(); + + if line_start.starts_with("\n\n") { + // empty line + continue; + } else if line_start.starts_with("#endif") { + depth += 1; + } else if line_start.starts_with("#if") { + if depth > 0 { + depth -= 1; + } else { + return Some(line_offset); + } + } + } + } + None +} + +fn comment_delimiter_pair( + map: &DisplaySnapshot, + offset: MultiBufferOffset, +) -> Option<(Range, Range)> { + let snapshot = map.buffer_snapshot(); + snapshot + .text_object_ranges(offset..offset, TreeSitterOptions::default()) + .find_map(|(range, obj)| { + if !matches!(obj, TextObject::InsideComment | TextObject::AroundComment) + || !range.contains(&offset) + { + return None; + } + + let mut chars = snapshot.chars_at(range.start); + if (Some('/'), Some('*')) != (chars.next(), chars.next()) { + return None; + } + + let open_range = range.start..range.start + 2usize; + let close_range = range.end - 2..range.end; + Some((open_range, close_range)) + }) +} + fn matching( map: &DisplaySnapshot, display_point: DisplayPoint, @@ -2609,6 +2729,32 @@ fn matching( continue; } + if let Some((open_range, close_range)) = comment_delimiter_pair(map, offset) { + if open_range.contains(&offset) { + return close_range.start.to_display_point(map); + } + + if close_range.contains(&offset) { + return open_range.start.to_display_point(map); + } + + let open_candidate = (open_range.start >= offset + && line_range.contains(&open_range.start)) + .then_some((open_range.start.saturating_sub(offset), close_range.start)); + + let close_candidate = (close_range.start >= offset + && line_range.contains(&close_range.start)) + .then_some((close_range.start.saturating_sub(offset), open_range.start)); + + if let Some((_, destination)) = [open_candidate, close_candidate] + .into_iter() + .flatten() + .min_by_key(|(distance, _)| *distance) + { + return destination.to_display_point(map); + } + } + closest_pair_destination .map(|destination| destination.to_display_point(map)) .unwrap_or_else(|| { @@ -3497,6 +3643,119 @@ mod test { ); } + #[gpui::test] + async fn test_matching_comments(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {r"ˇ/* + this is a comment + */"}) + .await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"/* + this is a comment + ˇ*/"}); + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"ˇ/* + this is a comment + */"}); + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"/* + this is a comment + ˇ*/"}); + + cx.set_shared_state("ˇ// comment").await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq("ˇ// comment"); + } + + #[gpui::test] + async fn test_matching_preprocessor_directives(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {r"#ˇif + + #else + + #endif + "}) + .await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"#if + + ˇ#else + + #endif + "}); + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"#if + + #else + + ˇ#endif + "}); + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"ˇ#if + + #else + + #endif + "}); + + cx.set_shared_state(indoc! {r" + #ˇif + #if + + #else + + #endif + + #else + #endif + "}) + .await; + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r" + #if + #if + + #else + + #endif + + ˇ#else + #endif + "}); + + cx.simulate_shared_keystrokes("% %").await; + cx.shared_state().await.assert_eq(indoc! {r" + ˇ#if + #if + + #else + + #endif + + #else + #endif + "}); + cx.simulate_shared_keystrokes("j % % %").await; + cx.shared_state().await.assert_eq(indoc! {r" + #if + ˇ#if + + #else + + #endif + + #else + #endif + "}); + } + #[gpui::test] async fn test_unmatched_forward(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await; diff --git a/crates/vim/test_data/test_matching_comments.json b/crates/vim/test_data/test_matching_comments.json new file mode 100644 index 0000000000000000000000000000000000000000..7fcf5e46e1ea16f2be794ff76b583242b33aabc0 --- /dev/null +++ b/crates/vim/test_data/test_matching_comments.json @@ -0,0 +1,10 @@ +{"Put":{"state":"ˇ/*\n this is a comment\n*/"}} +{"Key":"%"} +{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"ˇ/*\n this is a comment\n*/","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}} +{"Put":{"state":"ˇ// comment"}} +{"Key":"%"} +{"Get":{"state":"ˇ// comment","mode":"Normal"}} diff --git a/crates/vim/test_data/test_matching_preprocessor_directives.json b/crates/vim/test_data/test_matching_preprocessor_directives.json new file mode 100644 index 0000000000000000000000000000000000000000..9f0bd9792ee8dad5029f4ecaf325c231755530e1 --- /dev/null +++ b/crates/vim/test_data/test_matching_preprocessor_directives.json @@ -0,0 +1,18 @@ +{"Put":{"state":"#ˇif\n\n#else\n\n#endif\n"}} +{"Key":"%"} +{"Get":{"state":"#if\n\nˇ#else\n\n#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"#if\n\n#else\n\nˇ#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"ˇ#if\n\n#else\n\n#endif\n","mode":"Normal"}} +{"Put":{"state":"#ˇif\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n"}} +{"Key":"%"} +{"Get":{"state":"#if\n #if\n\n #else\n\n #endif\n\nˇ#else\n#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Key":"%"} +{"Get":{"state":"ˇ#if\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}} +{"Key":"j"} +{"Key":"%"} +{"Key":"%"} +{"Key":"%"} +{"Get":{"state":"#if\n ˇ#if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}} From 2aa1559080fe3fcb919415ef75fcccaf85eaa017 Mon Sep 17 00:00:00 2001 From: Markos Narinian Date: Tue, 7 Apr 2026 06:55:12 +0300 Subject: [PATCH 21/21] agent_ui: Add padding to markdown output in card layout (#53194) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #53191 Release Notes: - Added padding to markdown output when rendered in card layout. Before: Screenshot 2026-04-01 at 3 32 After: Screenshot 2026-04-05 at 4 47 38 PM --- crates/agent_ui/src/conversation_view/thread_view.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 9f9b5dff00536953b76a50b65a4ab64e427bc554..685621eb3c93632f1e7410bbbad22b623d5e18c7 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -7390,9 +7390,8 @@ impl ThreadView { .gap_2() .map(|this| { if card_layout { - this.when(context_ix > 0, |this| { - this.pt_2() - .border_t_1() + this.p_2().when(context_ix > 0, |this| { + this.border_t_1() .border_color(self.tool_card_border_color(cx)) }) } else {