From c357dc25fc6e611016e1bee8a74f3f6f9e57bbdc Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:44:48 -0300 Subject: [PATCH 01/81] git_ui: Clean up the commit view UI (#44162) --- crates/editor/src/editor.rs | 5 + crates/editor/src/element.rs | 22 ++-- crates/git_ui/src/commit_view.rs | 205 +++++++++++++++++-------------- 3 files changed, 132 insertions(+), 100 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 4b352e2d8298f3c9ae2c0d38bd6b443d62a61996..0de2dc8423b39ab2b336adb3cb17f79cc4a8f6e7 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -20090,6 +20090,11 @@ impl Editor { self.show_indent_guides } + pub fn disable_indent_guides(&mut self) -> Option { + self.show_indent_guides = Some(false); + self.show_indent_guides + } + pub fn toggle_line_numbers( &mut self, _: &ToggleLineNumbers, diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 3319af92eb04015bd3bd01760235e3dba0047975..fb9dc31a94441c81bccedfea66e2881acaf7ed82 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3915,6 +3915,8 @@ impl EditorElement { ) -> impl IntoElement { let editor = self.editor.read(cx); let multi_buffer = editor.buffer.read(cx); + let is_read_only = self.editor.read(cx).read_only(cx); + let file_status = multi_buffer .all_diff_hunks_expanded() .then(|| editor.status_for_buffer_id(for_excerpt.buffer_id, cx)) @@ -3967,7 +3969,7 @@ impl EditorElement { .gap_1p5() .when(is_sticky, |el| el.shadow_md()) .border_1() - .map(|div| { + .map(|border| { let border_color = if is_selected && is_folded && focus_handle.contains_focused(window, cx) @@ -3976,7 +3978,7 @@ impl EditorElement { } else { colors.border }; - div.border_color(border_color) + border.border_color(border_color) }) .bg(colors.editor_subheader_background) .hover(|style| style.bg(colors.element_hover)) @@ -4056,13 +4058,15 @@ impl EditorElement { }) .take(1), ) - .child( - h_flex() - .size_3() - .justify_center() - .flex_shrink_0() - .children(indicator), - ) + .when(!is_read_only, |this| { + this.child( + h_flex() + .size_3() + .justify_center() + .flex_shrink_0() + .children(indicator), + ) + }) .child( h_flex() .cursor_pointer() diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs index 31ac8139a63be218f652204ebe29d43e526c5a02..8a4504c1873193e81658c19c6b1115a9212e7760 100644 --- a/crates/git_ui/src/commit_view.rs +++ b/crates/git_ui/src/commit_view.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle}; -use editor::{Addon, Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer}; +use editor::{Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer}; use git::repository::{CommitDetails, CommitDiff, RepoPath}; use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url}; use gpui::{ @@ -11,9 +11,8 @@ use gpui::{ }; use language::{ Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, ReplicaId, Rope, - TextBuffer, ToPoint, + TextBuffer, }; -use multi_buffer::ExcerptInfo; use multi_buffer::PathKey; use project::{Project, WorktreeId, git_store::Repository}; use std::{ @@ -22,11 +21,9 @@ use std::{ sync::Arc, }; use theme::ActiveTheme; -use ui::{ - Avatar, Button, ButtonCommon, Clickable, Color, Icon, IconName, IconSize, Label, - LabelCommon as _, LabelSize, SharedString, div, h_flex, v_flex, -}; +use ui::{Avatar, DiffStat, Tooltip, prelude::*}; use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff}; +use workspace::item::TabTooltipContent; use workspace::{ Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, @@ -151,11 +148,11 @@ impl CommitView { let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); + editor.disable_inline_diagnostics(); + editor.disable_indent_guides(); editor.set_expand_all_diff_hunks(cx); - editor.register_addon(CommitViewAddon { - multibuffer: multibuffer.downgrade(), - }); + editor }); let commit_sha = Arc::::from(commit.sha.as_ref()); @@ -357,6 +354,41 @@ impl CommitView { .into_any() } + fn calculate_changed_lines(&self, cx: &App) -> (u32, u32) { + let snapshot = self.multibuffer.read(cx).snapshot(cx); + let mut total_additions = 0u32; + let mut total_deletions = 0u32; + + let mut seen_buffers = std::collections::HashSet::new(); + for (_, buffer, _) in snapshot.excerpts() { + let buffer_id = buffer.remote_id(); + if !seen_buffers.insert(buffer_id) { + continue; + } + + let Some(diff) = snapshot.diff_for_buffer_id(buffer_id) else { + continue; + }; + + let base_text = diff.base_text(); + + for hunk in diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer) { + let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row); + total_additions += added_rows; + + let base_start = base_text + .offset_to_point(hunk.diff_base_byte_range.start) + .row; + let base_end = base_text.offset_to_point(hunk.diff_base_byte_range.end).row; + let deleted_rows = base_end.saturating_sub(base_start); + + total_deletions += deleted_rows; + } + } + + (total_additions, total_deletions) + } + fn render_header(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let commit = &self.commit; let author_name = commit.author_name.clone(); @@ -380,46 +412,72 @@ impl CommitView { ) }); - v_flex() - .p_4() - .pl_0() - .gap_4() + let (additions, deletions) = self.calculate_changed_lines(cx); + + let commit_diff_stat = if additions > 0 || deletions > 0 { + Some(DiffStat::new( + "commit-diff-stat", + additions as usize, + deletions as usize, + )) + } else { + None + }; + + h_flex() .border_b_1() - .border_color(cx.theme().colors().border) + .border_color(cx.theme().colors().border_variant) + .child( + h_flex() + .w(self.editor.read(cx).last_gutter_dimensions().full_width()) + .justify_center() + .child(self.render_commit_avatar(&commit.sha, rems_from_px(48.), window, cx)), + ) .child( h_flex() + .py_4() + .pl_1() + .pr_4() + .w_full() .items_start() - .child( - h_flex() - .w(self.editor.read(cx).last_gutter_dimensions().full_width()) - .justify_center() - .child(self.render_commit_avatar( - &commit.sha, - gpui::rems(3.0), - window, - cx, - )), - ) + .justify_between() + .flex_wrap() .child( v_flex() - .gap_1() .child( h_flex() - .gap_3() - .items_baseline() + .gap_1() .child(Label::new(author_name).color(Color::Default)) .child( - Label::new(format!("commit {}", commit.sha)) - .color(Color::Muted), + Label::new(format!("Commit:{}", commit.sha)) + .color(Color::Muted) + .size(LabelSize::Small) + .truncate() + .buffer_font(cx), ), ) - .child(Label::new(date_string).color(Color::Muted)), + .child( + h_flex() + .gap_1p5() + .child( + Label::new(date_string) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .child( + Label::new("•") + .color(Color::Ignored) + .size(LabelSize::Small), + ) + .children(commit_diff_stat), + ), ) - .child(div().flex_grow()) .children(github_url.map(|url| { Button::new("view_on_github", "View on GitHub") .icon(IconName::Github) - .style(ui::ButtonStyle::Subtle) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) .on_click(move |_, _, cx| cx.open_url(&url)) })), ) @@ -714,55 +772,6 @@ impl language::File for GitBlob { // } // } -struct CommitViewAddon { - multibuffer: WeakEntity, -} - -impl Addon for CommitViewAddon { - fn render_buffer_header_controls( - &self, - excerpt: &ExcerptInfo, - _window: &Window, - cx: &App, - ) -> Option { - let multibuffer = self.multibuffer.upgrade()?; - let snapshot = multibuffer.read(cx).snapshot(cx); - let excerpts = snapshot.excerpts().collect::>(); - let current_idx = excerpts.iter().position(|(id, _, _)| *id == excerpt.id)?; - let (_, _, current_range) = &excerpts[current_idx]; - - let start_row = current_range.context.start.to_point(&excerpt.buffer).row; - - let prev_end_row = if current_idx > 0 { - let (_, prev_buffer, prev_range) = &excerpts[current_idx - 1]; - if prev_buffer.remote_id() == excerpt.buffer_id { - prev_range.context.end.to_point(&excerpt.buffer).row - } else { - 0 - } - } else { - 0 - }; - - let skipped_lines = start_row.saturating_sub(prev_end_row); - - if skipped_lines > 0 { - Some( - Label::new(format!("{} unchanged lines", skipped_lines)) - .color(Color::Muted) - .size(LabelSize::Small) - .into_any_element(), - ) - } else { - None - } - } - - fn to_any(&self) -> &dyn Any { - self - } -} - async fn build_buffer( mut text: String, blob: Arc, @@ -865,13 +874,28 @@ impl Item for CommitView { fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { let short_sha = self.commit.sha.get(0..7).unwrap_or(&*self.commit.sha); let subject = truncate_and_trailoff(self.commit.message.split('\n').next().unwrap(), 20); - format!("{short_sha} - {subject}").into() + format!("{short_sha} — {subject}").into() } - fn tab_tooltip_text(&self, _: &App) -> Option { + fn tab_tooltip_content(&self, _: &App) -> Option { let short_sha = self.commit.sha.get(0..16).unwrap_or(&*self.commit.sha); let subject = self.commit.message.split('\n').next().unwrap(); - Some(format!("{short_sha} - {subject}").into()) + + Some(TabTooltipContent::Custom(Box::new(Tooltip::element({ + let subject = subject.to_string(); + let short_sha = short_sha.to_string(); + + move |_, _| { + v_flex() + .child(Label::new(subject.clone())) + .child( + Label::new(short_sha.clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any_element() + } + })))) } fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) { @@ -988,12 +1012,11 @@ impl Item for CommitView { impl Render for CommitView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let is_stash = self.stash.is_some(); - div() + + v_flex() .key_context(if is_stash { "StashDiff" } else { "CommitDiff" }) - .bg(cx.theme().colors().editor_background) - .flex() - .flex_col() .size_full() + .bg(cx.theme().colors().editor_background) .child(self.render_header(window, cx)) .child(div().flex_grow().child(self.editor.clone())) } @@ -1013,7 +1036,7 @@ impl EventEmitter for CommitViewToolbar {} impl Render for CommitViewToolbar { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div() + div().hidden() } } From 07af011eb447a1b8afc2ad490a77da33fa76fb33 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 4 Dec 2025 18:14:10 +0100 Subject: [PATCH 02/81] worktree: Fix git ignored directories dropping their contents when they are refreshed (#44143) Closes https://github.com/zed-industries/zed/issues/38653 Release Notes: - Fixed git ignored directories appearing as empty when their content changes on windows Co-authored by: Smit Barmase --- crates/worktree/src/worktree.rs | 44 +++++-- crates/worktree/src/worktree_tests.rs | 169 ++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 12 deletions(-) diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 942e692a020049b102a0d810bfbf1a9074962611..5d1baceb2cebcadb54f5b47f357470861bb5b964 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -428,7 +428,7 @@ impl Worktree { let mut entry = Entry::new( RelPath::empty().into(), &metadata, - &next_entry_id, + ProjectEntryId::new(&next_entry_id), snapshot.root_char_bag, None, ); @@ -2736,13 +2736,30 @@ impl BackgroundScannerState { } } - async fn insert_entry( + fn entry_id_for( &mut self, - mut entry: Entry, - fs: &dyn Fs, - watcher: &dyn Watcher, - ) -> Entry { - self.reuse_entry_id(&mut entry); + next_entry_id: &AtomicUsize, + path: &RelPath, + metadata: &fs::Metadata, + ) -> ProjectEntryId { + // If an entry with the same inode was removed from the worktree during this scan, + // then it *might* represent the same file or directory. But the OS might also have + // re-used the inode for a completely different file or directory. + // + // Conditionally reuse the old entry's id: + // * if the mtime is the same, the file was probably been renamed. + // * if the path is the same, the file may just have been updated + if let Some(removed_entry) = self.removed_entries.remove(&metadata.inode) { + if removed_entry.mtime == Some(metadata.mtime) || *removed_entry.path == *path { + return removed_entry.id; + } + } else if let Some(existing_entry) = self.snapshot.entry_for_path(path) { + return existing_entry.id; + } + ProjectEntryId::new(next_entry_id) + } + + async fn insert_entry(&mut self, entry: Entry, fs: &dyn Fs, watcher: &dyn Watcher) -> Entry { let entry = self.snapshot.insert_entry(entry, fs); if entry.path.file_name() == Some(&DOT_GIT) { self.insert_git_repository(entry.path.clone(), fs, watcher) @@ -3389,13 +3406,13 @@ impl Entry { fn new( path: Arc, metadata: &fs::Metadata, - next_entry_id: &AtomicUsize, + id: ProjectEntryId, root_char_bag: CharBag, canonical_path: Option>, ) -> Self { let char_bag = char_bag_for_path(root_char_bag, &path); Self { - id: ProjectEntryId::new(next_entry_id), + id, kind: if metadata.is_dir { EntryKind::PendingDir } else { @@ -3682,8 +3699,10 @@ impl BackgroundScanner { .await; if ignore_stack.is_abs_path_ignored(root_abs_path.as_path(), true) { root_entry.is_ignored = true; + let mut root_entry = root_entry.clone(); + state.reuse_entry_id(&mut root_entry); state - .insert_entry(root_entry.clone(), self.fs.as_ref(), self.watcher.as_ref()) + .insert_entry(root_entry, self.fs.as_ref(), self.watcher.as_ref()) .await; } if root_entry.is_dir() { @@ -4289,7 +4308,7 @@ impl BackgroundScanner { let mut child_entry = Entry::new( child_path.clone(), &child_metadata, - &next_entry_id, + ProjectEntryId::new(&next_entry_id), root_char_bag, None, ); @@ -4476,10 +4495,11 @@ impl BackgroundScanner { .ignore_stack_for_abs_path(&abs_path, metadata.is_dir, self.fs.as_ref()) .await; let is_external = !canonical_path.starts_with(&root_canonical_path); + let entry_id = state.entry_id_for(self.next_entry_id.as_ref(), path, &metadata); let mut fs_entry = Entry::new( path.clone(), &metadata, - self.next_entry_id.as_ref(), + entry_id, state.snapshot.root_char_bag, if metadata.is_symlink { Some(canonical_path.as_path().to_path_buf().into()) diff --git a/crates/worktree/src/worktree_tests.rs b/crates/worktree/src/worktree_tests.rs index 50e2c6acae0013a75e346ba754f9c9f861196b58..08086118aacb37215227690532b927b3c7c46123 100644 --- a/crates/worktree/src/worktree_tests.rs +++ b/crates/worktree/src/worktree_tests.rs @@ -1533,6 +1533,175 @@ async fn test_create_dir_all_on_create_entry(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_create_file_in_expanded_gitignored_dir(cx: &mut TestAppContext) { + // Tests the behavior of our worktree refresh when a file in a gitignored directory + // is created. + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/root", + json!({ + ".gitignore": "ignored_dir\n", + "ignored_dir": { + "existing_file.txt": "existing content", + "another_file.txt": "another content", + }, + }), + ) + .await; + + let tree = Worktree::local( + Path::new("/root"), + true, + fs.clone(), + Default::default(), + &mut cx.to_async(), + ) + .await + .unwrap(); + + cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete()) + .await; + + tree.read_with(cx, |tree, _| { + let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap(); + assert!(ignored_dir.is_ignored); + assert_eq!(ignored_dir.kind, EntryKind::UnloadedDir); + }); + + tree.update(cx, |tree, cx| { + tree.load_file(rel_path("ignored_dir/existing_file.txt"), cx) + }) + .await + .unwrap(); + + tree.read_with(cx, |tree, _| { + let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap(); + assert!(ignored_dir.is_ignored); + assert_eq!(ignored_dir.kind, EntryKind::Dir); + + assert!( + tree.entry_for_path(rel_path("ignored_dir/existing_file.txt")) + .is_some() + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/another_file.txt")) + .is_some() + ); + }); + + let entry = tree + .update(cx, |tree, cx| { + tree.create_entry(rel_path("ignored_dir/new_file.txt").into(), false, None, cx) + }) + .await + .unwrap(); + assert!(entry.into_included().is_some()); + + cx.executor().run_until_parked(); + + tree.read_with(cx, |tree, _| { + let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap(); + assert!(ignored_dir.is_ignored); + assert_eq!( + ignored_dir.kind, + EntryKind::Dir, + "ignored_dir should still be loaded, not UnloadedDir" + ); + + assert!( + tree.entry_for_path(rel_path("ignored_dir/existing_file.txt")) + .is_some(), + "existing_file.txt should still be visible" + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/another_file.txt")) + .is_some(), + "another_file.txt should still be visible" + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/new_file.txt")) + .is_some(), + "new_file.txt should be visible" + ); + }); +} + +#[gpui::test] +async fn test_fs_event_for_gitignored_dir_does_not_lose_contents(cx: &mut TestAppContext) { + // Tests the behavior of our worktree refresh when a directory modification for a gitignored directory + // is triggered. + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/root", + json!({ + ".gitignore": "ignored_dir\n", + "ignored_dir": { + "file1.txt": "content1", + "file2.txt": "content2", + }, + }), + ) + .await; + + let tree = Worktree::local( + Path::new("/root"), + true, + fs.clone(), + Default::default(), + &mut cx.to_async(), + ) + .await + .unwrap(); + + cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete()) + .await; + + // Load a file to expand the ignored directory + tree.update(cx, |tree, cx| { + tree.load_file(rel_path("ignored_dir/file1.txt"), cx) + }) + .await + .unwrap(); + + tree.read_with(cx, |tree, _| { + let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap(); + assert_eq!(ignored_dir.kind, EntryKind::Dir); + assert!( + tree.entry_for_path(rel_path("ignored_dir/file1.txt")) + .is_some() + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/file2.txt")) + .is_some() + ); + }); + + fs.emit_fs_event("/root/ignored_dir", Some(fs::PathEventKind::Changed)); + tree.flush_fs_events(cx).await; + + tree.read_with(cx, |tree, _| { + let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap(); + assert_eq!( + ignored_dir.kind, + EntryKind::Dir, + "ignored_dir should still be loaded (Dir), not UnloadedDir" + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/file1.txt")) + .is_some(), + "file1.txt should still be visible after directory fs event" + ); + assert!( + tree.entry_for_path(rel_path("ignored_dir/file2.txt")) + .is_some(), + "file2.txt should still be visible after directory fs event" + ); + }); +} + #[gpui::test(iterations = 100)] async fn test_random_worktree_operations_during_initial_scan( cx: &mut TestAppContext, From 74a1b5d14db73d6bbf0524a2f67e425455bc801c Mon Sep 17 00:00:00 2001 From: Liffindra Angga Zaaldian <3760093+findrakecil@users.noreply.github.com> Date: Fri, 5 Dec 2025 01:04:06 +0700 Subject: [PATCH 03/81] Update PHP language server docs (#44001) Reformat document structure like other language docs, improve information flow, add missing requirements, and fix typos. Release Notes: - N/A --------- Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> --- docs/src/languages/php.md | 114 +++++++++++++++++++++++++++++--------- 1 file changed, 89 insertions(+), 25 deletions(-) diff --git a/docs/src/languages/php.md b/docs/src/languages/php.md index 73d5ecbf37eae6ab9b7e710c132025d217fe57bd..1a9f1cdadef394f1178fed87d4061eb0f3232cfd 100644 --- a/docs/src/languages/php.md +++ b/docs/src/languages/php.md @@ -2,34 +2,44 @@ PHP support is available through the [PHP extension](https://github.com/zed-extensions/php). -- Tree-sitter: https://github.com/tree-sitter/tree-sitter-php -- Language Servers: - - [phpactor](https://github.com/phpactor/phpactor) - - [intelephense](https://github.com/bmewburn/vscode-intelephense/) +- Tree-sitter: [tree-sitter/tree-sitter-php](https://github.com/tree-sitter/tree-sitter-php) +- Language Server: [phpactor/phpactor](https://github.com/phpactor/phpactor) +- Alternate Language Server: [bmewburn/vscode-intelephense](https://github.com/bmewburn/vscode-intelephense/) -## Choosing a language server +## Install PHP -The PHP extension offers both `phpactor` and `intelephense` language server support. +The PHP extension requires PHP to be installed and available in your `PATH`: -`phpactor` is enabled by default. +```sh +# macOS via Homebrew +brew install php -### Phpactor +# Debian/Ubuntu +sudo apt-get install php-cli -The Zed PHP Extension can install `phpactor` automatically but requires `php` to be installed and available in your path: +# CentOS 8+/RHEL +sudo dnf install php-cli -```sh -# brew install php # macOS -# sudo apt-get install php # Debian/Ubuntu -# yum install php # CentOS/RHEL -# pacman -S php # Arch Linux +# Arch Linux +sudo pacman -S php + +# check PHP path +## macOS and Linux which php + +## Windows +where php ``` +## Choosing a language server + +The PHP extension uses [LSP language servers](https://microsoft.github.io/language-server-protocol) with Phpactor as the default. If you want to use other language servers that support Zed (e.g. Intelephense or PHP Tools), make sure to follow the documentation on how to implement it. + ### Intelephense -[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/). +[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/buy). -To switch to `intelephense`, add the following to your `settings.json`: +To use Intelephense, add the following to your `settings.json`: ```json [settings] { @@ -41,7 +51,9 @@ To switch to `intelephense`, add the following to your `settings.json`: } ``` -To use the premium features, you can place your [licence.txt file](https://intelephense.com/faq.html) at `~/intelephense/licence.txt` inside your home directory. Alternatively, you can pass the licence key or a path to a file containing the licence key as an initialization option for the `intelephense` language server. To do this, add the following to your `settings.json`: +To use the premium features, you can place your license file inside your home directory at `~/intelephense/licence.txt` for macOS and Linux, or `%USERPROFILE%/intelephense/licence.txt` on Windows. + +Alternatively, you can pass the licence key or a path to a file containing the licence key as an initialization option. To do this, add the following to your `settings.json`: ```json [settings] { @@ -55,15 +67,67 @@ To use the premium features, you can place your [licence.txt file](https://intel } ``` +### PHP Tools + +[PHP Tools](https://www.devsense.com/) is a proprietary language server that offers free and premium features. You need to [purchase a license](https://www.devsense.com/en/purchase) to activate the premium features. + +To use PHP Tools, add the following to your `settings.json`: + +```json [settings] +{ + "languages": { + "PHP": { + "language_servers": ["phptools", "!intelephense", "!phpactor", "..."] + } + } +} +``` + +To use the premium features, you can add your license in `initialization_options` in your `settings.json`: + +```json [settings] +{ + "lsp": { + "phptools": { + "initialization_options": { + "0": "your_license_key" + } + } + } +} +``` + +or, set environment variable `DEVSENSE_PHP_LS_LICENSE` on `.env` file in your project. + +```env +DEVSENSE_PHP_LS_LICENSE="your_license_key" +``` + +Check out the documentation of [PHP Tools for Zed](https://docs.devsense.com/other/zed/) for more details. + +### Phpactor + +To use Phpactor instead of Intelephense or any other tools, add the following to your `settings.json`: + +```json [settings] +{ + "languages": { + "PHP": { + "language_servers": ["phpactor", "!intelephense", "!phptools", "..."] + } + } +} +``` + ## PHPDoc Zed supports syntax highlighting for PHPDoc comments. - Tree-sitter: [claytonrcarter/tree-sitter-phpdoc](https://github.com/claytonrcarter/tree-sitter-phpdoc) -## Setting up Xdebug +## Debugging -Zed’s PHP extension provides a debug adapter for PHP and Xdebug. The adapter name is `Xdebug`. Here a couple ways you can use it: +The PHP extension provides a debug adapter for PHP via Xdebug. There are several ways to use it: ```json [ @@ -83,10 +147,10 @@ Zed’s PHP extension provides a debug adapter for PHP and Xdebug. The adapter n ] ``` -In case you run into issues: +These are common troubleshooting tips, in case you run into issues: -- ensure that you have Xdebug installed for the version of PHP you’re running -- ensure that Xdebug is configured to run in `debug` mode -- ensure that Xdebug is actually starting a debugging session -- check that the host and port matches between Xdebug and Zed -- look at the diagnostics log by using the `xdebug_info()` function in the page you’re trying to debug +- Ensure that you have Xdebug installed for the version of PHP you’re running. +- Ensure that Xdebug is configured to run in `debug` mode. +- Ensure that Xdebug is actually starting a debugging session. +- Ensure that the host and port matches between Xdebug and Zed. +- Look at the diagnostics log by using the `xdebug_info()` function in the page you’re trying to debug. From d5ed9d3e3a96492c049a1ab50819f196ed255037 Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Thu, 4 Dec 2025 13:25:30 -0500 Subject: [PATCH 04/81] git: Don't call `git2::Repository::find_remote` for every blamed buffer (#44107) We already store the remote URLs for `origin` and `upstream` in the `RepositorySnapshot`, so just use that data. Follow-up to #44092. Release Notes: - N/A --- .../20221109000000_test_schema.sql | 2 ++ ...dd_remote_urls_to_project_repositories.sql | 2 ++ crates/collab/src/db/queries/projects.rs | 6 ++++ crates/collab/src/db/queries/rooms.rs | 2 ++ .../src/db/tables/project_repository.rs | 2 ++ crates/collab/src/tests/editor_tests.rs | 5 --- crates/editor/src/git/blame.rs | 22 ++++++++----- crates/git/src/blame.rs | 8 +---- crates/git/src/repository.rs | 31 ++++++------------- crates/project/src/git_store.rs | 19 +++++------- crates/proto/proto/git.proto | 4 ++- 11 files changed, 51 insertions(+), 52 deletions(-) create mode 100644 crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index a736ddfd1fe3334b1b847e820bd1816cb625ddca..32a2ed2e1331fc7b16f859accd895a7bce055804 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -121,6 +121,8 @@ CREATE TABLE "project_repositories" ( "merge_message" VARCHAR, "branch_summary" VARCHAR, "head_commit_details" VARCHAR, + "remote_upstream_url" VARCHAR, + "remote_origin_url" VARCHAR, PRIMARY KEY (project_id, id) ); diff --git a/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql b/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql new file mode 100644 index 0000000000000000000000000000000000000000..e1396de27d90fb2c872197d25198743d19be86f8 --- /dev/null +++ b/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql @@ -0,0 +1,2 @@ +ALTER TABLE "project_repositories" ADD COLUMN "remote_upstream_url" VARCHAR; +ALTER TABLE "project_repositories" ADD COLUMN "remote_origin_url" VARCHAR; diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index 51a0ef83323ec70675283d2fdec7ca1ad791b12d..6f1d8b884d15041eadaa9073a5bd99e5ed352502 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -362,6 +362,8 @@ impl Database { entry_ids: ActiveValue::set("[]".into()), head_commit_details: ActiveValue::set(None), merge_message: ActiveValue::set(None), + remote_upstream_url: ActiveValue::set(None), + remote_origin_url: ActiveValue::set(None), } }), ) @@ -511,6 +513,8 @@ impl Database { serde_json::to_string(&update.current_merge_conflicts).unwrap(), )), merge_message: ActiveValue::set(update.merge_message.clone()), + remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()), + remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()), }) .on_conflict( OnConflict::columns([ @@ -1005,6 +1009,8 @@ impl Database { is_last_update: true, merge_message: db_repository_entry.merge_message, stash_entries: Vec::new(), + remote_upstream_url: db_repository_entry.remote_upstream_url.clone(), + remote_origin_url: db_repository_entry.remote_origin_url.clone(), }); } } diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index f020b99b5f1030cfe9391498512258e6db249bac..eafb5cac44a510bf4ced0434a9b4adfadff0ebbc 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -796,6 +796,8 @@ impl Database { is_last_update: true, merge_message: db_repository.merge_message, stash_entries: Vec::new(), + remote_upstream_url: db_repository.remote_upstream_url.clone(), + remote_origin_url: db_repository.remote_origin_url.clone(), }); } } diff --git a/crates/collab/src/db/tables/project_repository.rs b/crates/collab/src/db/tables/project_repository.rs index eb653ecee37d48ce79e26450eb85d87dec411c1e..190ae8d79c54bb78daef4a1568ec75683eb0b0f2 100644 --- a/crates/collab/src/db/tables/project_repository.rs +++ b/crates/collab/src/db/tables/project_repository.rs @@ -22,6 +22,8 @@ pub struct Model { pub branch_summary: Option, // A JSON object representing the current Head commit values pub head_commit_details: Option, + pub remote_upstream_url: Option, + pub remote_origin_url: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 785a6457c8fdb57f84a8e7b5a8487f0ceae3d025..149a48db7439cc28e76fac5aae8b6e11f0837991 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -3518,7 +3518,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA .into_iter() .map(|(sha, message)| (sha.parse().unwrap(), message.into())) .collect(), - remote_url: Some("git@github.com:zed-industries/zed.git".to_string()), }; client_a.fs().set_blame_for_repo( Path::new(path!("/my-repo/.git")), @@ -3603,10 +3602,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() { let details = blame.details_for_entry(*buffer, entry).unwrap(); assert_eq!(details.message, format!("message for idx-{}", idx)); - assert_eq!( - details.permalink.unwrap().to_string(), - format!("https://github.com/zed-industries/zed/commit/{}", entry.sha) - ); } }); }); diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index 008630faef7cc1ccb3b9703e4b11c0b88b7cf17c..67df69aadab43a45c2941703e10bb81af2b8dd78 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -508,7 +508,19 @@ impl GitBlame { let buffer_edits = buffer.update(cx, |buffer, _| buffer.subscribe()); let blame_buffer = project.blame_buffer(&buffer, None, cx); - Some(async move { (id, snapshot, buffer_edits, blame_buffer.await) }) + let remote_url = project + .git_store() + .read(cx) + .repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx) + .and_then(|(repo, _)| { + repo.read(cx) + .remote_upstream_url + .clone() + .or(repo.read(cx).remote_origin_url.clone()) + }); + Some( + async move { (id, snapshot, buffer_edits, blame_buffer.await, remote_url) }, + ) }) .collect::>() }); @@ -524,13 +536,9 @@ impl GitBlame { .await; let mut res = vec![]; let mut errors = vec![]; - for (id, snapshot, buffer_edits, blame) in blame { + for (id, snapshot, buffer_edits, blame, remote_url) in blame { match blame { - Ok(Some(Blame { - entries, - messages, - remote_url, - })) => { + Ok(Some(Blame { entries, messages })) => { let entries = build_blame_entry_sum_tree( entries, snapshot.max_point().row, diff --git a/crates/git/src/blame.rs b/crates/git/src/blame.rs index e58b9cb7e0427bf3af1c88f473debba0b6f94f59..6325eacc8201d812d14dfdf4853f4004e22c263e 100644 --- a/crates/git/src/blame.rs +++ b/crates/git/src/blame.rs @@ -19,7 +19,6 @@ pub use git2 as libgit; pub struct Blame { pub entries: Vec, pub messages: HashMap, - pub remote_url: Option, } #[derive(Clone, Debug, Default)] @@ -36,7 +35,6 @@ impl Blame { working_directory: &Path, path: &RepoPath, content: &Rope, - remote_url: Option, ) -> Result { let output = run_git_blame(git_binary, working_directory, path, content).await?; let mut entries = parse_git_blame(&output)?; @@ -53,11 +51,7 @@ impl Blame { .await .context("failed to get commit messages")?; - Ok(Self { - entries, - messages, - remote_url, - }) + Ok(Self { entries, messages }) } } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index f79bade2d6bc12553b173c4f4e86989a961e6d31..70cbf6e3c58b7d8f6b690a554370d34262f541e3 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -1494,28 +1494,17 @@ impl GitRepository for RealGitRepository { let git_binary_path = self.any_git_binary_path.clone(); let executor = self.executor.clone(); - async move { - let remote_url = if let Some(remote_url) = self.remote_url("upstream").await { - Some(remote_url) - } else if let Some(remote_url) = self.remote_url("origin").await { - Some(remote_url) - } else { - None - }; - executor - .spawn(async move { - crate::blame::Blame::for_path( - &git_binary_path, - &working_directory?, - &path, - &content, - remote_url, - ) - .await - }) + executor + .spawn(async move { + crate::blame::Blame::for_path( + &git_binary_path, + &working_directory?, + &path, + &content, + ) .await - } - .boxed() + }) + .boxed() } fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result> { diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 81511b21be3599b4686b9fd11aac5118711f11fa..0b74a04e1db5c0f2b7c8934d1bbe7d38b1d1ad1b 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -3296,6 +3296,8 @@ impl RepositorySnapshot { .iter() .map(stash_to_proto) .collect(), + remote_upstream_url: self.remote_upstream_url.clone(), + remote_origin_url: self.remote_origin_url.clone(), } } @@ -3365,6 +3367,8 @@ impl RepositorySnapshot { .iter() .map(stash_to_proto) .collect(), + remote_upstream_url: self.remote_upstream_url.clone(), + remote_origin_url: self.remote_origin_url.clone(), } } @@ -5395,6 +5399,8 @@ impl Repository { cx.emit(RepositoryEvent::StashEntriesChanged) } self.snapshot.stash_entries = new_stash_entries; + self.snapshot.remote_upstream_url = update.remote_upstream_url; + self.snapshot.remote_origin_url = update.remote_origin_url; let edits = update .removed_statuses @@ -5954,11 +5960,7 @@ fn serialize_blame_buffer_response(blame: Option) -> proto::B .collect::>(); proto::BlameBufferResponse { - blame_response: Some(proto::blame_buffer_response::BlameResponse { - entries, - messages, - remote_url: blame.remote_url, - }), + blame_response: Some(proto::blame_buffer_response::BlameResponse { entries, messages }), } } @@ -5995,11 +5997,7 @@ fn deserialize_blame_buffer_response( .filter_map(|message| Some((git::Oid::from_bytes(&message.oid).ok()?, message.message))) .collect::>(); - Some(Blame { - entries, - messages, - remote_url: response.remote_url, - }) + Some(Blame { entries, messages }) } fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch { @@ -6147,7 +6145,6 @@ async fn compute_snapshot( events.push(RepositoryEvent::BranchChanged); } - // Used by edit prediction data collection let remote_origin_url = backend.remote_url("origin").await; let remote_upstream_url = backend.remote_url("upstream").await; diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index aa0668ceabddc7627fcc3593b86ad2f4e40a6ac7..6e3573b91a690290b71e626f3bd67fc81d8d8e92 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -124,6 +124,8 @@ message UpdateRepository { optional GitCommitDetails head_commit_details = 11; optional string merge_message = 12; repeated StashEntry stash_entries = 13; + optional string remote_upstream_url = 14; + optional string remote_origin_url = 15; } message RemoveRepository { @@ -500,8 +502,8 @@ message BlameBufferResponse { message BlameResponse { repeated BlameEntry entries = 1; repeated CommitMessage messages = 2; - optional string remote_url = 4; reserved 3; + reserved 4; } optional BlameResponse blame_response = 5; From 9ae77ec3c9fb0c8d5fe85f370432487f5b8b22d6 Mon Sep 17 00:00:00 2001 From: vipex <101529155+vipexv@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:48:06 +0100 Subject: [PATCH 05/81] markdown: Don't adjust indentation when inserting with multiple cursors (#40794) Closes #40757 ## Summary This PR addresses an issue where Zed incorrectly adjusts the indentation of Markdown lists when inserting text using multiple cursors. Currently: - Editing individual lines with a single cursor behaves correctly (no unwanted indentation changes). - Using multiple cursors, Zed automatically adjusts the indentation, unlike VS Code, which preserves the existing formatting. ## Tasks - [x] Implement a new test to verify correct Markdown indentation behavior with multiple cursors. - [x] Apply the fix to prevent Zed from adjusting indentation when inserting text on multiple cursors. ------------------------ Release Notes: - Fixed an issue where inserting text with multiple cursors inside a nested Markdown list would cause it to lose its indentation. --------- Co-authored-by: Smit Barmase --- Cargo.lock | 1 + crates/editor/Cargo.toml | 1 + crates/editor/src/editor_tests.rs | 59 +++++++++++++++++++++++ crates/languages/src/markdown/config.toml | 1 + crates/languages/src/markdown/indents.scm | 3 ++ 5 files changed, 65 insertions(+) create mode 100644 crates/languages/src/markdown/indents.scm diff --git a/Cargo.lock b/Cargo.lock index 5078c79e21ce1a580a6e055a7ce8ab4295f56906..87557afcb1b868cf9321bc0a4746e92687bb456d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5405,6 +5405,7 @@ dependencies = [ "tree-sitter-bash", "tree-sitter-c", "tree-sitter-html", + "tree-sitter-md", "tree-sitter-python", "tree-sitter-rust", "tree-sitter-typescript", diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 2aa02e293dd44d5fdd920ac8cd98da48b9c1a912..736916ebbf74f20f11e8c03a0e584bd8ae92e07d 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -118,6 +118,7 @@ tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true tree-sitter-yaml.workspace = true tree-sitter-bash.workspace = true +tree-sitter-md.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 64c335e2e4b0dc660efe1b28bb87984fba8aafb4..7ab3dcc2345dd8a140b7c4762dc5afadb9cef484 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -27498,6 +27498,65 @@ async fn test_paste_url_from_other_app_creates_markdown_link_over_selected_text( )); } +#[gpui::test] +async fn test_markdown_list_indent_with_multi_cursor(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let markdown_language = languages::language("markdown", tree_sitter_md::LANGUAGE.into()); + let mut cx = EditorTestContext::new(cx).await; + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(markdown_language), cx)); + + cx.set_state(&indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [ˇ] Item 2 + - [ˇ] Item 2.a + - [ˇ] Item 2.b + " + }); + + cx.update_editor(|editor, window, cx| { + editor.handle_input("X", window, cx); + }); + + cx.assert_editor_state(indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [Xˇ] Item 2 + - [Xˇ] Item 2.a + - [Xˇ] Item 2.b + " + }); +} + +#[gpui::test] +async fn test_markdown_list_indent_with_newline(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let markdown_language = languages::language("markdown", tree_sitter_md::LANGUAGE.into()); + let mut cx = EditorTestContext::new(cx).await; + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(markdown_language), cx)); + + cx.set_state(indoc! {" + - [x] list item + - [x] sub list itemˇ + " + }); + + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + + cx.assert_editor_state(indoc! {" + - [x] list item + - [x] sub list item + ˇ + " + }); +} + #[gpui::test] async fn test_paste_url_from_zed_copy_creates_markdown_link_over_selected_text( cx: &mut gpui::TestAppContext, diff --git a/crates/languages/src/markdown/config.toml b/crates/languages/src/markdown/config.toml index 36071cb5392462a51c10e0513b39979580ec67f5..2bbda0ef43e9a49b483dbe22cdf0473c8fbcf73c 100644 --- a/crates/languages/src/markdown/config.toml +++ b/crates/languages/src/markdown/config.toml @@ -24,4 +24,5 @@ rewrap_prefixes = [ auto_indent_on_paste = false auto_indent_using_last_non_empty_line = false tab_size = 2 +decrease_indent_pattern = "^\\s*$" prettier_parser_name = "markdown" diff --git a/crates/languages/src/markdown/indents.scm b/crates/languages/src/markdown/indents.scm new file mode 100644 index 0000000000000000000000000000000000000000..7fde3226bbbeb0fb9f0f7a1d90a328923a5228b3 --- /dev/null +++ b/crates/languages/src/markdown/indents.scm @@ -0,0 +1,3 @@ +(list (list_item) @indent) + +(list_item (list) @indent) From bdb8caa42e88c670be5278dab5819e770b92a133 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:47:27 -0300 Subject: [PATCH 06/81] git_ui: Fix indent guides not showing for file buffers in the commit view (#44166) Follow up to https://github.com/zed-industries/zed/pull/44162 where my strategy for not displaying the indent guides only in the commit message was wrong given I ended up... disabling indent guides for all the buffers. This PR adds a new method to the editor where we can disable it for a specific buffer ID following the pattern of `disable_header_for_buffer`. Release Notes: - N/A --- crates/editor/src/editor.rs | 17 ++++++++++++++--- crates/editor/src/indent_guides.rs | 4 ++++ crates/git_ui/src/commit_view.rs | 3 ++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 0de2dc8423b39ab2b336adb3cb17f79cc4a8f6e7..306d7a272b0b8c33e66803ccdbbd74194fde403a 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1079,6 +1079,7 @@ pub struct Editor { show_breakpoints: Option, show_wrap_guides: Option, show_indent_guides: Option, + buffers_with_disabled_indent_guides: HashSet, highlight_order: usize, highlighted_rows: HashMap>, background_highlights: HashMap, @@ -2204,6 +2205,7 @@ impl Editor { show_breakpoints: None, show_wrap_guides: None, show_indent_guides, + buffers_with_disabled_indent_guides: HashSet::default(), highlight_order: 0, highlighted_rows: HashMap::default(), background_highlights: HashMap::default(), @@ -20090,9 +20092,18 @@ impl Editor { self.show_indent_guides } - pub fn disable_indent_guides(&mut self) -> Option { - self.show_indent_guides = Some(false); - self.show_indent_guides + pub fn disable_indent_guides_for_buffer( + &mut self, + buffer_id: BufferId, + cx: &mut Context, + ) { + self.buffers_with_disabled_indent_guides.insert(buffer_id); + cx.notify(); + } + + pub fn has_indent_guides_disabled_for_buffer(&self, buffer_id: BufferId) -> bool { + self.buffers_with_disabled_indent_guides + .contains(&buffer_id) } pub fn toggle_line_numbers( diff --git a/crates/editor/src/indent_guides.rs b/crates/editor/src/indent_guides.rs index 7c392d27531472a413ce4d32d09cce4eb722e462..f186f9da77aca5a0d34cdc05272032f93862b1d2 100644 --- a/crates/editor/src/indent_guides.rs +++ b/crates/editor/src/indent_guides.rs @@ -181,6 +181,10 @@ pub fn indent_guides_in_range( .buffer_snapshot() .indent_guides_in_range(start_anchor..end_anchor, ignore_disabled_for_language, cx) .filter(|indent_guide| { + if editor.has_indent_guides_disabled_for_buffer(indent_guide.buffer_id) { + return false; + } + if editor.is_buffer_folded(indent_guide.buffer_id, cx) { return false; } diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs index 8a4504c1873193e81658c19c6b1115a9212e7760..7d191c1ae461ac36007dcbadc0b2e10f7dc53599 100644 --- a/crates/git_ui/src/commit_view.rs +++ b/crates/git_ui/src/commit_view.rs @@ -150,7 +150,6 @@ impl CommitView { Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); editor.disable_inline_diagnostics(); - editor.disable_indent_guides(); editor.set_expand_all_diff_hunks(cx); editor @@ -259,6 +258,8 @@ impl CommitView { this.editor.update(cx, |editor, cx| { editor.disable_header_for_buffer(message_buffer.read(cx).remote_id(), cx); + editor + .disable_indent_guides_for_buffer(message_buffer.read(cx).remote_id(), cx); editor.insert_blocks( [BlockProperties { From 43f977c6b92411b82c757d1b168e72937b8d416a Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:48:03 -0300 Subject: [PATCH 07/81] terminal view: Use tooltip element for the tab tooltip (#44169) Just recently realized we don't need this custom component for it given we now have `Tooltip::element`. UI result is exactly the same; nothing changes. Release Notes: - N/A --- .../terminal_view/src/terminal_tab_tooltip.rs | 36 ------------------- crates/terminal_view/src/terminal_view.rs | 30 ++++++++++------ 2 files changed, 19 insertions(+), 47 deletions(-) delete mode 100644 crates/terminal_view/src/terminal_tab_tooltip.rs diff --git a/crates/terminal_view/src/terminal_tab_tooltip.rs b/crates/terminal_view/src/terminal_tab_tooltip.rs deleted file mode 100644 index 6324c0999a8231bb1e633ef39343944783029895..0000000000000000000000000000000000000000 --- a/crates/terminal_view/src/terminal_tab_tooltip.rs +++ /dev/null @@ -1,36 +0,0 @@ -use gpui::{IntoElement, Render}; -use ui::{Divider, prelude::*, tooltip_container}; - -pub struct TerminalTooltip { - title: SharedString, - pid: u32, -} - -impl TerminalTooltip { - pub fn new(title: impl Into, pid: u32) -> Self { - Self { - title: title.into(), - pid, - } - } -} - -impl Render for TerminalTooltip { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - tooltip_container(cx, move |this, _cx| { - this.occlude() - .on_mouse_move(|_, _window, cx| cx.stop_propagation()) - .child( - v_flex() - .gap_1() - .child(Label::new(self.title.clone())) - .child(Divider::horizontal()) - .child( - Label::new(format!("Process ID (PID): {}", self.pid)) - .color(Color::Muted) - .size(LabelSize::Small), - ), - ) - }) - } -} diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 4d567d902ff4f9271a0bdcf6a4db94d0e3a34ec6..98f7a17a2778e05b258f2ab6135cb94ba91ba547 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -4,7 +4,6 @@ pub mod terminal_panel; mod terminal_path_like_target; pub mod terminal_scrollbar; mod terminal_slash_command; -pub mod terminal_tab_tooltip; use assistant_slash_command::SlashCommandRegistry; use editor::{EditorSettings, actions::SelectAll, blink_manager::BlinkManager}; @@ -32,9 +31,8 @@ use terminal_panel::TerminalPanel; use terminal_path_like_target::{hover_path_like_target, open_path_like_target}; use terminal_scrollbar::TerminalScrollHandle; use terminal_slash_command::TerminalSlashCommand; -use terminal_tab_tooltip::TerminalTooltip; use ui::{ - ContextMenu, Icon, IconName, Label, ScrollAxes, Scrollbars, Tooltip, WithScrollbar, h_flex, + ContextMenu, Divider, ScrollAxes, Scrollbars, Tooltip, WithScrollbar, prelude::*, scrollbars::{self, GlobalSetting, ScrollbarVisibility}, }; @@ -1140,14 +1138,24 @@ impl Item for TerminalView { type Event = ItemEvent; fn tab_tooltip_content(&self, cx: &App) -> Option { - let terminal = self.terminal().read(cx); - let title = terminal.title(false); - let pid = terminal.pid_getter()?.fallback_pid(); - - Some(TabTooltipContent::Custom(Box::new(move |_window, cx| { - cx.new(|_| TerminalTooltip::new(title.clone(), pid.as_u32())) - .into() - }))) + Some(TabTooltipContent::Custom(Box::new(Tooltip::element({ + let terminal = self.terminal().read(cx); + let title = terminal.title(false); + let pid = terminal.pid_getter()?.fallback_pid(); + + move |_, _| { + v_flex() + .gap_1() + .child(Label::new(title.clone())) + .child(h_flex().flex_grow().child(Divider::horizontal())) + .child( + Label::new(format!("Process ID (PID): {}", pid)) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any_element() + } + })))) } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { From cd8679e81a2d3c32b10bf65a3d583b6886d2a2f3 Mon Sep 17 00:00:00 2001 From: Ian Chamberlain Date: Thu, 4 Dec 2025 12:37:32 -0800 Subject: [PATCH 08/81] Allow trailing commas in builtin JSONC schemas (#43854) The JSON language server looks for a top-level `allowTrailingCommas` flag to decide whether it should warn for trailing commas. Since the JSONC parser for these builtin files can handles trailing commas, adding this flag to the schema also prevents a warning for those commas. I don't think there's an issue that is only for this specific issue, but it relates to *many* existing / older issues: - #18509 - #17487 - #40970 - #18509 - #21303 Release Notes: - Suppress warning for trailing commas in builtin JSON files (`settings.json`, `keymap.json`, etc.) --- crates/settings/src/keymap_file.rs | 5 ++++- crates/settings/src/settings_store.rs | 3 ++- crates/snippet_provider/src/format.rs | 3 ++- crates/task/src/debug_format.rs | 1 + crates/task/src/task_template.rs | 3 ++- crates/util/src/schemars.rs | 17 +++++++++++++++++ 6 files changed, 28 insertions(+), 4 deletions(-) diff --git a/crates/settings/src/keymap_file.rs b/crates/settings/src/keymap_file.rs index fc86afca2a1cbcd0a26777aa2ccb1fcb29b193a5..2ef1dfc5385592b9757eff5ec631af818ae1869c 100644 --- a/crates/settings/src/keymap_file.rs +++ b/crates/settings/src/keymap_file.rs @@ -15,6 +15,7 @@ use util::ResultExt as _; use util::{ asset_str, markdown::{MarkdownEscaped, MarkdownInlineCode, MarkdownString}, + schemars::AllowTrailingCommas, }; use crate::SettingsAssets; @@ -451,7 +452,9 @@ impl KeymapFile { /// Creates a JSON schema generator, suitable for generating json schemas /// for actions pub fn action_schema_generator() -> schemars::SchemaGenerator { - schemars::generate::SchemaSettings::draft2019_09().into_generator() + schemars::generate::SchemaSettings::draft2019_09() + .with_transform(AllowTrailingCommas) + .into_generator() } pub fn generate_json_schema_for_registered_actions(cx: &mut App) -> Value { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 181b8b417879be63fe85dbe6d08adca2d97929bd..72e2d3ef099659c5ad27e7f1aaafaee24354d4a9 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -25,7 +25,7 @@ use std::{ use util::{ ResultExt as _, rel_path::RelPath, - schemars::{DefaultDenyUnknownFields, replace_subschema}, + schemars::{AllowTrailingCommas, DefaultDenyUnknownFields, replace_subschema}, }; pub type EditorconfigProperties = ec4rs::Properties; @@ -1010,6 +1010,7 @@ impl SettingsStore { pub fn json_schema(&self, params: &SettingsJsonSchemaParams) -> Value { let mut generator = schemars::generate::SchemaSettings::draft2019_09() .with_transform(DefaultDenyUnknownFields) + .with_transform(AllowTrailingCommas) .into_generator(); UserSettingsContent::json_schema(&mut generator); diff --git a/crates/snippet_provider/src/format.rs b/crates/snippet_provider/src/format.rs index 0bbf137aed506f4cc7793f5dbe80ee144b620bf4..f9abb987d919b3a8bc7ab558e4bc86bac5e0b5a9 100644 --- a/crates/snippet_provider/src/format.rs +++ b/crates/snippet_provider/src/format.rs @@ -2,7 +2,7 @@ use collections::HashMap; use schemars::{JsonSchema, json_schema}; use serde::Deserialize; use std::borrow::Cow; -use util::schemars::DefaultDenyUnknownFields; +use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields}; #[derive(Deserialize)] pub struct VsSnippetsFile { @@ -14,6 +14,7 @@ impl VsSnippetsFile { pub fn generate_json_schema() -> serde_json::Value { let schema = schemars::generate::SchemaSettings::draft2019_09() .with_transform(DefaultDenyUnknownFields) + .with_transform(AllowTrailingCommas) .into_generator() .root_schema_for::(); diff --git a/crates/task/src/debug_format.rs b/crates/task/src/debug_format.rs index 38089670e23f815221c274a2ccc4619b9e8bb327..5609e2565c8497ad2e92fb8b7d0e6738a1cb663c 100644 --- a/crates/task/src/debug_format.rs +++ b/crates/task/src/debug_format.rs @@ -357,6 +357,7 @@ impl DebugTaskFile { "$schema": meta_schema, "title": "Debug Configurations", "description": "Configuration for debug scenarios", + "allowTrailingCommas": true, "type": "array", "items": { "type": "object", diff --git a/crates/task/src/task_template.rs b/crates/task/src/task_template.rs index 33ff610b6e1ba509c75138ad4bf35478e69deaf1..0c319db0616862489b7b7d21912142a01ee89fcb 100644 --- a/crates/task/src/task_template.rs +++ b/crates/task/src/task_template.rs @@ -4,7 +4,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::path::PathBuf; -use util::schemars::DefaultDenyUnknownFields; +use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields}; use util::serde::default_true; use util::{ResultExt, truncate_and_remove_front}; @@ -118,6 +118,7 @@ impl TaskTemplates { pub fn generate_json_schema() -> serde_json::Value { let schema = schemars::generate::SchemaSettings::draft2019_09() .with_transform(DefaultDenyUnknownFields) + .with_transform(AllowTrailingCommas) .into_generator() .root_schema_for::(); diff --git a/crates/util/src/schemars.rs b/crates/util/src/schemars.rs index 9314eda4ac4d5003d7186c3115137e2e54c66794..8124ca8cfef62cb4ea320da6423d7ad95a09eb78 100644 --- a/crates/util/src/schemars.rs +++ b/crates/util/src/schemars.rs @@ -53,3 +53,20 @@ impl schemars::transform::Transform for DefaultDenyUnknownFields { transform_subschemas(self, schema); } } + +/// Defaults `allowTrailingCommas` to `true`, for use with `json-language-server`. +/// This can be applied to any schema that will be treated as `jsonc`. +/// +/// Note that this is non-recursive and only applied to the root schema. +#[derive(Clone)] +pub struct AllowTrailingCommas; + +impl schemars::transform::Transform for AllowTrailingCommas { + fn transform(&mut self, schema: &mut schemars::Schema) { + if let Some(object) = schema.as_object_mut() + && !object.contains_key("allowTrailingCommas") + { + object.insert("allowTrailingCommas".to_string(), true.into()); + } + } +} From 76167109db7b2d899f2e88ffe04a84ca718dca03 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 4 Dec 2025 12:48:39 -0800 Subject: [PATCH 09/81] Add experimental LSP-based context retrieval system for edit prediction (#44036) To do * [x] Default to no context retrieval. Allow opting in to LSP-based retrieval via a setting (for users in `zeta2` feature flag) * [x] Feed this context to models when enabled * [x] Make the zeta2 context view work well with LSP retrieval * [x] Add a UI for the setting (for feature-flagged users) * [x] Ensure Zeta CLI `context` command is usable --- * [ ] Filter out LSP definitions that are too large / entire files (e.g. modules) * [ ] Introduce timeouts * [ ] Test with other LSPs * [ ] Figure out hangs Release Notes: - N/A --------- Co-authored-by: Ben Kunkle Co-authored-by: Agus Zubiaga --- Cargo.lock | 28 +- Cargo.toml | 2 + .../src/edit_prediction_button.rs | 28 +- crates/edit_prediction_context2/Cargo.toml | 42 + crates/edit_prediction_context2/LICENSE-GPL | 1 + .../src/assemble_excerpts.rs | 324 ++++++ .../src/edit_prediction_context2.rs | 465 +++++++++ .../src/edit_prediction_context_tests.rs | 360 +++++++ .../src/fake_definition_lsp.rs | 329 ++++++ .../src/extension_store_test.rs | 2 +- crates/language/src/buffer.rs | 14 + crates/language/src/buffer_tests.rs | 57 +- crates/language/src/language_registry.rs | 18 +- crates/language/src/language_settings.rs | 8 + crates/language/src/outline.rs | 50 +- crates/language/src/syntax_map.rs | 13 + .../remote_server/src/remote_editing_tests.rs | 6 +- .../settings/src/settings_content/language.rs | 2 + crates/text/src/anchor.rs | 8 +- crates/ui/src/components/data_table.rs | 22 +- crates/zeta/Cargo.toml | 1 + crates/zeta/src/assemble_excerpts.rs | 173 --- crates/zeta/src/retrieval_search.rs | 364 ++----- crates/zeta/src/sweep_ai.rs | 28 +- crates/zeta/src/zeta.rs | 984 ++++++++---------- crates/zeta2_tools/Cargo.toml | 1 - crates/zeta2_tools/src/zeta2_context_view.rs | 310 +++--- crates/zeta2_tools/src/zeta2_tools.rs | 12 + crates/zeta_cli/src/main.rs | 76 +- crates/zeta_cli/src/predict.rs | 61 +- crates/zeta_cli/src/util.rs | 28 +- 31 files changed, 2479 insertions(+), 1338 deletions(-) create mode 100644 crates/edit_prediction_context2/Cargo.toml create mode 120000 crates/edit_prediction_context2/LICENSE-GPL create mode 100644 crates/edit_prediction_context2/src/assemble_excerpts.rs create mode 100644 crates/edit_prediction_context2/src/edit_prediction_context2.rs create mode 100644 crates/edit_prediction_context2/src/edit_prediction_context_tests.rs create mode 100644 crates/edit_prediction_context2/src/fake_definition_lsp.rs delete mode 100644 crates/zeta/src/assemble_excerpts.rs diff --git a/Cargo.lock b/Cargo.lock index 87557afcb1b868cf9321bc0a4746e92687bb456d..6d41fbe96fac878f496e93461c180e1c184216d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5342,6 +5342,32 @@ dependencies = [ "zlog", ] +[[package]] +name = "edit_prediction_context2" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "env_logger 0.11.8", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "log", + "lsp", + "parking_lot", + "pretty_assertions", + "project", + "serde", + "serde_json", + "settings", + "smallvec", + "text", + "tree-sitter", + "util", + "zlog", +] + [[package]] name = "editor" version = "0.1.0" @@ -21693,6 +21719,7 @@ dependencies = [ "db", "edit_prediction", "edit_prediction_context", + "edit_prediction_context2", "editor", "feature_flags", "fs", @@ -21742,7 +21769,6 @@ dependencies = [ "clap", "client", "cloud_llm_client", - "cloud_zeta2_prompt", "collections", "edit_prediction_context", "editor", diff --git a/Cargo.toml b/Cargo.toml index 59b9a53d4a60b28582625fb90b64b934079cdc40..62a44dbf35fefbf02a1b570146b0bf24cea6dcd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ members = [ "crates/edit_prediction", "crates/edit_prediction_button", "crates/edit_prediction_context", + "crates/edit_prediction_context2", "crates/zeta2_tools", "crates/editor", "crates/eval", @@ -316,6 +317,7 @@ image_viewer = { path = "crates/image_viewer" } edit_prediction = { path = "crates/edit_prediction" } edit_prediction_button = { path = "crates/edit_prediction_button" } edit_prediction_context = { path = "crates/edit_prediction_context" } +edit_prediction_context2 = { path = "crates/edit_prediction_context2" } zeta2_tools = { path = "crates/zeta2_tools" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 8ce8441859b7cc747a2b566dedd913e58259969d..8b234497376aefdc972681c877a1122f3f9cee17 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1105,9 +1105,33 @@ impl EditPredictionButton { .separator(); } - let menu = self.build_language_settings_menu(menu, window, cx); - let menu = self.add_provider_switching_section(menu, provider, cx); + menu = self.build_language_settings_menu(menu, window, cx); + + if cx.has_flag::() { + let settings = all_language_settings(None, cx); + let context_retrieval = settings.edit_predictions.use_context; + menu = menu.separator().header("Context Retrieval").item( + ContextMenuEntry::new("Enable Context Retrieval") + .toggleable(IconPosition::Start, context_retrieval) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler({ + let fs = self.fs.clone(); + move |_, cx| { + update_settings_file(fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .experimental_edit_prediction_context_retrieval = + Some(!context_retrieval) + }); + } + }), + ); + } + menu = self.add_provider_switching_section(menu, provider, cx); menu }) } diff --git a/crates/edit_prediction_context2/Cargo.toml b/crates/edit_prediction_context2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..597884b44821e24a930c8730225be4c6bf1c90f6 --- /dev/null +++ b/crates/edit_prediction_context2/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "edit_prediction_context2" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_context2.rs" + +[dependencies] +parking_lot.workspace = true +anyhow.workspace = true +collections.workspace = true +futures.workspace = true +gpui.workspace = true +language.workspace = true +lsp.workspace = true +project.workspace = true +log.workspace = true +serde.workspace = true +smallvec.workspace = true +tree-sitter.workspace = true +util.workspace = true + +[dev-dependencies] +env_logger.workspace = true +indoc.workspace = true +futures.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +lsp = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true +project = {workspace= true, features = ["test-support"]} +serde_json.workspace = true +settings = {workspace= true, features = ["test-support"]} +text = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/edit_prediction_context2/LICENSE-GPL b/crates/edit_prediction_context2/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction_context2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/edit_prediction_context2/src/assemble_excerpts.rs b/crates/edit_prediction_context2/src/assemble_excerpts.rs new file mode 100644 index 0000000000000000000000000000000000000000..b3b8d4f8bc480053a1e9ab9d498d5350039ed609 --- /dev/null +++ b/crates/edit_prediction_context2/src/assemble_excerpts.rs @@ -0,0 +1,324 @@ +use crate::RelatedExcerpt; +use language::{BufferSnapshot, OffsetRangeExt as _, Point}; +use std::ops::Range; + +#[cfg(not(test))] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512; +#[cfg(test)] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24; + +pub fn assemble_excerpts( + buffer: &BufferSnapshot, + mut input_ranges: Vec>, +) -> Vec { + merge_ranges(&mut input_ranges); + + let mut outline_ranges = Vec::new(); + let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); + let mut outline_ix = 0; + for input_range in &mut input_ranges { + *input_range = clip_range_to_lines(input_range, false, buffer); + + while let Some(outline_item) = outline_items.get(outline_ix) { + let item_range = clip_range_to_lines(&outline_item.range, false, buffer); + + if item_range.start > input_range.start { + break; + } + + if item_range.end > input_range.start { + let body_range = outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)) + .filter(|body_range| { + body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE + }); + + add_outline_item( + item_range.clone(), + body_range.clone(), + buffer, + &mut outline_ranges, + ); + + if let Some(body_range) = body_range + && input_range.start < body_range.start + { + let mut child_outline_ix = outline_ix + 1; + while let Some(next_outline_item) = outline_items.get(child_outline_ix) { + if next_outline_item.range.end > body_range.end { + break; + } + if next_outline_item.depth == outline_item.depth + 1 { + let next_item_range = + clip_range_to_lines(&next_outline_item.range, false, buffer); + + add_outline_item( + next_item_range, + next_outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)), + buffer, + &mut outline_ranges, + ); + child_outline_ix += 1; + } + } + } + } + + outline_ix += 1; + } + } + + input_ranges.extend_from_slice(&outline_ranges); + merge_ranges(&mut input_ranges); + + input_ranges + .into_iter() + .map(|range| { + let offset_range = range.to_offset(buffer); + RelatedExcerpt { + point_range: range, + anchor_range: buffer.anchor_before(offset_range.start) + ..buffer.anchor_after(offset_range.end), + text: buffer.as_rope().slice(offset_range), + } + }) + .collect() +} + +fn clip_range_to_lines( + range: &Range, + inward: bool, + buffer: &BufferSnapshot, +) -> Range { + let mut range = range.clone(); + if inward { + if range.start.column > 0 { + range.start.column = buffer.line_len(range.start.row); + } + range.end.column = 0; + } else { + range.start.column = 0; + if range.end.column > 0 { + range.end.column = buffer.line_len(range.end.row); + } + } + range +} + +fn add_outline_item( + mut item_range: Range, + body_range: Option>, + buffer: &BufferSnapshot, + outline_ranges: &mut Vec>, +) { + if let Some(mut body_range) = body_range { + if body_range.start.column > 0 { + body_range.start.column = buffer.line_len(body_range.start.row); + } + body_range.end.column = 0; + + let head_range = item_range.start..body_range.start; + if head_range.start < head_range.end { + outline_ranges.push(head_range); + } + + let tail_range = body_range.end..item_range.end; + if tail_range.start < tail_range.end { + outline_ranges.push(tail_range); + } + } else { + item_range.start.column = 0; + item_range.end.column = buffer.line_len(item_range.end.row); + outline_ranges.push(item_range); + } +} + +pub fn merge_ranges(ranges: &mut Vec>) { + ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + + let mut index = 1; + while index < ranges.len() { + let mut prev_range_end = ranges[index - 1].end; + if prev_range_end.column > 0 { + prev_range_end += Point::new(1, 0); + } + + if (prev_range_end + Point::new(1, 0)) + .cmp(&ranges[index].start) + .is_ge() + { + let removed = ranges.remove(index); + if removed.end.cmp(&ranges[index - 1].end).is_gt() { + ranges[index - 1].end = removed.end; + } + } else { + index += 1; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{TestAppContext, prelude::*}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; + use pretty_assertions::assert_eq; + use std::{fmt::Write as _, sync::Arc}; + use util::test::marked_text_ranges; + + #[gpui::test] + fn test_rust(cx: &mut TestAppContext) { + let table = [ + ( + indoc! {r#" + struct User { + first_name: String, + «last_name»: String, + age: u32, + email: String, + create_at: Instant, + } + + impl User { + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn full_name(&self) -> String { + « format!("{} {}", self.first_name, self.last_name) + » } + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + … + } + + impl User { + … + pub fn full_name(&self) -> String { + format!("{} {}", self.first_name, self.last_name) + } + } + "#}, + ), + ( + indoc! {r#" + struct «User» { + first_name: String, + last_name: String, + age: u32, + } + + impl User { + // methods + } + "# + }, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + age: u32, + } + … + "#}, + ), + ( + indoc! {r#" + trait «FooProvider» { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + ids.iter() + .map(|id| self.provide_foo(*id)) + .collect() + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait FooProvider { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ]; + + for (input, expected_output) in table { + let (input, ranges) = marked_text_ranges(&input, false); + let buffer = + cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); + buffer.read_with(cx, |buffer, _cx| { + let ranges: Vec> = ranges + .into_iter() + .map(|range| range.to_point(&buffer)) + .collect(); + + let excerpts = assemble_excerpts(&buffer.snapshot(), ranges); + + let output = format_excerpts(buffer, &excerpts); + assert_eq!(output, expected_output); + }); + } + } + + fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { + let mut output = String::new(); + let file_line_count = buffer.max_point().row; + let mut current_row = 0; + for excerpt in excerpts { + if excerpt.text.is_empty() { + continue; + } + if current_row < excerpt.point_range.start.row { + writeln!(&mut output, "…").unwrap(); + } + current_row = excerpt.point_range.start.row; + + for line in excerpt.text.to_string().lines() { + output.push_str(line); + output.push('\n'); + current_row += 1; + } + } + if current_row < file_line_count { + writeln!(&mut output, "…").unwrap(); + } + output + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(language::tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/edit_prediction_context2/src/edit_prediction_context2.rs b/crates/edit_prediction_context2/src/edit_prediction_context2.rs new file mode 100644 index 0000000000000000000000000000000000000000..f8790478547ddb8b7b873015846f2af6c1bcbc2c --- /dev/null +++ b/crates/edit_prediction_context2/src/edit_prediction_context2.rs @@ -0,0 +1,465 @@ +use crate::assemble_excerpts::assemble_excerpts; +use anyhow::Result; +use collections::HashMap; +use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; +use project::{LocationLink, Project, ProjectPath}; +use serde::{Serialize, Serializer}; +use smallvec::SmallVec; +use std::{ + collections::hash_map, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; +use util::{RangeExt as _, ResultExt}; + +mod assemble_excerpts; +#[cfg(test)] +mod edit_prediction_context_tests; +#[cfg(test)] +mod fake_definition_lsp; + +pub struct RelatedExcerptStore { + project: WeakEntity, + related_files: Vec, + cache: HashMap>, + update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, +} + +pub enum RelatedExcerptStoreEvent { + StartedRefresh, + FinishedRefresh { + cache_hit_count: usize, + cache_miss_count: usize, + mean_definition_latency: Duration, + max_definition_latency: Duration, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Identifier { + pub name: String, + pub range: Range, +} + +enum DefinitionTask { + CacheHit(Arc), + CacheMiss(Task>>>), +} + +#[derive(Debug)] +struct CacheEntry { + definitions: SmallVec<[CachedDefinition; 1]>, +} + +#[derive(Clone, Debug)] +struct CachedDefinition { + path: ProjectPath, + buffer: Entity, + anchor_range: Range, +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedFile { + #[serde(serialize_with = "serialize_project_path")] + pub path: ProjectPath, + #[serde(skip)] + pub buffer: WeakEntity, + pub excerpts: Vec, + pub max_row: u32, +} + +impl RelatedFile { + pub fn merge_excerpts(&mut self) { + self.excerpts.sort_unstable_by(|a, b| { + a.point_range + .start + .cmp(&b.point_range.start) + .then(b.point_range.end.cmp(&a.point_range.end)) + }); + + let mut index = 1; + while index < self.excerpts.len() { + if self.excerpts[index - 1] + .point_range + .end + .cmp(&self.excerpts[index].point_range.start) + .is_ge() + { + let removed = self.excerpts.remove(index); + if removed + .point_range + .end + .cmp(&self.excerpts[index - 1].point_range.end) + .is_gt() + { + self.excerpts[index - 1].point_range.end = removed.point_range.end; + self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; + } + } else { + index += 1; + } + } + } +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedExcerpt { + #[serde(skip)] + pub anchor_range: Range, + #[serde(serialize_with = "serialize_point_range")] + pub point_range: Range, + #[serde(serialize_with = "serialize_rope")] + pub text: Rope, +} + +fn serialize_project_path( + project_path: &ProjectPath, + serializer: S, +) -> Result { + project_path.path.serialize(serializer) +} + +fn serialize_rope(rope: &Rope, serializer: S) -> Result { + rope.to_string().serialize(serializer) +} + +fn serialize_point_range( + range: &Range, + serializer: S, +) -> Result { + [ + [range.start.row, range.start.column], + [range.end.row, range.end.column], + ] + .serialize(serializer) +} + +const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); + +impl EventEmitter for RelatedExcerptStore {} + +impl RelatedExcerptStore { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity, Anchor)>(); + cx.spawn(async move |this, cx| { + let executor = cx.background_executor().clone(); + while let Some((mut buffer, mut position)) = update_rx.next().await { + let mut timer = executor.timer(DEBOUNCE_DURATION).fuse(); + loop { + futures::select_biased! { + next = update_rx.next() => { + if let Some((new_buffer, new_position)) = next { + buffer = new_buffer; + position = new_position; + timer = executor.timer(DEBOUNCE_DURATION).fuse(); + } else { + return anyhow::Ok(()); + } + } + _ = timer => break, + } + } + + Self::fetch_excerpts(this.clone(), buffer, position, cx).await?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + RelatedExcerptStore { + project: project.downgrade(), + update_tx, + related_files: Vec::new(), + cache: Default::default(), + } + } + + pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { + self.update_tx.unbounded_send((buffer, position)).ok(); + } + + pub fn related_files(&self) -> &[RelatedFile] { + &self.related_files + } + + async fn fetch_excerpts( + this: WeakEntity, + buffer: Entity, + position: Anchor, + cx: &mut AsyncApp, + ) -> Result<()> { + let (project, snapshot) = this.read_with(cx, |this, cx| { + (this.project.upgrade(), buffer.read(cx).snapshot()) + })?; + let Some(project) = project else { + return Ok(()); + }; + + let file = snapshot.file().cloned(); + if let Some(file) = &file { + log::debug!("retrieving_context buffer:{}", file.path().as_unix_str()); + } + + this.update(cx, |_, cx| { + cx.emit(RelatedExcerptStoreEvent::StartedRefresh); + })?; + + let identifiers = cx + .background_spawn(async move { identifiers_for_position(&snapshot, position) }) + .await; + + let async_cx = cx.clone(); + let start_time = Instant::now(); + let futures = this.update(cx, |this, cx| { + identifiers + .into_iter() + .filter_map(|identifier| { + let task = if let Some(entry) = this.cache.get(&identifier) { + DefinitionTask::CacheHit(entry.clone()) + } else { + DefinitionTask::CacheMiss( + this.project + .update(cx, |project, cx| { + project.definitions(&buffer, identifier.range.start, cx) + }) + .ok()?, + ) + }; + + let cx = async_cx.clone(); + let project = project.clone(); + Some(async move { + match task { + DefinitionTask::CacheHit(cache_entry) => { + Some((identifier, cache_entry, None)) + } + DefinitionTask::CacheMiss(task) => { + let locations = task.await.log_err()??; + let duration = start_time.elapsed(); + cx.update(|cx| { + ( + identifier, + Arc::new(CacheEntry { + definitions: locations + .into_iter() + .filter_map(|location| { + process_definition(location, &project, cx) + }) + .collect(), + }), + Some(duration), + ) + }) + .ok() + } + } + }) + }) + .collect::>() + })?; + + let mut cache_hit_count = 0; + let mut cache_miss_count = 0; + let mut mean_definition_latency = Duration::ZERO; + let mut max_definition_latency = Duration::ZERO; + let mut new_cache = HashMap::default(); + new_cache.reserve(futures.len()); + for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() { + new_cache.insert(identifier, entry); + if let Some(duration) = duration { + cache_miss_count += 1; + mean_definition_latency += duration; + max_definition_latency = max_definition_latency.max(duration); + } else { + cache_hit_count += 1; + } + } + mean_definition_latency /= cache_miss_count.max(1) as u32; + + let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; + + if let Some(file) = &file { + log::debug!( + "finished retrieving context buffer:{}, latency:{:?}", + file.path().as_unix_str(), + start_time.elapsed() + ); + } + + this.update(cx, |this, cx| { + this.cache = new_cache; + this.related_files = related_files; + cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + }); + })?; + + anyhow::Ok(()) + } +} + +async fn rebuild_related_files( + new_entries: HashMap>, + cx: &mut AsyncApp, +) -> Result<(HashMap>, Vec)> { + let mut snapshots = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { + definition + .buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + e.insert( + definition + .buffer + .read_with(cx, |buffer, _| buffer.snapshot())?, + ); + } + } + } + + Ok(cx + .background_spawn(async move { + let mut files = Vec::::new(); + let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); + let mut paths_by_buffer = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else { + continue; + }; + paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone()); + ranges_by_buffer + .entry(definition.buffer.clone()) + .or_default() + .push(definition.anchor_range.to_point(snapshot)); + } + } + + for (buffer, ranges) in ranges_by_buffer { + let Some(snapshot) = snapshots.get(&buffer.entity_id()) else { + continue; + }; + let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else { + continue; + }; + let excerpts = assemble_excerpts(snapshot, ranges); + files.push(RelatedFile { + path: project_path.clone(), + buffer: buffer.downgrade(), + excerpts, + max_row: snapshot.max_point().row, + }); + } + + files.sort_by_key(|file| file.path.clone()); + (new_entries, files) + }) + .await) +} + +fn process_definition( + location: LocationLink, + project: &Entity, + cx: &mut App, +) -> Option { + let buffer = location.target.buffer.read(cx); + let anchor_range = location.target.range; + let file = buffer.file()?; + let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?; + if worktree.read(cx).is_single_file() { + return None; + } + Some(CachedDefinition { + path: ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }, + buffer: location.target.buffer, + anchor_range, + }) +} + +/// Gets all of the identifiers that are present in the given line, and its containing +/// outline items. +fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec { + let offset = position.to_offset(buffer); + let point = buffer.offset_to_point(offset); + + let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point()); + let mut ranges = vec![line_range.to_offset(&buffer)]; + + // Include the range of the outline item itself, but not its body. + let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); + for item in outline_items { + if let Some(body_range) = item.body_range(&buffer) { + ranges.push(item.range.start..body_range.start.to_offset(&buffer)); + } else { + ranges.push(item.range.clone()); + } + } + + ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + ranges.dedup_by(|a, b| { + if a.start <= b.end { + b.start = b.start.min(a.start); + b.end = b.end.max(a.end); + true + } else { + false + } + }); + + let mut identifiers = Vec::new(); + let outer_range = + ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end); + + let mut captures = buffer + .syntax + .captures(outer_range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + for range in ranges { + captures.set_byte_range(range.start..outer_range.end); + + let mut last_range = None; + while let Some(capture) = captures.peek() { + let node_range = capture.node.byte_range(); + if node_range.start > range.end { + break; + } + let config = captures.grammars()[capture.grammar_index] + .highlights_config + .as_ref(); + + if let Some(config) = config + && config.identifier_capture_indices.contains(&capture.index) + && range.contains_inclusive(&node_range) + && Some(&node_range) != last_range.as_ref() + { + let name = buffer.text_for_range(node_range.clone()).collect(); + identifiers.push(Identifier { + range: buffer.anchor_after(node_range.start) + ..buffer.anchor_before(node_range.end), + name, + }); + last_range = Some(node_range); + } + + captures.advance(); + } + } + + identifiers +} diff --git a/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..05d1becc2167837a5f9741d77e7bc96c2f5b8d34 --- /dev/null +++ b/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs @@ -0,0 +1,360 @@ +use super::*; +use futures::channel::mpsc::UnboundedReceiver; +use gpui::TestAppContext; +use indoc::indoc; +use language::{Language, LanguageConfig, LanguageMatcher, Point, ToPoint as _, tree_sitter_rust}; +use lsp::FakeLanguageServer; +use project::{FakeFs, LocationLink, Project}; +use serde_json::json; +use settings::SettingsStore; +use std::sync::Arc; +use util::path; + +#[gpui::test] +async fn test_edit_prediction_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx)); + related_excerpt_store.update(cx, |store, cx| { + let position = { + let buffer = buffer.read(cx); + let offset = buffer.text().find("todo").unwrap(); + buffer.anchor_before(offset) + }; + + store.refresh(buffer.clone(), position, cx); + }); + + cx.executor().advance_clock(DEBOUNCE_DURATION); + related_excerpt_store.update(cx, |store, _| { + let excerpts = store.related_files(); + assert_related_files( + &excerpts, + &[ + ( + "src/company.rs", + &[indoc! {" + pub struct Company { + owner: Arc, + address: Address, + }"}], + ), + ( + "src/main.rs", + &[ + indoc! {" + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) {"}, + indoc! {" + } + }"}, + ], + ), + ( + "src/person.rs", + &[ + indoc! {" + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + }"}, + "}", + ], + ), + ], + ); + }); +} + +#[gpui::test] +async fn test_fake_definition_lsp(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text()); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("Address {").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub struct Address {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("State::CA").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub enum State {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("to_string()").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx); +} + +fn init_test(cx: &mut TestAppContext) { + let settings_store = cx.update(|cx| SettingsStore::test(cx)); + cx.set_global(settings_store); + env_logger::try_init().ok(); +} + +fn setup_fake_lsp( + project: &Entity, + cx: &mut TestAppContext, +) -> UnboundedReceiver { + let (language_registry, fs) = project.read_with(cx, |project, _| { + (project.languages().clone(), project.fs().clone()) + }); + let language = rust_lang(); + language_registry.add(language.clone()); + fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs) +} + +fn test_project_1() -> serde_json::Value { + let person_rs = indoc! {r#" + pub struct Person { + first_name: String, + last_name: String, + email: String, + age: u32, + } + + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + } + + pub fn get_last_name(&self) -> &str { + &self.last_name + } + + pub fn get_email(&self) -> &str { + &self.email + } + + pub fn get_age(&self) -> u32 { + self.age + } + } + "#}; + + let address_rs = indoc! {r#" + pub struct Address { + street: String, + city: String, + state: State, + zip: u32, + } + + pub enum State { + CA, + OR, + WA, + TX, + // ... + } + + impl Address { + pub fn get_street(&self) -> &str { + &self.street + } + + pub fn get_city(&self) -> &str { + &self.city + } + + pub fn get_state(&self) -> State { + self.state + } + + pub fn get_zip(&self) -> u32 { + self.zip + } + } + "#}; + + let company_rs = indoc! {r#" + use super::person::Person; + use super::address::Address; + + pub struct Company { + owner: Arc, + address: Address, + } + + impl Company { + pub fn get_owner(&self) -> &Person { + &self.owner + } + + pub fn get_address(&self) -> &Address { + &self.address + } + + pub fn to_string(&self) -> String { + format!("{} ({})", self.owner.first_name, self.address.city) + } + } + "#}; + + let main_rs = indoc! {r#" + use std::sync::Arc; + use super::person::Person; + use super::address::Address; + use super::company::Company; + + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) { + self.company = company; + if company.owner != self.company.owner { + log("new owner", company.owner.get_first_name()); todo(); + } + } + } + + fn main() { + let company = Company { + owner: Arc::new(Person { + first_name: "John".to_string(), + last_name: "Doe".to_string(), + email: "john@example.com".to_string(), + age: 30, + }), + address: Address { + street: "123 Main St".to_string(), + city: "Anytown".to_string(), + state: State::CA, + zip: 12345, + }, + }; + + println!("Company: {}", company.to_string()); + } + "#}; + + json!({ + "src": { + "person.rs": person_rs, + "address.rs": address_rs, + "company.rs": company_rs, + "main.rs": main_rs, + }, + }) +} + +fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) { + let actual_files = actual_files + .iter() + .map(|file| { + let excerpts = file + .excerpts + .iter() + .map(|excerpt| excerpt.text.to_string()) + .collect::>(); + (file.path.path.as_unix_str(), excerpts) + }) + .collect::>(); + let expected_excerpts = expected_files + .iter() + .map(|(path, texts)| { + ( + *path, + texts + .iter() + .map(|line| line.to_string()) + .collect::>(), + ) + }) + .collect::>(); + pretty_assertions::assert_eq!(actual_files, expected_excerpts) +} + +fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) { + let actual_first_lines = definitions + .iter() + .map(|definition| { + definition.target.buffer.read_with(cx, |buffer, _| { + let mut start = definition.target.range.start.to_point(&buffer); + start.column = 0; + let end = Point::new(start.row, buffer.line_len(start.row)); + buffer + .text_for_range(start..end) + .collect::() + .trim() + .to_string() + }) + }) + .collect::>(); + + assert_eq!(actual_first_lines, first_lines); +} + +pub(crate) fn rust_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + first_line_pattern: None, + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) + .unwrap() + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap(), + ) +} diff --git a/crates/edit_prediction_context2/src/fake_definition_lsp.rs b/crates/edit_prediction_context2/src/fake_definition_lsp.rs new file mode 100644 index 0000000000000000000000000000000000000000..31fb681309c610a37c7f886390ef5adb92ee78ef --- /dev/null +++ b/crates/edit_prediction_context2/src/fake_definition_lsp.rs @@ -0,0 +1,329 @@ +use collections::HashMap; +use futures::channel::mpsc::UnboundedReceiver; +use language::{Language, LanguageRegistry}; +use lsp::{ + FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri, +}; +use parking_lot::Mutex; +use project::Fs; +use std::{ops::Range, path::PathBuf, sync::Arc}; +use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree}; + +/// Registers a fake language server that implements go-to-definition using tree-sitter, +/// making the assumption that all names are unique, and all variables' types are +/// explicitly declared. +pub fn register_fake_definition_server( + language_registry: &Arc, + language: Arc, + fs: Arc, +) -> UnboundedReceiver { + let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone()))); + + language_registry.register_fake_lsp( + language.name(), + language::FakeLspAdapter { + name: "fake-definition-lsp", + initialization_options: None, + prettier_plugins: Vec::new(), + disk_based_diagnostics_progress_token: None, + disk_based_diagnostics_sources: Vec::new(), + language_server_binary: LanguageServerBinary { + path: PathBuf::from("fake-definition-lsp"), + arguments: Vec::new(), + env: None, + }, + capabilities: lsp::ServerCapabilities { + definition_provider: Some(lsp::OneOf::Left(true)), + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + ..Default::default() + }, + label_for_completion: None, + initializer: Some(Box::new({ + move |server| { + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + index + .lock() + .open_buffer(params.text_document.uri, ¶ms.text_document.text); + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let uri = params.text_document.uri; + let path = uri.to_file_path().ok(); + index.lock().mark_buffer_closed(&uri); + + if let Some(path) = path { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(uri, &content); + } + }) + .detach(); + } + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + for event in params.changes { + if index.lock().is_buffer_open(&event.uri) { + continue; + } + + match event.typ { + lsp::FileChangeType::DELETED => { + index.lock().remove_definitions_for_file(&event.uri); + } + lsp::FileChangeType::CREATED + | lsp::FileChangeType::CHANGED => { + if let Some(path) = event.uri.to_file_path().ok() { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(event.uri, &content); + } + } + } + _ => {} + } + } + }) + .detach(); + } + }); + + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + if let Some(change) = params.content_changes.into_iter().last() { + index + .lock() + .index_file(params.text_document.uri, &change.text); + } + } + }); + + server.handle_notification::( + { + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + let files = fs.as_fake().files(); + cx.spawn(async move |_cx| { + for folder in params.event.added { + let Ok(path) = folder.uri.to_file_path() else { + continue; + }; + for file in &files { + if let Some(uri) = Uri::from_file_path(&file).ok() + && file.starts_with(&path) + && let Ok(content) = fs.load(&file).await + { + index.lock().index_file(uri, &content); + } + } + } + }) + .detach(); + } + }, + ); + + server.set_request_handler::({ + let index = index.clone(); + move |params, _cx| { + let result = index.lock().get_definitions( + params.text_document_position_params.text_document.uri, + params.text_document_position_params.position, + ); + async move { Ok(result) } + } + }); + } + })), + }, + ) +} + +struct DefinitionIndex { + language: Arc, + definitions: HashMap>, + files: HashMap, +} + +#[derive(Debug)] +struct FileEntry { + contents: String, + is_open_in_buffer: bool, +} + +impl DefinitionIndex { + fn new(language: Arc) -> Self { + Self { + language, + definitions: HashMap::default(), + files: HashMap::default(), + } + } + + fn remove_definitions_for_file(&mut self, uri: &Uri) { + self.definitions.retain(|_, locations| { + locations.retain(|loc| &loc.uri != uri); + !locations.is_empty() + }); + self.files.remove(uri); + } + + fn open_buffer(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, true); + } + + fn mark_buffer_closed(&mut self, uri: &Uri) { + if let Some(entry) = self.files.get_mut(uri) { + entry.is_open_in_buffer = false; + } + } + + fn is_buffer_open(&self, uri: &Uri) -> bool { + self.files + .get(uri) + .map(|entry| entry.is_open_in_buffer) + .unwrap_or(false) + } + + fn index_file(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, false); + } + + fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> { + self.remove_definitions_for_file(&uri); + let grammar = self.language.grammar()?; + let outline_config = grammar.outline_config.as_ref()?; + let mut parser = Parser::new(); + parser.set_language(&grammar.ts_language).ok()?; + let tree = parser.parse(content, None)?; + let declarations = extract_declarations_from_tree(&tree, content, outline_config); + for (name, byte_range) in declarations { + let range = byte_range_to_lsp_range(content, byte_range); + let location = lsp::Location { + uri: uri.clone(), + range, + }; + self.definitions + .entry(name) + .or_insert_with(Vec::new) + .push(location); + } + self.files.insert( + uri, + FileEntry { + contents: content.to_string(), + is_open_in_buffer, + }, + ); + + Some(()) + } + + fn get_definitions( + &mut self, + uri: Uri, + position: lsp::Position, + ) -> Option { + let entry = self.files.get(&uri)?; + let name = word_at_position(&entry.contents, position)?; + let locations = self.definitions.get(name).cloned()?; + Some(lsp::GotoDefinitionResponse::Array(locations)) + } +} + +fn extract_declarations_from_tree( + tree: &Tree, + content: &str, + outline_config: &language::OutlineConfig, +) -> Vec<(String, Range)> { + let mut cursor = QueryCursor::new(); + let mut declarations = Vec::new(); + let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()); + while let Some(query_match) = matches.next() { + let mut name_range: Option> = None; + let mut has_item_range = false; + + for capture in query_match.captures { + let range = capture.node.byte_range(); + if capture.index == outline_config.name_capture_ix { + name_range = Some(range); + } else if capture.index == outline_config.item_capture_ix { + has_item_range = true; + } + } + + if let Some(name_range) = name_range + && has_item_range + { + let name = content[name_range.clone()].to_string(); + if declarations.iter().any(|(n, _)| n == &name) { + continue; + } + declarations.push((name, name_range)); + } + } + declarations +} + +fn byte_range_to_lsp_range(content: &str, byte_range: Range) -> lsp::Range { + let start = byte_offset_to_position(content, byte_range.start); + let end = byte_offset_to_position(content, byte_range.end); + lsp::Range { start, end } +} + +fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position { + let mut line = 0; + let mut character = 0; + let mut current_offset = 0; + for ch in content.chars() { + if current_offset >= offset { + break; + } + if ch == '\n' { + line += 1; + character = 0; + } else { + character += 1; + } + current_offset += ch.len_utf8(); + } + lsp::Position { line, character } +} + +fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> { + let mut lines = content.lines(); + let line = lines.nth(position.line as usize)?; + let column = position.character as usize; + if column > line.len() { + return None; + } + let start = line[..column] + .rfind(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + 1) + .unwrap_or(0); + let end = line[column..] + .find(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + column) + .unwrap_or(line.len()); + Some(&line[start..end]).filter(|word| !word.is_empty()) +} diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 85a3a720ce8c62fc4317756ec264926c981864c4..6d3aadeb5ac498b3948d871a0a87f7ecf49b6bd8 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -705,7 +705,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) { .await .unwrap(); - let mut fake_servers = language_registry.register_fake_language_server( + let mut fake_servers = language_registry.register_fake_lsp_server( LanguageServerName("gleam".into()), lsp::ServerCapabilities { completion_provider: Some(Default::default()), diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index a46f7cc35912d4c6da42ba69f7aee6d25caca2e7..7166a01ef64bff9e47c70cac47910f714ae2dc39 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -4022,6 +4022,20 @@ impl BufferSnapshot { }) } + pub fn outline_items_as_offsets_containing( + &self, + range: Range, + include_extra_context: bool, + theme: Option<&SyntaxTheme>, + ) -> Vec> { + self.outline_items_containing_internal( + range, + include_extra_context, + theme, + |buffer, range| range.to_offset(buffer), + ) + } + fn outline_items_containing_internal( &self, range: Range, diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index efef0a08127bc66f9c6d8f21fe5a545dbee20fb1..e95bc544a56ecf9d561936ca48b10ccffcb23e72 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -784,28 +784,48 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { .unindent(); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); + let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); + let outline = snapshot.outline(None); - assert_eq!( + pretty_assertions::assert_eq!( outline .items .iter() - .map(|item| (item.text.as_str(), item.depth)) + .map(|item| ( + item.text.as_str(), + item.depth, + item.to_point(&snapshot).body_range(&snapshot) + .map(|range| minimize_space(&snapshot.text_for_range(range).collect::())) + )) .collect::>(), &[ - ("struct Person", 0), - ("name", 1), - ("age", 1), - ("mod module", 0), - ("enum LoginState", 1), - ("LoggedOut", 2), - ("LoggingOn", 2), - ("LoggedIn", 2), - ("person", 3), - ("time", 3), - ("impl Eq for Person", 0), - ("impl Drop for Person", 0), - ("fn drop", 1), + ("struct Person", 0, Some("name: String, age: usize,".to_string())), + ("name", 1, None), + ("age", 1, None), + ( + "mod module", + 0, + Some( + "enum LoginState { LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, } }".to_string() + ) + ), + ( + "enum LoginState", + 1, + Some("LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, }".to_string()) + ), + ("LoggedOut", 2, None), + ("LoggingOn", 2, None), + ("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())), + ("person", 3, None), + ("time", 3, None), + ("impl Eq for Person", 0, None), + ( + "impl Drop for Person", + 0, + Some("fn drop(&mut self) { println!(\"bye\"); }".to_string()) + ), + ("fn drop", 1, Some("println!(\"bye\");".to_string())), ] ); @@ -840,6 +860,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { ] ); + fn minimize_space(text: &str) -> String { + static WHITESPACE: LazyLock = LazyLock::new(|| Regex::new("[\\n\\s]+").unwrap()); + WHITESPACE.replace_all(text, " ").trim().to_string() + } + async fn search<'a>( outline: &'a Outline, query: &'a str, diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index 022eb89e6d2b378b8c4305c81887060d776bb411..a0b04efd1b1366a101812d8656965637c13769a5 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -437,26 +437,14 @@ impl LanguageRegistry { language_name: impl Into, mut adapter: crate::FakeLspAdapter, ) -> futures::channel::mpsc::UnboundedReceiver { - let language_name = language_name.into(); let adapter_name = LanguageServerName(adapter.name.into()); let capabilities = adapter.capabilities.clone(); let initializer = adapter.initializer.take(); - let adapter = CachedLspAdapter::new(Arc::new(adapter)); - { - let mut state = self.state.write(); - state - .lsp_adapters - .entry(language_name) - .or_default() - .push(adapter.clone()); - state.all_lsp_adapters.insert(adapter.name(), adapter); - } - - self.register_fake_language_server(adapter_name, capabilities, initializer) + self.register_fake_lsp_adapter(language_name, adapter); + self.register_fake_lsp_server(adapter_name, capabilities, initializer) } /// Register a fake lsp adapter (without the language server) - /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] pub fn register_fake_lsp_adapter( &self, @@ -479,7 +467,7 @@ impl LanguageRegistry { /// Register a fake language server (without the adapter) /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] - pub fn register_fake_language_server( + pub fn register_fake_lsp_server( &self, lsp_name: LanguageServerName, capabilities: lsp::ServerCapabilities, diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 3bf4e35c6b5cfd7f2a1f221bde4cec181998ab6a..068f8e1aa39ca3422fda8eb5706c00de6f2f62ce 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -373,6 +373,8 @@ impl InlayHintSettings { pub struct EditPredictionSettings { /// The provider that supplies edit predictions. pub provider: settings::EditPredictionProvider, + /// Whether to use the experimental edit prediction context retrieval system. + pub use_context: bool, /// A list of globs representing files that edit predictions should be disabled for. /// This list adds to a pre-existing, sensible default set of globs. /// Any additional ones you add are combined with them. @@ -622,6 +624,11 @@ impl settings::Settings for AllLanguageSettings { .features .as_ref() .and_then(|f| f.edit_prediction_provider); + let use_edit_prediction_context = all_languages + .features + .as_ref() + .and_then(|f| f.experimental_edit_prediction_context_retrieval) + .unwrap_or_default(); let edit_predictions = all_languages.edit_predictions.clone().unwrap(); let edit_predictions_mode = edit_predictions.mode.unwrap(); @@ -668,6 +675,7 @@ impl settings::Settings for AllLanguageSettings { } else { EditPredictionProvider::None }, + use_context: use_edit_prediction_context, disabled_globs: disabled_globs .iter() .filter_map(|g| { diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index 2ce2b42734465a4710a7439f5e2225debc96b04a..875042bfc83ae42fb580ab848029902d68988511 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -1,4 +1,4 @@ -use crate::{BufferSnapshot, Point, ToPoint}; +use crate::{BufferSnapshot, Point, ToPoint, ToTreeSitterPoint}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{BackgroundExecutor, HighlightStyle}; use std::ops::Range; @@ -48,6 +48,54 @@ impl OutlineItem { .map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)), } } + + pub fn body_range(&self, buffer: &BufferSnapshot) -> Option> { + if let Some(range) = self.body_range.as_ref() { + return Some(range.start.to_point(buffer)..range.end.to_point(buffer)); + } + + let range = self.range.start.to_point(buffer)..self.range.end.to_point(buffer); + let start_indent = buffer.indent_size_for_line(range.start.row); + let node = buffer.syntax_ancestor(range.clone())?; + + let mut cursor = node.walk(); + loop { + let node = cursor.node(); + if node.start_position() >= range.start.to_ts_point() + && node.end_position() <= range.end.to_ts_point() + { + break; + } + cursor.goto_first_child_for_point(range.start.to_ts_point()); + } + + if !cursor.goto_last_child() { + return None; + } + let body_node = loop { + let node = cursor.node(); + if node.child_count() > 0 { + break node; + } + if !cursor.goto_previous_sibling() { + return None; + } + }; + + let mut start_row = body_node.start_position().row as u32; + let mut end_row = body_node.end_position().row as u32; + + while start_row < end_row && buffer.indent_size_for_line(start_row) == start_indent { + start_row += 1; + } + while start_row < end_row && buffer.indent_size_for_line(end_row - 1) == start_indent { + end_row -= 1; + } + if start_row < end_row { + return Some(Point::new(start_row, 0)..Point::new(end_row, 0)); + } + None + } } impl Outline { diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index 8574d52ff900563ddfb733c09204caab5eb6ae44..17285ca315fb64dd518d00039d28266c0a7f51ab 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -1215,6 +1215,19 @@ impl<'a> SyntaxMapMatches<'a> { true } + + // pub fn set_byte_range(&mut self, range: Range) { + // for layer in &mut self.layers { + // layer.matches.set_byte_range(range.clone()); + // layer.advance(); + // } + // self.layers.sort_unstable_by_key(|layer| layer.sort_key()); + // self.active_layer_count = self + // .layers + // .iter() + // .position(|layer| !layer.has_next) + // .unwrap_or(self.layers.len()); + // } } impl SyntaxMapCapturesLayer<'_> { diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 1e6ecddb5f2599a0ded0180f3afd3df0f197f037..a91d1d055d582eb2f2de4883314ad5984238103a 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -452,7 +452,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext }); let mut fake_lsp = server_cx.update(|cx| { - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("rust-analyzer".into()), lsp::ServerCapabilities { completion_provider: Some(lsp::CompletionOptions::default()), @@ -476,7 +476,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext ..FakeLspAdapter::default() }, ); - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("fake-analyzer".into()), lsp::ServerCapabilities { completion_provider: Some(lsp::CompletionOptions::default()), @@ -669,7 +669,7 @@ async fn test_remote_cancel_language_server_work( }); let mut fake_lsp = server_cx.update(|cx| { - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("rust-analyzer".into()), Default::default(), None, diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index 6b8a372269d44935e20426a0b669fed96a33dadf..b466b4e0dd88bf41e0f77f67a38842305c11906f 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -62,6 +62,8 @@ impl merge_from::MergeFrom for AllLanguageSettingsContent { pub struct FeaturesContent { /// Determines which edit prediction provider to use. pub edit_prediction_provider: Option, + /// Enables the experimental edit prediction context retrieval system. + pub experimental_edit_prediction_context_retrieval: Option, } /// The provider that supplies edit predictions. diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index c6d47a1e233b2fdf58fbc73adb622fc801832335..bf660b1302466e2b244a86b3d1e58ea2b6991067 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -8,10 +8,14 @@ use sum_tree::{Bias, Dimensions}; /// A timestamped position in a buffer #[derive(Copy, Clone, Eq, PartialEq, Hash)] pub struct Anchor { + /// The timestamp of the operation that inserted the text + /// in which this anchor is located. pub timestamp: clock::Lamport, - /// The byte offset in the buffer + /// The byte offset into the text inserted in the operation + /// at `timestamp`. pub offset: usize, - /// Describes which character the anchor is biased towards + /// Whether this anchor stays attached to the character *before* or *after* + /// the offset. pub bias: Bias, pub buffer_id: Option, } diff --git a/crates/ui/src/components/data_table.rs b/crates/ui/src/components/data_table.rs index f7cce2b85ffa3aeb9f97634c6c0fa65c46f4a8e7..9cd2a5cb7a0d802d170fcfbe6a812027c779d942 100644 --- a/crates/ui/src/components/data_table.rs +++ b/crates/ui/src/components/data_table.rs @@ -485,6 +485,7 @@ pub struct Table { interaction_state: Option>, col_widths: Option>, map_row: Option), &mut Window, &mut App) -> AnyElement>>, + use_ui_font: bool, empty_table_callback: Option AnyElement>>, } @@ -498,6 +499,7 @@ impl Table { rows: TableContents::Vec(Vec::new()), interaction_state: None, map_row: None, + use_ui_font: true, empty_table_callback: None, col_widths: None, } @@ -590,6 +592,11 @@ impl Table { self } + pub fn no_ui_font(mut self) -> Self { + self.use_ui_font = false; + self + } + pub fn map_row( mut self, callback: impl Fn((usize, Stateful
), &mut Window, &mut App) -> AnyElement + 'static, @@ -618,8 +625,8 @@ fn base_cell_style(width: Option) -> Div { .overflow_hidden() } -fn base_cell_style_text(width: Option, cx: &App) -> Div { - base_cell_style(width).text_ui(cx) +fn base_cell_style_text(width: Option, use_ui_font: bool, cx: &App) -> Div { + base_cell_style(width).when(use_ui_font, |el| el.text_ui(cx)) } pub fn render_table_row( @@ -656,7 +663,12 @@ pub fn render_table_row( .map(IntoElement::into_any_element) .into_iter() .zip(column_widths) - .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), + .map(|(cell, width)| { + base_cell_style_text(width, table_context.use_ui_font, cx) + .px_1() + .py_0p5() + .child(cell) + }), ); let row = if let Some(map_row) = table_context.map_row { @@ -700,7 +712,7 @@ pub fn render_table_header( .border_color(cx.theme().colors().border) .children(headers.into_iter().enumerate().zip(column_widths).map( |((header_idx, h), width)| { - base_cell_style_text(width, cx) + base_cell_style_text(width, table_context.use_ui_font, cx) .child(h) .id(ElementId::NamedInteger( shared_element_id.clone(), @@ -739,6 +751,7 @@ pub struct TableRenderContext { pub total_row_count: usize, pub column_widths: Option<[Length; COLS]>, pub map_row: Option), &mut Window, &mut App) -> AnyElement>>, + pub use_ui_font: bool, } impl TableRenderContext { @@ -748,6 +761,7 @@ impl TableRenderContext { total_row_count: table.rows.len(), column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)), map_row: table.map_row.clone(), + use_ui_font: table.use_ui_font, } } } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 7429fcb8e8d5e4b485f69ea87c37d7d670c3b199..b90934e67c2a689e1f7bb9704ff28a408de3049a 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -30,6 +30,7 @@ credentials_provider.workspace = true db.workspace = true edit_prediction.workspace = true edit_prediction_context.workspace = true +edit_prediction_context2.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true diff --git a/crates/zeta/src/assemble_excerpts.rs b/crates/zeta/src/assemble_excerpts.rs deleted file mode 100644 index f2a5b5adb1fcffab945cd9bdb88153bc5e494138..0000000000000000000000000000000000000000 --- a/crates/zeta/src/assemble_excerpts.rs +++ /dev/null @@ -1,173 +0,0 @@ -use cloud_llm_client::predict_edits_v3::Excerpt; -use edit_prediction_context::Line; -use language::{BufferSnapshot, Point}; -use std::ops::Range; - -pub fn assemble_excerpts( - buffer: &BufferSnapshot, - merged_line_ranges: impl IntoIterator>, -) -> Vec { - let mut output = Vec::new(); - - let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); - let mut outline_items = outline_items.into_iter().peekable(); - - for range in merged_line_ranges { - let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0); - - while let Some(outline_item) = outline_items.peek() { - if outline_item.range.start >= point_range.start { - break; - } - if outline_item.range.end > point_range.start { - let mut point_range = outline_item.source_range_for_text.clone(); - point_range.start.column = 0; - point_range.end.column = buffer.line_len(point_range.end.row); - - output.push(Excerpt { - start_line: Line(point_range.start.row), - text: buffer - .text_for_range(point_range.clone()) - .collect::() - .into(), - }) - } - outline_items.next(); - } - - output.push(Excerpt { - start_line: Line(point_range.start.row), - text: buffer - .text_for_range(point_range.clone()) - .collect::() - .into(), - }) - } - - output -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use cloud_llm_client::predict_edits_v3; - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; - use pretty_assertions::assert_eq; - use util::test::marked_text_ranges; - - #[gpui::test] - fn test_rust(cx: &mut TestAppContext) { - let table = [ - ( - indoc! {r#" - struct User { - first_name: String, - « last_name: String, - ageˇ: u32, - » email: String, - create_at: Instant, - } - - impl User { - pub fn first_name(&self) -> String { - self.first_name.clone() - } - - pub fn full_name(&self) -> String { - « format!("{} {}", self.first_name, self.last_name) - » } - } - "#}, - indoc! {r#" - 1|struct User { - … - 3| last_name: String, - 4| age<|cursor|>: u32, - … - 9|impl User { - … - 14| pub fn full_name(&self) -> String { - 15| format!("{} {}", self.first_name, self.last_name) - … - "#}, - ), - ( - indoc! {r#" - struct User { - first_name: String, - « last_name: String, - age: u32, - } - »"# - }, - indoc! {r#" - 1|struct User { - … - 3| last_name: String, - 4| age: u32, - 5|} - "#}, - ), - ]; - - for (input, expected_output) in table { - let input_without_ranges = input.replace(['«', '»'], ""); - let input_without_caret = input.replace('ˇ', ""); - let cursor_offset = input_without_ranges.find('ˇ'); - let (input, ranges) = marked_text_ranges(&input_without_caret, false); - let buffer = - cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); - buffer.read_with(cx, |buffer, _cx| { - let insertions = cursor_offset - .map(|offset| { - let point = buffer.offset_to_point(offset); - vec![( - predict_edits_v3::Point { - line: Line(point.row), - column: point.column, - }, - "<|cursor|>", - )] - }) - .unwrap_or_default(); - let ranges: Vec> = ranges - .into_iter() - .map(|range| { - let point_range = range.to_point(&buffer); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect(); - - let mut output = String::new(); - cloud_zeta2_prompt::write_excerpts( - assemble_excerpts(&buffer.snapshot(), ranges).iter(), - &insertions, - Line(buffer.max_point().row), - true, - &mut output, - ); - assert_eq!(output, expected_output); - }); - } - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(language::tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/zeta/src/retrieval_search.rs b/crates/zeta/src/retrieval_search.rs index bcc0233ff7e872a151ecddf2cf55a3cb434f02b3..f429f167744422c3641b5a68ca662af48c8e1614 100644 --- a/crates/zeta/src/retrieval_search.rs +++ b/crates/zeta/src/retrieval_search.rs @@ -1,6 +1,7 @@ use anyhow::Result; use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use collections::HashMap; +use edit_prediction_context2::{RelatedExcerpt, RelatedFile}; use futures::{ StreamExt, channel::mpsc::{self, UnboundedSender}, @@ -8,7 +9,7 @@ use futures::{ use gpui::{AppContext, AsyncApp, Entity}; use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint}; use project::{ - Project, WorktreeSettings, + Project, ProjectPath, WorktreeSettings, search::{SearchQuery, SearchResult}, }; use smol::channel; @@ -20,14 +21,14 @@ use util::{ use workspace::item::Settings as _; #[cfg(feature = "eval-support")] -type CachedSearchResults = std::collections::BTreeMap>>; +type CachedSearchResults = std::collections::BTreeMap>>; pub async fn run_retrieval_searches( queries: Vec, project: Entity, #[cfg(feature = "eval-support")] eval_cache: Option>, cx: &mut AsyncApp, -) -> Result, Vec>>> { +) -> Result> { #[cfg(feature = "eval-support")] let cache = if let Some(eval_cache) = eval_cache { use crate::EvalCacheEntryKind; @@ -54,24 +55,44 @@ pub async fn run_retrieval_searches( if let Some(cached_results) = eval_cache.read(key) { let file_results = serde_json::from_str::(&cached_results) .context("Failed to deserialize cached search results")?; - let mut results = HashMap::default(); + let mut results = Vec::new(); for (path, ranges) in file_results { + let project_path = project.update(cx, |project, cx| { + project.find_project_path(path, cx).unwrap() + })?; let buffer = project .update(cx, |project, cx| { - let project_path = project.find_project_path(path, cx).unwrap(); - project.open_buffer(project_path, cx) + project.open_buffer(project_path.clone(), cx) })? .await?; let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let mut ranges: Vec<_> = ranges .into_iter() - .map(|range| { - snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end) - }) + .map( + |Range { + start: (start_row, start_col), + end: (end_row, end_col), + }| { + snapshot.anchor_before(Point::new(start_row, start_col)) + ..snapshot.anchor_after(Point::new(end_row, end_col)) + }, + ) .collect(); merge_anchor_ranges(&mut ranges, &snapshot); - results.insert(buffer, ranges); + results.push(RelatedFile { + path: project_path, + buffer: buffer.downgrade(), + excerpts: ranges + .into_iter() + .map(|range| RelatedExcerpt { + point_range: range.to_point(&snapshot), + text: snapshot.as_rope().slice(range.to_offset(&snapshot)), + anchor_range: range, + }) + .collect(), + max_row: snapshot.max_point().row, + }); } return Ok(results); @@ -117,14 +138,29 @@ pub async fn run_retrieval_searches( #[cfg(feature = "eval-support")] let cache = cache.clone(); cx.background_spawn(async move { - let mut results: HashMap, Vec>> = HashMap::default(); + let mut results: Vec = Vec::default(); let mut snapshots = HashMap::default(); let mut total_bytes = 0; - 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await { - snapshots.insert(buffer.entity_id(), snapshot); - let existing = results.entry(buffer).or_default(); - existing.reserve(excerpts.len()); + 'outer: while let Some((project_path, buffer, snapshot, excerpts)) = results_rx.next().await + { + let existing = results + .iter_mut() + .find(|related_file| related_file.buffer.entity_id() == buffer.entity_id()); + let existing = match existing { + Some(existing) => existing, + None => { + results.push(RelatedFile { + path: project_path, + buffer: buffer.downgrade(), + excerpts: Vec::new(), + max_row: snapshot.max_point().row, + }); + results.last_mut().unwrap() + } + }; + // let existing = results.entry(buffer).or_default(); + existing.excerpts.reserve(excerpts.len()); for (range, size) in excerpts { // Blunt trimming of the results until we have a proper algorithmic filtering step @@ -133,24 +169,34 @@ pub async fn run_retrieval_searches( break 'outer; } total_bytes += size; - existing.push(range); + existing.excerpts.push(RelatedExcerpt { + point_range: range.to_point(&snapshot), + text: snapshot.as_rope().slice(range.to_offset(&snapshot)), + anchor_range: range, + }); } + snapshots.insert(buffer.entity_id(), snapshot); } #[cfg(feature = "eval-support")] if let Some((cache, queries, key)) = cache { let cached_results: CachedSearchResults = results .iter() - .filter_map(|(buffer, ranges)| { - let snapshot = snapshots.get(&buffer.entity_id())?; - let path = snapshot.file().map(|f| f.path()); - let mut ranges = ranges + .map(|related_file| { + let mut ranges = related_file + .excerpts .iter() - .map(|range| range.to_offset(&snapshot)) + .map( + |RelatedExcerpt { + point_range: Range { start, end }, + .. + }| { + (start.row, start.column)..(end.row, end.column) + }, + ) .collect::>(); ranges.sort_unstable_by_key(|range| (range.start, range.end)); - - Some((path?.as_std_path().to_path_buf(), ranges)) + (related_file.path.path.as_std_path().to_path_buf(), ranges) }) .collect(); cache.write( @@ -160,10 +206,8 @@ pub async fn run_retrieval_searches( ); } - for (buffer, ranges) in results.iter_mut() { - if let Some(snapshot) = snapshots.get(&buffer.entity_id()) { - merge_anchor_ranges(ranges, snapshot); - } + for related_file in results.iter_mut() { + related_file.merge_excerpts(); } Ok(results) @@ -171,6 +215,7 @@ pub async fn run_retrieval_searches( .await } +#[cfg(feature = "eval-support")] pub(crate) fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapshot) { ranges.sort_unstable_by(|a, b| { a.start @@ -201,6 +246,7 @@ const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; struct SearchJob { buffer: Entity, snapshot: BufferSnapshot, + project_path: ProjectPath, ranges: Vec>, query_ix: usize, jobs_tx: channel::Sender, @@ -208,7 +254,12 @@ struct SearchJob { async fn run_query( input_query: SearchToolQuery, - results_tx: UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + results_tx: UnboundedSender<( + ProjectPath, + Entity, + BufferSnapshot, + Vec<(Range, usize)>, + )>, path_style: PathStyle, exclude_matcher: PathMatcher, project: &Entity, @@ -257,12 +308,21 @@ async fn run_query( .read_with(cx, |buffer, _| buffer.parsing_idle())? .await; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let Some(file) = snapshot.file() else { + continue; + }; + + let project_path = cx.update(|cx| ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + })?; let expanded_ranges: Vec<_> = ranges .into_iter() .filter_map(|range| expand_to_parent_range(&range, &snapshot)) .collect(); jobs_tx .send(SearchJob { + project_path, buffer, snapshot, ranges: expanded_ranges, @@ -301,6 +361,13 @@ async fn run_query( while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let Some(file) = snapshot.file() else { + continue; + }; + let project_path = cx.update(|cx| ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + })?; let ranges = ranges .into_iter() @@ -314,7 +381,8 @@ async fn run_query( }) .collect(); - let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges)); + let send_result = + results_tx.unbounded_send((project_path, buffer.clone(), snapshot.clone(), ranges)); if let Err(err) = send_result && !err.is_disconnected() @@ -330,7 +398,12 @@ async fn run_query( } async fn process_nested_search_job( - results_tx: &UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + results_tx: &UnboundedSender<( + ProjectPath, + Entity, + BufferSnapshot, + Vec<(Range, usize)>, + )>, queries: &Vec, content_query: &Option, job: SearchJob, @@ -347,6 +420,7 @@ async fn process_nested_search_job( } job.jobs_tx .send(SearchJob { + project_path: job.project_path, buffer: job.buffer, snapshot: job.snapshot, ranges: subranges, @@ -382,7 +456,8 @@ async fn process_nested_search_job( }) .collect(); - let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches)); + let send_result = + results_tx.unbounded_send((job.project_path, job.buffer, job.snapshot, matches)); if let Err(err) = send_result && !err.is_disconnected() @@ -413,230 +488,3 @@ fn expand_to_parent_range( let node = snapshot.syntax_ancestor(line_range)?; Some(node.byte_range()) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::assemble_excerpts::assemble_excerpts; - use cloud_zeta2_prompt::write_codeblock; - use edit_prediction_context::Line; - use gpui::TestAppContext; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use pretty_assertions::assert_eq; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - use std::path::Path; - use util::path; - - #[gpui::test] - async fn test_retrieval(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "user.rs": indoc!{" - pub struct Organization { - owner: Arc, - } - - pub struct User { - first_name: String, - last_name: String, - } - - impl Organization { - pub fn owner(&self) -> Arc { - self.owner.clone() - } - } - - impl User { - pub fn new(first_name: String, last_name: String) -> Self { - Self { - first_name, - last_name - } - } - - pub fn first_name(&self) -> String { - self.first_name.clone() - } - - pub fn last_name(&self) -> String { - self.last_name.clone() - } - } - "}, - "main.rs": indoc!{r#" - fn main() { - let user = User::new(FIRST_NAME.clone(), "doe".into()); - println!("user {:?}", user); - } - "#}, - }), - ) - .await; - - let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(rust_lang().into()) - }); - - assert_results( - &project, - SearchToolQuery { - glob: "user.rs".into(), - syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()], - content: None, - }, - indoc! {r#" - `````root/user.rs - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - } - … - ````` - "#}, - cx, - ) - .await; - - assert_results( - &project, - SearchToolQuery { - glob: "user.rs".into(), - syntax_node: vec!["impl\\s+User".into()], - content: Some("\\.clone".into()), - }, - indoc! {r#" - `````root/user.rs - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - … - pub fn last_name(&self) -> String { - self.last_name.clone() - … - ````` - "#}, - cx, - ) - .await; - - assert_results( - &project, - SearchToolQuery { - glob: "*.rs".into(), - syntax_node: vec![], - content: Some("\\.clone".into()), - }, - indoc! {r#" - `````root/main.rs - fn main() { - let user = User::new(FIRST_NAME.clone(), "doe".into()); - … - ````` - - `````root/user.rs - … - impl Organization { - pub fn owner(&self) -> Arc { - self.owner.clone() - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - … - pub fn last_name(&self) -> String { - self.last_name.clone() - … - ````` - "#}, - cx, - ) - .await; - } - - async fn assert_results( - project: &Entity, - query: SearchToolQuery, - expected_output: &str, - cx: &mut TestAppContext, - ) { - let results = run_retrieval_searches( - vec![query], - project.clone(), - #[cfg(feature = "eval-support")] - None, - &mut cx.to_async(), - ) - .await - .unwrap(); - - let mut results = results.into_iter().collect::>(); - results.sort_by_key(|results| { - results - .0 - .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone()) - }); - - let mut output = String::new(); - for (buffer, ranges) in results { - buffer.read_with(cx, |buffer, cx| { - let excerpts = ranges.into_iter().map(|range| { - let point_range = range.to_point(buffer); - if point_range.end.column > 0 { - Line(point_range.start.row)..Line(point_range.end.row + 1) - } else { - Line(point_range.start.row)..Line(point_range.end.row) - } - }); - - write_codeblock( - &buffer.file().unwrap().full_path(cx), - assemble_excerpts(&buffer.snapshot(), excerpts).iter(), - &[], - Line(buffer.max_point().row), - false, - &mut output, - ); - }); - } - output.pop(); - - assert_eq!(output, expected_output); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(move |cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - zlog::init_test(); - }); - } -} diff --git a/crates/zeta/src/sweep_ai.rs b/crates/zeta/src/sweep_ai.rs index 8fd5398f3facc807d99951c48c749e9247fb5670..0bc0d1d41e2393212f865e402912f6d760aa252e 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/zeta/src/sweep_ai.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; +use edit_prediction_context2::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ App, AppContext as _, Entity, Task, @@ -49,6 +50,7 @@ impl SweepAi { position: language::Anchor, events: Vec>, recent_paths: &VecDeque, + related_files: Vec, diagnostic_search_range: Range, cx: &mut App, ) -> Task>> { @@ -120,6 +122,19 @@ impl SweepAi { }) .collect::>(); + let retrieval_chunks = related_files + .iter() + .flat_map(|related_file| { + related_file.excerpts.iter().map(|excerpt| FileChunk { + file_path: related_file.path.path.as_unix_str().to_string(), + start_line: excerpt.point_range.start.row as usize, + end_line: excerpt.point_range.end.row as usize, + content: excerpt.text.to_string(), + timestamp: None, + }) + }) + .collect(); + let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false); let mut diagnostic_content = String::new(); let mut diagnostic_count = 0; @@ -168,7 +183,7 @@ impl SweepAi { multiple_suggestions: false, branch: None, file_chunks, - retrieval_chunks: vec![], + retrieval_chunks, recent_user_actions: vec![], use_bytes: true, // TODO @@ -320,7 +335,7 @@ struct AutocompleteRequest { pub cursor_position: usize, pub original_file_contents: String, pub file_chunks: Vec, - pub retrieval_chunks: Vec, + pub retrieval_chunks: Vec, pub recent_user_actions: Vec, pub multiple_suggestions: bool, pub privacy_mode_enabled: bool, @@ -337,15 +352,6 @@ struct FileChunk { pub timestamp: Option, } -#[derive(Debug, Clone, Serialize)] -struct RetrievalChunk { - pub file_path: String, - pub start_line: usize, - pub end_line: usize, - pub content: String, - pub timestamp: u64, -} - #[derive(Debug, Clone, Serialize)] struct UserAction { pub action_type: ActionType, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 33d37d9e3aa0c5c89830d5ec86663330da1daf77..576067b9844cd668c69411d7a4098975db4a5d26 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result, anyhow, bail}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, @@ -14,31 +14,39 @@ use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::{ - DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, - EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line, - SyntaxIndex, SyntaxIndexState, + EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions, + EditPredictionScoreOptions, Line, SyntaxIndex, +}; +use edit_prediction_context2::{ + RelatedExcerpt, RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile, }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; -use futures::channel::mpsc::UnboundedReceiver; -use futures::channel::{mpsc, oneshot}; -use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased}; +use futures::{ + AsyncReadExt as _, FutureExt as _, StreamExt as _, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, + select_biased, +}; use gpui::BackgroundExecutor; use gpui::{ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, http_client::{self, AsyncBody, Method}, prelude::*, }; +use language::language_settings::all_language_settings; use language::{ Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint, }; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use open_ai::FunctionDefinition; -use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; +use project::{DisableAiSettings, Project, ProjectItem as _, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; use serde::de::DeserializeOwned; -use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file}; +use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file}; use std::any::{Any as _, TypeId}; use std::collections::{VecDeque, hash_map}; use telemetry_events::EditPredictionRating; @@ -52,11 +60,9 @@ use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use std::{env, mem}; use thiserror::Error; -use util::rel_path::RelPathBuf; use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -pub mod assemble_excerpts; mod license_detection; mod onboarding_modal; mod prediction; @@ -71,7 +77,6 @@ pub mod zeta1; #[cfg(test)] mod zeta_tests; -use crate::assemble_excerpts::assemble_excerpts; use crate::license_detection::LicenseDetectionWatcher; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; @@ -115,8 +120,7 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction target_before_cursor_over_total_bytes: 0.5, }; -pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = - ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS); +pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Lsp(DEFAULT_EXCERPT_OPTIONS); pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { excerpt: DEFAULT_EXCERPT_OPTIONS, @@ -190,6 +194,7 @@ pub struct Zeta { llm_token: LlmApiToken, _llm_token_subscription: Subscription, projects: HashMap, + use_context: bool, options: ZetaOptions, update_required: bool, debug_tx: Option>, @@ -225,6 +230,7 @@ pub struct ZetaOptions { pub enum ContextMode { Agentic(AgenticContextOptions), Syntax(EditPredictionContextOptions), + Lsp(EditPredictionExcerptOptions), } #[derive(Debug, Clone, PartialEq)] @@ -237,6 +243,7 @@ impl ContextMode { match self { ContextMode::Agentic(options) => &options.excerpt, ContextMode::Syntax(options) => &options.excerpt, + ContextMode::Lsp(options) => &options, } } } @@ -244,23 +251,22 @@ impl ContextMode { #[derive(Debug)] pub enum ZetaDebugInfo { ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), - SearchQueriesGenerated(ZetaSearchQueryDebugInfo), - SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), - ContextRetrievalFinished(ZetaContextRetrievalDebugInfo), + ContextRetrievalFinished(ZetaContextRetrievalFinishedDebugInfo), EditPredictionRequested(ZetaEditPredictionDebugInfo), } #[derive(Debug)] pub struct ZetaContextRetrievalStartedDebugInfo { - pub project: Entity, + pub project_entity_id: EntityId, pub timestamp: Instant, pub search_prompt: String, } #[derive(Debug)] -pub struct ZetaContextRetrievalDebugInfo { - pub project: Entity, +pub struct ZetaContextRetrievalFinishedDebugInfo { + pub project_entity_id: EntityId, pub timestamp: Instant, + pub metadata: Vec<(&'static str, SharedString)>, } #[derive(Debug)] @@ -273,17 +279,9 @@ pub struct ZetaEditPredictionDebugInfo { pub response_rx: oneshot::Receiver<(Result, Duration)>, } -#[derive(Debug)] -pub struct ZetaSearchQueryDebugInfo { - pub project: Entity, - pub timestamp: Instant, - pub search_queries: Vec, -} - pub type RequestDebugInfo = predict_edits_v3::DebugInfo; struct ZetaProject { - syntax_index: Option>, events: VecDeque>, last_event: Option, recent_paths: VecDeque, @@ -291,16 +289,26 @@ struct ZetaProject { current_prediction: Option, next_pending_prediction_id: usize, pending_predictions: ArrayVec, + context_updates_tx: smol::channel::Sender<()>, + context_updates_rx: smol::channel::Receiver<()>, last_prediction_refresh: Option<(EntityId, Instant)>, cancelled_predictions: HashSet, - context: Option, Vec>>>, - refresh_context_task: Option>>>, - refresh_context_debounce_task: Option>>, - refresh_context_timestamp: Option, + context: ZetaProjectContext, license_detection_watchers: HashMap>, _subscription: gpui::Subscription, } +enum ZetaProjectContext { + Syntax(Entity), + Lsp(Entity), + Agentic { + refresh_context_task: Option>>>, + refresh_context_debounce_task: Option>>, + refresh_context_timestamp: Option, + context: Vec, + }, +} + impl ZetaProject { pub fn events(&self, cx: &App) -> Vec> { self.events @@ -521,11 +529,12 @@ impl Zeta { }) .detach(); - Self { + let mut this = Self { projects: HashMap::default(), client, user_store, options: DEFAULT_OPTIONS, + use_context: false, llm_token, _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, @@ -549,7 +558,22 @@ impl Zeta { reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), - } + }; + + this.enable_or_disable_context_retrieval(cx); + let weak_this = cx.weak_entity(); + cx.on_flags_ready(move |_, cx| { + weak_this + .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx)) + .ok(); + }) + .detach(); + cx.observe_global::(|this, cx| { + this.enable_or_disable_context_retrieval(cx); + }) + .detach(); + + this } pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { @@ -584,29 +608,29 @@ impl Zeta { self.options = options; } + pub fn set_use_context(&mut self, use_context: bool) { + self.use_context = use_context; + } + pub fn clear_history(&mut self) { for zeta_project in self.projects.values_mut() { zeta_project.events.clear(); } } - pub fn context_for_project( - &self, + pub fn context_for_project<'a>( + &'a self, project: &Entity, - ) -> impl Iterator, &[Range])> { + cx: &'a App, + ) -> &'a [RelatedFile] { self.projects .get(&project.entity_id()) - .and_then(|project| { - Some( - project - .context - .as_ref()? - .iter() - .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())), - ) + .and_then(|project| match &project.context { + ZetaProjectContext::Syntax(_) => None, + ZetaProjectContext::Lsp(store) => Some(store.read(cx).related_files()), + ZetaProjectContext::Agentic { context, .. } => Some(context.as_slice()), }) - .into_iter() - .flatten() + .unwrap_or(&[]) } pub fn usage(&self, cx: &App) -> Option { @@ -636,34 +660,122 @@ impl Zeta { project: &Entity, cx: &mut Context, ) -> &mut ZetaProject { + let entity_id = project.entity_id(); + let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); self.projects - .entry(project.entity_id()) + .entry(entity_id) .or_insert_with(|| ZetaProject { - syntax_index: if let ContextMode::Syntax(_) = &self.options.context { - Some(cx.new(|cx| { + context: match &self.options.context { + ContextMode::Agentic(_) => ZetaProjectContext::Agentic { + refresh_context_task: None, + refresh_context_debounce_task: None, + refresh_context_timestamp: None, + context: Vec::new(), + }, + ContextMode::Syntax(_) => ZetaProjectContext::Syntax(cx.new(|cx| { SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) - })) - } else { - None + })), + ContextMode::Lsp(_) => { + let related_excerpt_store = + cx.new(|cx| RelatedExcerptStore::new(project, cx)); + cx.subscribe( + &related_excerpt_store, + move |this, _, event, _| match event { + RelatedExcerptStoreEvent::StartedRefresh => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( + ZetaContextRetrievalStartedDebugInfo { + project_entity_id: entity_id, + timestamp: Instant::now(), + search_prompt: String::new(), + }, + )) + .ok(); + } + } + RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + } => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send( + ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalFinishedDebugInfo { + project_entity_id: entity_id, + timestamp: Instant::now(), + metadata: vec![ + ( + "Cache Hits", + format!( + "{}/{}", + cache_hit_count, + cache_hit_count + + cache_miss_count + ) + .into(), + ), + ( + "Max LSP Time", + format!( + "{} ms", + max_definition_latency + .as_millis() + ) + .into(), + ), + ( + "Mean LSP Time", + format!( + "{} ms", + mean_definition_latency + .as_millis() + ) + .into(), + ), + ], + }, + ), + ) + .ok(); + } + if let Some(project_state) = this.projects.get(&entity_id) { + project_state.context_updates_tx.send_blocking(()).ok(); + } + } + }, + ) + .detach(); + ZetaProjectContext::Lsp(related_excerpt_store) + } }, events: VecDeque::new(), last_event: None, recent_paths: VecDeque::new(), + context_updates_rx, + context_updates_tx, registered_buffers: HashMap::default(), current_prediction: None, cancelled_predictions: HashSet::default(), pending_predictions: ArrayVec::new(), next_pending_prediction_id: 0, last_prediction_refresh: None, - context: None, - refresh_context_task: None, - refresh_context_debounce_task: None, - refresh_context_timestamp: None, license_detection_watchers: HashMap::default(), _subscription: cx.subscribe(&project, Self::handle_project_event), }) } + pub fn project_context_updates( + &self, + project: &Entity, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; + Some(project_state.context_updates_rx.clone()) + } + fn handle_project_event( &mut self, project: Entity, @@ -1349,6 +1461,11 @@ impl Zeta { position, events, &zeta_project.recent_paths, + if self.use_context { + self.context_for_project(&project, cx).to_vec() + } else { + Vec::new() + }, diagnostic_search_range.clone(), cx, ), @@ -1480,73 +1597,34 @@ impl Zeta { trigger: PredictEditsRequestTrigger, cx: &mut Context, ) -> Task>> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone())) - }); let options = self.options.clone(); let buffer_snapshotted_at = Instant::now(); - let Some(excerpt_path) = active_snapshot + + let Some((excerpt_path, active_project_path)) = active_snapshot .file() - .map(|path| -> Arc { path.full_path(cx).into() }) + .map(|file| -> Arc { file.full_path(cx).into() }) + .zip(active_buffer.read(cx).project_path(cx)) else { return Task::ready(Err(anyhow!("No file path for excerpt"))); }; + let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); let debug_tx = self.debug_tx.clone(); let diagnostics = active_snapshot.diagnostic_sets().clone(); let file = active_buffer.read(cx).file(); - let parent_abs_path = project::File::from_dyn(file).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); + + let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); // TODO data collection let can_collect_data = file .as_ref() .map_or(false, |file| self.can_collect_file(project, file, cx)); - let empty_context_files = HashMap::default(); - let context_files = project_state - .and_then(|project_state| project_state.context.as_ref()) - .unwrap_or(&empty_context_files); - - #[cfg(feature = "eval-support")] - let parsed_fut = futures::future::join_all( - context_files - .keys() - .map(|buffer| buffer.read(cx).parsing_idle()), - ); - - let mut included_files = context_files - .iter() - .filter_map(|(buffer_entity, ranges)| { - let buffer = buffer_entity.read(cx); - Some(( - buffer_entity.clone(), - buffer.snapshot(), - buffer.file()?.full_path(cx).into(), - ranges.clone(), - )) - }) - .collect::>(); - - included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| { - (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len())) - }); + let mut included_files = self.context_for_project(project, cx).to_vec(); #[cfg(feature = "eval-support")] let eval_cache = self.eval_cache.clone(); @@ -1554,15 +1632,6 @@ impl Zeta { let request_task = cx.background_spawn({ let active_buffer = active_buffer.clone(); async move { - #[cfg(feature = "eval-support")] - parsed_fut.await; - - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - let cursor_offset = position.to_offset(&active_snapshot); let cursor_point = cursor_offset.to_point(&active_snapshot); @@ -1576,110 +1645,84 @@ impl Zeta { options.max_diagnostic_bytes, ); - let cloud_request = match options.context { - ContextMode::Agentic(context_options) => { - let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &active_snapshot, - &context_options.excerpt, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; + let excerpt_options = options.context.excerpt(); - let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) - ..active_snapshot.anchor_before(excerpt.range.end); + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &active_snapshot, + &excerpt_options, + None, + ) else { + return Ok((None, None)); + }; - if let Some(buffer_ix) = - included_files.iter().position(|(_, snapshot, _, _)| { - snapshot.remote_id() == active_snapshot.remote_id() - }) - { - let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; - ranges.push(excerpt_anchor_range); - retrieval_search::merge_anchor_ranges(ranges, buffer); - let last_ix = included_files.len() - 1; - included_files.swap(buffer_ix, last_ix); - } else { - included_files.push(( - active_buffer.clone(), - active_snapshot.clone(), - excerpt_path.clone(), - vec![excerpt_anchor_range], - )); - } + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); + let related_excerpt = RelatedExcerpt { + anchor_range: excerpt_anchor_range.clone(), + point_range: Point::new(excerpt.line_range.start.0, 0) + ..Point::new(excerpt.line_range.end.0, 0), + text: active_snapshot.as_rope().slice(excerpt.range), + }; + + if let Some(buffer_ix) = included_files + .iter() + .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) + { + let file = &mut included_files[buffer_ix]; + file.excerpts.push(related_excerpt); + file.merge_excerpts(); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + let active_file = RelatedFile { + path: active_project_path, + buffer: active_buffer.downgrade(), + excerpts: vec![related_excerpt], + max_row: active_snapshot.max_point().row, + }; + included_files.push(active_file); + } - let included_files = included_files + let included_files = included_files + .iter() + .map(|related_file| predict_edits_v3::IncludedFile { + path: Arc::from(related_file.path.path.as_std_path()), + max_row: Line(related_file.max_row), + excerpts: related_file + .excerpts .iter() - .map(|(_, snapshot, path, ranges)| { - let ranges = ranges - .iter() - .map(|range| { - let point_range = range.to_point(&snapshot); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect::>(); - let excerpts = assemble_excerpts(&snapshot, ranges); - predict_edits_v3::IncludedFile { - path: path.clone(), - max_row: Line(snapshot.max_point().row), - excerpts, - } + .map(|excerpt| predict_edits_v3::Excerpt { + start_line: Line(excerpt.point_range.start.row), + text: excerpt.text.to_string().into(), }) - .collect::>(); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: String::new(), - excerpt_line_range: Line(0)..Line(0), - excerpt_range: 0..0, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(cursor_point.row), - column: cursor_point.column, - }, - included_files, - referenced_declarations: vec![], - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - debug_info: debug_tx.is_some(), - prompt_max_bytes: Some(options.max_prompt_bytes), - prompt_format: options.prompt_format, - // TODO [zeta2] - signatures: vec![], - excerpt_parent: None, - git_info: None, - trigger, - } - } - ContextMode::Syntax(context_options) => { - let Some(context) = EditPredictionContext::gather_context( - cursor_point, - &active_snapshot, - parent_abs_path.as_deref(), - &context_options, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; - - make_syntax_context_cloud_request( - excerpt_path, - context, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - None, - debug_tx.is_some(), - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - trigger, - ) - } + .collect(), + }) + .collect::>(); + + let cloud_request = predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + included_files, + referenced_declarations: vec![], + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + // TODO [zeta2] + signatures: vec![], + excerpt_parent: None, + git_info: None, + trigger, }; let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); @@ -1787,18 +1830,17 @@ impl Zeta { } let get_buffer_from_context = |path: &Path| { - included_files - .iter() - .find_map(|(_, buffer, probe_path, ranges)| { - if probe_path.as_ref() == path { - Some((buffer, ranges.as_slice())) - } else { - None - } - }) + if Some(path) == active_file_full_path.as_deref() { + Some(( + &active_snapshot, + std::slice::from_ref(&excerpt_anchor_range), + )) + } else { + None + } }; - let (edited_buffer_snapshot, edits) = match options.prompt_format { + let (_, edits) = match options.prompt_format { PromptFormat::NumLinesUniDiff => { // TODO: Implement parsing of multi-file diffs crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? @@ -1822,24 +1864,13 @@ impl Zeta { } }; - let edited_buffer = included_files - .iter() - .find_map(|(buffer, snapshot, _, _)| { - if snapshot.remote_id() == edited_buffer_snapshot.remote_id() { - Some(buffer.clone()) - } else { - None - } - }) - .context("Failed to find buffer in included_buffers")?; - anyhow::Ok(( Some(( request_id, Some(( inputs, - edited_buffer, - edited_buffer_snapshot.clone(), + active_buffer, + active_snapshot.clone(), edits, received_response_at, )), @@ -2058,61 +2089,78 @@ impl Zeta { pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); - // Refresh the related excerpts when the user just beguns editing after - // an idle period, and after they pause editing. - fn refresh_context_if_needed( + pub fn refresh_context_if_needed( &mut self, project: &Entity, buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, ) { - if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) { + if !self.use_context { return; } - - if !matches!(&self.options().context, ContextMode::Agentic { .. }) { - return; - } - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { return; }; - let now = Instant::now(); - let was_idle = zeta_project - .refresh_context_timestamp - .map_or(true, |timestamp| { - now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION - }); - zeta_project.refresh_context_timestamp = Some(now); - zeta_project.refresh_context_debounce_task = Some(cx.spawn({ - let buffer = buffer.clone(); - let project = project.clone(); - async move |this, cx| { - if was_idle { - log::debug!("refetching edit prediction context after idle"); - } else { - cx.background_executor() - .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) - .await; - log::debug!("refetching edit prediction context after pause"); - } - this.update(cx, |this, cx| { - let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); + match &mut zeta_project.context { + ZetaProjectContext::Syntax(_entity) => {} + ZetaProjectContext::Lsp(related_excerpt_store) => { + related_excerpt_store.update(cx, |store, cx| { + store.refresh(buffer.clone(), cursor_position, cx); + }); + } + ZetaProjectContext::Agentic { + refresh_context_debounce_task, + refresh_context_timestamp, + .. + } => { + let now = Instant::now(); + let was_idle = refresh_context_timestamp.map_or(true, |timestamp| { + now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION + }); + *refresh_context_timestamp = Some(now); + *refresh_context_debounce_task = Some(cx.spawn({ + let buffer = buffer.clone(); + let project = project.clone(); + async move |this, cx| { + if was_idle { + log::debug!("refetching edit prediction context after idle"); + } else { + cx.background_executor() + .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) + .await; + log::debug!("refetching edit prediction context after pause"); + } + this.update(cx, |this, cx| { + let task = this.refresh_context_with_agentic_retrieval( + project.clone(), + buffer, + cursor_position, + cx, + ); - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - zeta_project.refresh_context_task = Some(task.log_err()); - }; - }) - .ok() + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) + { + if let ZetaProjectContext::Agentic { + refresh_context_task, + .. + } = &mut zeta_project.context + { + *refresh_context_task = Some(task.log_err()); + } + }; + }) + .ok() + } + })); } - })); + } } // Refresh the related excerpts asynchronously. Ensure the task runs to completion, // and avoid spawning more than one concurrent task. - pub fn refresh_context( + pub fn refresh_context_with_agentic_retrieval( &mut self, project: Entity, buffer: Entity, @@ -2162,12 +2210,14 @@ impl Zeta { } }; + let retrieval_started_at = Instant::now(); + if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( ZetaContextRetrievalStartedDebugInfo { - project: project.clone(), - timestamp: Instant::now(), + project_entity_id: project.entity_id(), + timestamp: retrieval_started_at, search_prompt: prompt.clone(), }, )) @@ -2260,19 +2310,8 @@ impl Zeta { queries.extend(input.queries); } - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( - ZetaSearchQueryDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - search_queries: queries.clone(), - }, - )) - .ok(); - } - log::trace!("Running retrieval search: {queries:#?}"); + let query_generation_finished_at = Instant::now(); let related_excerpts_result = retrieval_search::run_retrieval_searches( queries, @@ -2284,54 +2323,62 @@ impl Zeta { .await; log::trace!("Search queries executed"); - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } + let query_execution_finished_at = Instant::now(); this.update(cx, |this, _cx| { let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { return Ok(()); }; - zeta_project.refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalDebugInfo { - project, - timestamp: Instant::now(), - }, - )) - .ok(); - } - match related_excerpts_result { - Ok(excerpts) => { - zeta_project.context = Some(excerpts); - Ok(()) + if let ZetaProjectContext::Agentic { + refresh_context_task, + context, + .. + } = &mut zeta_project.context + { + refresh_context_task.take(); + if let Some(debug_tx) = &this.debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalFinishedDebugInfo { + project_entity_id: project.entity_id(), + timestamp: Instant::now(), + metadata: vec![ + ( + "query_generation", + format!( + "{:?}", + query_generation_finished_at - retrieval_started_at + ) + .into(), + ), + ( + "search_execution", + format!( + "{:?}", + query_execution_finished_at + - query_generation_finished_at + ) + .into(), + ), + ], + }, + )) + .ok(); + } + match related_excerpts_result { + Ok(excerpts) => { + *context = excerpts; + Ok(()) + } + Err(error) => Err(error), } - Err(error) => Err(error), + } else { + Ok(()) } })? }) } - pub fn set_context( - &mut self, - project: Entity, - context: HashMap, Vec>>, - ) { - if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) { - zeta_project.context = Some(context); - } - } - fn gather_nearby_diagnostics( cursor_offset: usize, diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], @@ -2378,92 +2425,13 @@ impl Zeta { (results, diagnostic_groups_truncated) } - // TODO: Dedupe with similar code in request_prediction? - pub fn cloud_request_for_zeta_cli( - &mut self, - project: &Entity, - buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|index| index.read_with(cx, |index, _cx| index.state().clone())) - }); - let options = self.options.clone(); - let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); - - let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); - - cx.background_spawn(async move { - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - - let cursor_point = position.to_point(&snapshot); - - let debug_info = true; - EditPredictionContext::gather_context( - cursor_point, - &snapshot, - parent_abs_path.as_deref(), - match &options.context { - ContextMode::Agentic(_) => { - // TODO - panic!("Llm mode not supported in zeta cli yet"); - } - ContextMode::Syntax(edit_prediction_context_options) => { - edit_prediction_context_options - } - }, - index_state.as_deref(), - ) - .context("Failed to select excerpt") - .map(|context| { - make_syntax_context_cloud_request( - excerpt_path.into(), - context, - // TODO pass everything - Vec::new(), - false, - Vec::new(), - false, - None, - debug_info, - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - PredictEditsRequestTrigger::Other, - ) - }) - }) - } - pub fn wait_for_initial_indexing( &mut self, project: &Entity, cx: &mut Context, ) -> Task> { let zeta_project = self.get_or_init_zeta_project(project, cx); - if let Some(syntax_index) = &zeta_project.syntax_index { + if let ZetaProjectContext::Syntax(syntax_index) = &zeta_project.context { syntax_index.read(cx).wait_for_initial_file_indexing(cx) } else { Task::ready(Ok(())) @@ -2555,6 +2523,11 @@ impl Zeta { self.client.telemetry().flush_events().detach(); cx.notify(); } + + fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, Zeta>) { + self.use_context = cx.has_flag::() + && all_language_settings(None, cx).edit_predictions.use_context; + } } pub fn text_from_response(mut res: open_ai::Response) -> Option { @@ -2597,131 +2570,6 @@ pub struct ZedUpdateRequiredError { minimum_version: Version, } -fn make_syntax_context_cloud_request( - excerpt_path: Arc, - context: EditPredictionContext, - events: Vec>, - can_collect_data: bool, - diagnostic_groups: Vec, - diagnostic_groups_truncated: bool, - git_info: Option, - debug_info: bool, - worktrees: &Vec, - index_state: Option<&SyntaxIndexState>, - prompt_max_bytes: Option, - prompt_format: PromptFormat, - trigger: PredictEditsRequestTrigger, -) -> predict_edits_v3::PredictEditsRequest { - let mut signatures = Vec::new(); - let mut declaration_to_signature_index = HashMap::default(); - let mut referenced_declarations = Vec::new(); - - for snippet in context.declarations { - let project_entry_id = snippet.declaration.project_entry_id(); - let Some(path) = worktrees.iter().find_map(|worktree| { - worktree.entry_for_id(project_entry_id).map(|entry| { - let mut full_path = RelPathBuf::new(); - full_path.push(worktree.root_name()); - full_path.push(&entry.path); - full_path - }) - }) else { - continue; - }; - - let parent_index = index_state.and_then(|index_state| { - snippet.declaration.parent().and_then(|parent| { - add_signature( - parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - let (text, text_is_truncated) = snippet.declaration.item_text(); - referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { - path: path.as_std_path().into(), - text: text.into(), - range: snippet.declaration.item_line_range(), - text_is_truncated, - signature_range: snippet.declaration.signature_range_in_item_text(), - parent_index, - signature_score: snippet.score(DeclarationStyle::Signature), - declaration_score: snippet.score(DeclarationStyle::Declaration), - score_components: snippet.components, - }); - } - - let excerpt_parent = index_state.and_then(|index_state| { - context - .excerpt - .parent_declarations - .last() - .and_then(|(parent, _)| { - add_signature( - *parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: context.excerpt_text.body, - excerpt_line_range: context.excerpt.line_range, - excerpt_range: context.excerpt.range, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(context.cursor_point.row), - column: context.cursor_point.column, - }, - referenced_declarations, - included_files: vec![], - signatures, - excerpt_parent, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - git_info, - debug_info, - prompt_max_bytes, - prompt_format, - trigger, - } -} - -fn add_signature( - declaration_id: DeclarationId, - declaration_to_signature_index: &mut HashMap, - signatures: &mut Vec, - index: &SyntaxIndexState, -) -> Option { - if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { - return Some(*signature_index); - } - let Some(parent_declaration) = index.declaration(declaration_id) else { - log::error!("bug: missing parent declaration"); - return None; - }; - let parent_index = parent_declaration.parent().and_then(|parent| { - add_signature(parent, declaration_to_signature_index, signatures, index) - }); - let (text, text_is_truncated) = parent_declaration.signature_text(); - let signature_index = signatures.len(); - signatures.push(Signature { - text: text.into(), - text_is_truncated, - parent_index, - range: parent_declaration.signature_line_range(), - }); - declaration_to_signature_index.insert(declaration_id, signature_index); - Some(signature_index) -} - #[cfg(feature = "eval-support")] pub type EvalCacheKey = (EvalCacheEntryKind, u64); @@ -2917,7 +2765,6 @@ mod tests { use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, }; - use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, @@ -2929,6 +2776,7 @@ mod tests { }; use indoc::indoc; use language::OffsetRangeExt as _; + use lsp::LanguageServerId; use open_ai::Usage; use pretty_assertions::{assert_eq, assert_matches}; use project::{FakeFs, Project}; @@ -2959,7 +2807,8 @@ mod tests { let buffer1 = project .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); + let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); project.open_buffer(path, cx) }) .await @@ -2995,58 +2844,38 @@ mod tests { assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); - // Context refresh - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); - respond_tx - .send(open_ai::Response { - id: Uuid::new_v4().to_string(), - object: "response".into(), - created: 0, - model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![open_ai::ToolCall { - id: "search".into(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME - .to_string(), - arguments: serde_json::to_string(&SearchToolInput { - queries: Box::new([SearchToolQuery { - glob: "root/2.txt".to_string(), - syntax_node: vec![], - content: Some(".".into()), - }]), - }) - .unwrap(), - }, - }, - }], - }, - finish_reason: None, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) - .unwrap(); - refresh_task.await.unwrap(); - zeta.update(cx, |zeta, _cx| { zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); }); - // Prediction for another file - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + // Prediction for diagnostic in another file + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); }); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(model_response(indoc! {r#" @@ -4018,7 +3847,6 @@ mod tests { let mut buf = Vec::new(); body.read_to_end(&mut buf).await.ok(); let req = serde_json::from_slice(&buf).unwrap(); - let (res_tx, res_rx) = oneshot::channel(); predict_req_tx.unbounded_send((req, res_tx)).unwrap(); serde_json::to_string(&res_rx.await?).unwrap() diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 607e24c895d96de1464ff1bfa2a4dfa01c5d9669..8e20224736c658d4d80d678b29d4231ec7e4b2f5 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -15,7 +15,6 @@ path = "src/zeta2_tools.rs" anyhow.workspace = true client.workspace = true cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true collections.workspace = true edit_prediction_context.workspace = true editor.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 54f1ea2d813f7c00d30b12e341fb3e5ac3f155dc..882846929a62f90f349d40f8f6b6996f83613ec7 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -8,26 +8,25 @@ use std::{ use anyhow::Result; use client::{Client, UserStore}; -use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use editor::{Editor, PathKey}; use futures::StreamExt as _; use gpui::{ Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle, - Focusable, ParentElement as _, SharedString, Styled as _, Task, TextAlign, Window, actions, - pulsating_between, + Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString, + Styled as _, Task, TextAlign, Window, actions, div, pulsating_between, }; use multi_buffer::MultiBuffer; use project::Project; use text::OffsetRangeExt; use ui::{ - ButtonCommon, Clickable, Color, Disableable, FluentBuilder as _, Icon, IconButton, IconName, - IconSize, InteractiveElement, IntoElement, ListHeader, ListItem, StyledTypography, div, h_flex, - v_flex, + ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName, + StyledTypography as _, h_flex, v_flex, }; + use workspace::Item; use zeta::{ - Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, - ZetaSearchQueryDebugInfo, + Zeta, ZetaContextRetrievalFinishedDebugInfo, ZetaContextRetrievalStartedDebugInfo, + ZetaDebugInfo, }; pub struct Zeta2ContextView { @@ -42,10 +41,8 @@ pub struct Zeta2ContextView { #[derive(Debug)] struct RetrievalRun { editor: Entity, - search_queries: Vec, started_at: Instant, - search_results_generated_at: Option, - search_results_executed_at: Option, + metadata: Vec<(&'static str, SharedString)>, finished_at: Option, } @@ -97,22 +94,12 @@ impl Zeta2ContextView { ) { match event { ZetaDebugInfo::ContextRetrievalStarted(info) => { - if info.project == self.project { + if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_started(info, window, cx); } } - ZetaDebugInfo::SearchQueriesGenerated(info) => { - if info.project == self.project { - self.handle_search_queries_generated(info, window, cx); - } - } - ZetaDebugInfo::SearchQueriesExecuted(info) => { - if info.project == self.project { - self.handle_search_queries_executed(info, window, cx); - } - } ZetaDebugInfo::ContextRetrievalFinished(info) => { - if info.project == self.project { + if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_finished(info, window, cx); } } @@ -129,7 +116,7 @@ impl Zeta2ContextView { if self .runs .back() - .is_some_and(|run| run.search_results_executed_at.is_none()) + .is_some_and(|run| run.finished_at.is_none()) { self.runs.pop_back(); } @@ -144,11 +131,9 @@ impl Zeta2ContextView { self.runs.push_back(RetrievalRun { editor, - search_queries: Vec::new(), started_at: info.timestamp, - search_results_generated_at: None, - search_results_executed_at: None, finished_at: None, + metadata: Vec::new(), }); cx.notify(); @@ -156,7 +141,7 @@ impl Zeta2ContextView { fn handle_context_retrieval_finished( &mut self, - info: ZetaContextRetrievalDebugInfo, + info: ZetaContextRetrievalFinishedDebugInfo, window: &mut Window, cx: &mut Context, ) { @@ -165,67 +150,72 @@ impl Zeta2ContextView { }; run.finished_at = Some(info.timestamp); + run.metadata = info.metadata; + + let project = self.project.clone(); + let related_files = self + .zeta + .read(cx) + .context_for_project(&self.project, cx) + .to_vec(); + let editor = run.editor.clone(); let multibuffer = run.editor.read(cx).buffer().clone(); - multibuffer.update(cx, |multibuffer, cx| { - multibuffer.clear(cx); - let context = self.zeta.read(cx).context_for_project(&self.project); - let mut paths = Vec::new(); - for (buffer, ranges) in context { - let path = PathKey::for_buffer(&buffer, cx); - let snapshot = buffer.read(cx).snapshot(); - let ranges = ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>(); - paths.push((path, buffer, ranges)); - } + if self.current_ix + 2 == self.runs.len() { + self.current_ix += 1; + } - for (path, buffer, ranges) in paths { - multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx); + cx.spawn_in(window, async move |this, cx| { + let mut paths = Vec::new(); + for related_file in related_files { + let (buffer, point_ranges): (_, Vec<_>) = + if let Some(buffer) = related_file.buffer.upgrade() { + let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + + ( + buffer, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.anchor_range.to_point(&snapshot)) + .collect(), + ) + } else { + ( + project + .update(cx, |project, cx| { + project.open_buffer(related_file.path.clone(), cx) + })? + .await?, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.point_range.clone()) + .collect(), + ) + }; + cx.update(|_, cx| { + let path = PathKey::for_buffer(&buffer, cx); + paths.push((path, buffer, point_ranges)); + })?; } - }); - - run.editor.update(cx, |editor, cx| { - editor.move_to_beginning(&Default::default(), window, cx); - }); - - cx.notify(); - } - - fn handle_search_queries_generated( - &mut self, - info: ZetaSearchQueryDebugInfo, - _window: &mut Window, - cx: &mut Context, - ) { - let Some(run) = self.runs.back_mut() else { - return; - }; - run.search_results_generated_at = Some(info.timestamp); - run.search_queries = info.search_queries; - cx.notify(); - } + multibuffer.update(cx, |multibuffer, cx| { + multibuffer.clear(cx); - fn handle_search_queries_executed( - &mut self, - info: ZetaContextRetrievalDebugInfo, - _window: &mut Window, - cx: &mut Context, - ) { - if self.current_ix + 2 == self.runs.len() { - // Switch to latest when the queries are executed - self.current_ix += 1; - } + for (path, buffer, ranges) in paths { + multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx); + } + })?; - let Some(run) = self.runs.back_mut() else { - return; - }; + editor.update_in(cx, |editor, window, cx| { + editor.move_to_beginning(&Default::default(), window, cx); + })?; - run.search_results_executed_at = Some(info.timestamp); - cx.notify(); + this.update(cx, |_, cx| cx.notify()) + }) + .detach(); } fn handle_go_back( @@ -254,8 +244,11 @@ impl Zeta2ContextView { } fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div { - let is_latest = self.runs.len() == self.current_ix + 1; let run = &self.runs[self.current_ix]; + let new_run_started = self + .runs + .back() + .map_or(false, |latest_run| latest_run.finished_at.is_none()); h_flex() .p_2() @@ -264,114 +257,65 @@ impl Zeta2ContextView { .text_xs() .border_t_1() .gap_2() + .child(v_flex().h_full().flex_1().child({ + let t0 = run.started_at; + let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font(); + for (key, value) in &run.metadata { + table = table.row([key.into_any_element(), value.clone().into_any_element()]) + } + table = table.row([ + "Total Time".into_any_element(), + format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis()) + .into_any_element(), + ]); + table + })) .child( - v_flex().h_full().flex_1().children( - run.search_queries - .iter() - .enumerate() - .flat_map(|(ix, query)| { - std::iter::once(ListHeader::new(query.glob.clone()).into_any_element()) - .chain(query.syntax_node.iter().enumerate().map( - move |(regex_ix, regex)| { - ListItem::new(ix * 100 + regex_ix) - .start_slot( - Icon::new(IconName::MagnifyingGlass) - .color(Color::Muted) - .size(IconSize::Small), - ) - .child(regex.clone()) - .into_any_element() - }, + v_flex().h_full().text_align(TextAlign::Right).child( + h_flex() + .justify_end() + .child( + IconButton::new("go-back", IconName::ChevronLeft) + .disabled(self.current_ix == 0 || self.runs.len() < 2) + .tooltip(ui::Tooltip::for_action_title( + "Go to previous run", + &Zeta2ContextGoBack, )) - .chain(query.content.as_ref().map(move |regex| { - ListItem::new(ix * 100 + query.syntax_node.len()) - .start_slot( - Icon::new(IconName::MagnifyingGlass) - .color(Color::Muted) - .size(IconSize::Small), + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_back(&Zeta2ContextGoBack, window, cx); + })), + ) + .child( + div() + .child(format!("{}/{}", self.current_ix + 1, self.runs.len())) + .map(|this| { + if new_run_started { + this.with_animation( + "pulsating-count", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.opacity(delta), ) - .child(regex.clone()) .into_any_element() - })) - }), + } else { + this.into_any_element() + } + }), + ) + .child( + IconButton::new("go-forward", IconName::ChevronRight) + .disabled(self.current_ix + 1 == self.runs.len()) + .tooltip(ui::Tooltip::for_action_title( + "Go to next run", + &Zeta2ContextGoBack, + )) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_forward(&Zeta2ContextGoForward, window, cx); + })), + ), ), ) - .child( - v_flex() - .h_full() - .text_align(TextAlign::Right) - .child( - h_flex() - .justify_end() - .child( - IconButton::new("go-back", IconName::ChevronLeft) - .disabled(self.current_ix == 0 || self.runs.len() < 2) - .tooltip(ui::Tooltip::for_action_title( - "Go to previous run", - &Zeta2ContextGoBack, - )) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_back(&Zeta2ContextGoBack, window, cx); - })), - ) - .child( - div() - .child(format!("{}/{}", self.current_ix + 1, self.runs.len())) - .map(|this| { - if self.runs.back().is_some_and(|back| { - back.search_results_executed_at.is_none() - }) { - this.with_animation( - "pulsating-count", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 0.8)), - |label, delta| label.opacity(delta), - ) - .into_any_element() - } else { - this.into_any_element() - } - }), - ) - .child( - IconButton::new("go-forward", IconName::ChevronRight) - .disabled(self.current_ix + 1 == self.runs.len()) - .tooltip(ui::Tooltip::for_action_title( - "Go to next run", - &Zeta2ContextGoBack, - )) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_forward(&Zeta2ContextGoForward, window, cx); - })), - ), - ) - .map(|mut div| { - let pending_message = |div: ui::Div, msg: &'static str| { - if is_latest { - return div.child(msg); - } else { - return div.child("Canceled"); - } - }; - - let t0 = run.started_at; - let Some(t1) = run.search_results_generated_at else { - return pending_message(div, "Planning search..."); - }; - div = div.child(format!("Planned search: {:>5} ms", (t1 - t0).as_millis())); - - let Some(t2) = run.search_results_executed_at else { - return pending_message(div, "Running search..."); - }; - div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis())); - - div.child(format!( - "Total: {:>5} ms", - (run.finished_at.unwrap_or(t0) - t0).as_millis() - )) - }), - ) } } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 4e650f2405d63feab010c5c9b73efc75bd576af6..26d68b075153557ab50ed0a231c5d45f0bb9646c 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -108,6 +108,7 @@ pub struct Zeta2Inspector { pub enum ContextModeState { Llm, + Lsp, Syntax { max_retrieved_declarations: Entity, }, @@ -222,6 +223,9 @@ impl Zeta2Inspector { ), }; } + ContextMode::Lsp(_) => { + self.context_mode = ContextModeState::Lsp; + } } cx.notify(); } @@ -302,6 +306,9 @@ impl Zeta2Inspector { ContextModeState::Syntax { max_retrieved_declarations, } => number_input_value(max_retrieved_declarations, cx), + ContextModeState::Lsp => { + zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations + } }; ContextMode::Syntax(EditPredictionContextOptions { @@ -310,6 +317,7 @@ impl Zeta2Inspector { ..context_options }) } + ContextMode::Lsp(excerpt_options) => ContextMode::Lsp(excerpt_options), }; this.set_zeta_options( @@ -656,6 +664,7 @@ impl Zeta2Inspector { ContextModeState::Syntax { max_retrieved_declarations, } => Some(max_retrieved_declarations.clone()), + ContextModeState::Lsp => None, }) .child(self.max_prompt_bytes_input.clone()) .child(self.render_prompt_format_dropdown(window, cx)), @@ -679,6 +688,7 @@ impl Zeta2Inspector { match &self.context_mode { ContextModeState::Llm => "LLM-based", ContextModeState::Syntax { .. } => "Syntax", + ContextModeState::Lsp => "LSP-based", }, ContextMenu::build(window, cx, move |menu, _window, _cx| { menu.item( @@ -695,6 +705,7 @@ impl Zeta2Inspector { this.zeta.read(cx).options().clone(); match current_options.context.clone() { ContextMode::Agentic(_) => {} + ContextMode::Lsp(_) => {} ContextMode::Syntax(context_options) => { let options = ZetaOptions { context: ContextMode::Agentic( @@ -739,6 +750,7 @@ impl Zeta2Inspector { this.set_zeta_options(options, cx); } ContextMode::Syntax(_) => {} + ContextMode::Lsp(_) => {} } }) .ok(); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index a9c8b5998cdd32a94a71f1894dfbdc40c22abaed..42c0ea185f4401a11c2798f9402a59829f8463df 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -21,15 +21,12 @@ use ::util::paths::PathStyle; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; -use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, -}; +use edit_prediction_context::EditPredictionExcerptOptions; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, Point}; use metrics::delta_chr_f; -use project::{Project, Worktree}; +use project::{Project, Worktree, lsp_store::OpenLspBufferHandle}; use reqwest_client::ReqwestClient; -use serde_json::json; use std::io::{self}; use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; @@ -97,7 +94,7 @@ struct ContextArgs { enum ContextProvider { Zeta1, #[default] - Syntax, + Zeta2, } #[derive(Clone, Debug, Args)] @@ -204,19 +201,12 @@ enum PredictionProvider { Sweep, } -fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions { +fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions { zeta::ZetaOptions { - context: ContextMode::Syntax(EditPredictionContextOptions { - max_retrieved_declarations: args.max_retrieved_definitions, - use_imports: !args.disable_imports_gathering, - excerpt: EditPredictionExcerptOptions { - max_bytes: args.max_excerpt_bytes, - min_bytes: args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps, - }, + context: ContextMode::Lsp(EditPredictionExcerptOptions { + max_bytes: args.max_excerpt_bytes, + min_bytes: args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, }), max_diagnostic_bytes: args.max_diagnostic_bytes, max_prompt_bytes: args.max_prompt_bytes, @@ -295,6 +285,7 @@ struct LoadedContext { worktree: Entity, project: Entity, buffer: Entity, + lsp_open_handle: Option, } async fn load_context( @@ -330,7 +321,7 @@ async fn load_context( .await?; let mut ready_languages = HashSet::default(); - let (_lsp_open_handle, buffer) = if *use_language_server { + let (lsp_open_handle, buffer) = if *use_language_server { let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( project.clone(), worktree.clone(), @@ -377,10 +368,11 @@ async fn load_context( worktree, project, buffer, + lsp_open_handle, }) } -async fn zeta2_syntax_context( +async fn zeta2_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, @@ -390,6 +382,7 @@ async fn zeta2_syntax_context( project, buffer, clipped_cursor, + lsp_open_handle: _handle, .. } = load_context(&args, app_state, cx).await?; @@ -406,30 +399,26 @@ async fn zeta2_syntax_context( zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) }); let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true)); + zeta.set_options(zeta2_args_to_options(&args.zeta2_args)); zeta.register_buffer(&buffer, &project, cx); zeta.wait_for_initial_indexing(&project, cx) }); cx.spawn(async move |cx| { indexing_done_task.await?; - let request = zeta - .update(cx, |zeta, cx| { - let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) - })? - .await?; - - let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?; - - match args.zeta2_args.output_format { - OutputFormat::Prompt => anyhow::Ok(prompt_string), - OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?), - OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ - "request": request, - "prompt": prompt_string, - "section_labels": section_labels, - }))?), - } + let updates_rx = zeta.update(cx, |zeta, cx| { + let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); + zeta.set_use_context(true); + zeta.refresh_context_if_needed(&project, &buffer, cursor, cx); + zeta.project_context_updates(&project).unwrap() + })?; + + updates_rx.recv().await.ok(); + + let context = zeta.update(cx, |zeta, cx| { + zeta.context_for_project(&project, cx).to_vec() + })?; + + anyhow::Ok(serde_json::to_string_pretty(&context).unwrap()) }) })? .await?; @@ -482,7 +471,6 @@ fn main() { None => { if args.printenv { ::util::shell_env::print_env(); - return; } else { panic!("Expected a command"); } @@ -494,7 +482,7 @@ fn main() { arguments.extension, arguments.limit, arguments.skip, - zeta2_args_to_options(&arguments.zeta2_args, false), + zeta2_args_to_options(&arguments.zeta2_args), cx, ) .await; @@ -507,10 +495,8 @@ fn main() { zeta1_context(context_args, &app_state, cx).await.unwrap(); serde_json::to_string_pretty(&context.body).unwrap() } - ContextProvider::Syntax => { - zeta2_syntax_context(context_args, &app_state, cx) - .await - .unwrap() + ContextProvider::Zeta2 => { + zeta2_context(context_args, &app_state, cx).await.unwrap() } }; println!("{}", result); diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 99fe65cfa3221a1deb18e767e8faa8e1a1fca0ac..9fefc5ce97672796f79482e23acca3599aa1ff44 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -136,8 +136,7 @@ pub async fn perform_predict( let result = result.clone(); async move { let mut start_time = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; + let mut retrieval_finished_at = None; while let Some(event) = debug_rx.next().await { match event { zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => { @@ -147,17 +146,17 @@ pub async fn perform_predict( &info.search_prompt, )?; } - zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - example_run_dir.join("search_queries.json"), - serde_json::to_string_pretty(&info.search_queries).unwrap(), - )?; - } - zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); + zeta::ZetaDebugInfo::ContextRetrievalFinished(info) => { + retrieval_finished_at = Some(info.timestamp); + for (key, value) in &info.metadata { + if *key == "search_queries" { + fs::write( + example_run_dir.join("search_queries.json"), + value.as_bytes(), + )?; + } + } } - zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} zeta::ZetaDebugInfo::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); @@ -200,13 +199,8 @@ pub async fn perform_predict( let mut result = result.lock().unwrap(); result.generated_len = response.chars().count(); - - result.planning_search_time = - Some(search_queries_generated_at.unwrap() - start_time.unwrap()); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), - ); + result.retrieval_time = + retrieval_finished_at.unwrap() - start_time.unwrap(); result.prediction_time = prediction_finished_at - prediction_started_at; result.total_time = prediction_finished_at - start_time.unwrap(); @@ -219,7 +213,12 @@ pub async fn perform_predict( }); zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + zeta.refresh_context_with_agentic_retrieval( + project.clone(), + cursor_buffer.clone(), + cursor_anchor, + cx, + ) })? .await?; } @@ -321,8 +320,7 @@ pub struct PredictionDetails { pub diff: String, pub excerpts: Vec, pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly - pub planning_search_time: Option, - pub running_search_time: Option, + pub retrieval_time: Duration, pub prediction_time: Duration, pub total_time: Duration, pub run_example_dir: PathBuf, @@ -336,8 +334,7 @@ impl PredictionDetails { diff: Default::default(), excerpts: Default::default(), excerpts_text: Default::default(), - planning_search_time: Default::default(), - running_search_time: Default::default(), + retrieval_time: Default::default(), prediction_time: Default::default(), total_time: Default::default(), run_example_dir, @@ -357,28 +354,20 @@ impl PredictionDetails { } pub fn to_markdown(&self) -> String { - let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time; - format!( "## Excerpts\n\n\ {}\n\n\ ## Prediction\n\n\ {}\n\n\ ## Time\n\n\ - Planning searches: {}ms\n\ - Running searches: {}ms\n\ - Making Prediction: {}ms\n\n\ - -------------------\n\n\ - Total: {}ms\n\ - Inference: {}ms ({:.2}%)\n", + Retrieval: {}ms\n\ + Prediction: {}ms\n\n\ + Total: {}ms\n", self.excerpts_text, self.diff, - self.planning_search_time.unwrap_or_default().as_millis(), - self.running_search_time.unwrap_or_default().as_millis(), + self.retrieval_time.as_millis(), self.prediction_time.as_millis(), self.total_time.as_millis(), - inference_time.as_millis(), - (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100. ) } } diff --git a/crates/zeta_cli/src/util.rs b/crates/zeta_cli/src/util.rs index 699c1c743f67e09ef5ca7211c385114802d4ab32..f4a51d94585f82da008ac832dc62392c365738fd 100644 --- a/crates/zeta_cli/src/util.rs +++ b/crates/zeta_cli/src/util.rs @@ -2,7 +2,8 @@ use anyhow::{Result, anyhow}; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AsyncApp, Entity, Task}; -use language::{Buffer, LanguageId, LanguageServerId, ParseStatus}; +use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus}; +use project::lsp_store::OpenLspBufferHandle; use project::{Project, ProjectPath, Worktree}; use std::collections::HashSet; use std::sync::Arc; @@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server( path: Arc, ready_languages: &mut HashSet, cx: &mut AsyncApp, -) -> Result<(Entity>, LanguageServerId, Entity)> { +) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity)> { let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?; let (lsp_open_handle, path_style) = project.update(cx, |project, cx| { @@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server( ) })?; + let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; + let result = language_registry + .load_language_for_file_path(path.as_std_path()) + .await; + + if let Err(error) = result + && !error.is::() + { + anyhow::bail!(error); + } + let Some(language_id) = buffer.read_with(cx, |buffer, _cx| { buffer.language().map(|language| language.id()) })? @@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server( return Err(anyhow!("No language for {}", path.display(path_style))); }; - let log_prefix = path.display(path_style); + let log_prefix = format!("{} | ", path.display(path_style)); if !ready_languages.contains(&language_id) { - wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?; + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; ready_languages.insert(language_id); } @@ -95,7 +107,7 @@ pub fn wait_for_lang_server( log_prefix: String, cx: &mut AsyncApp, ) -> Task> { - println!("{}⏵ Waiting for language server", log_prefix); + eprintln!("{}⏵ Waiting for language server", log_prefix); let (mut tx, mut rx) = mpsc::channel(1); @@ -137,7 +149,7 @@ pub fn wait_for_lang_server( .. } = event { - println!("{}⟲ {message}", log_prefix) + eprintln!("{}⟲ {message}", log_prefix) } } }), @@ -162,7 +174,7 @@ pub fn wait_for_lang_server( cx.spawn(async move |cx| { if !has_lang_server { // some buffers never have a language server, so this aborts quickly in that case. - let timeout = cx.background_executor().timer(Duration::from_secs(5)); + let timeout = cx.background_executor().timer(Duration::from_secs(500)); futures::select! { _ = added_rx.next() => {}, _ = timeout.fuse() => { @@ -173,7 +185,7 @@ pub fn wait_for_lang_server( let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5)); let result = futures::select! { _ = rx.next() => { - println!("{}⚑ Language server idle", log_prefix); + eprintln!("{}⚑ Language server idle", log_prefix); anyhow::Ok(()) }, _ = timeout.fuse() => { From 42583c1141b68f655335769f4770b3ceea84c263 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 4 Dec 2025 15:56:57 -0800 Subject: [PATCH 10/81] Reorganize edit prediction code and remove old experiments (#44187) Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Ben Kunkle --- Cargo.lock | 1272 +----- Cargo.toml | 18 +- assets/keymaps/default-linux.json | 16 +- assets/keymaps/default-macos.json | 16 +- assets/keymaps/default-windows.json | 14 +- .../cloud_llm_client/src/predict_edits_v3.rs | 88 +- crates/cloud_zeta2_prompt/Cargo.toml | 5 - .../src/cloud_zeta2_prompt.rs | 680 +-- .../src/retrieval_prompt.rs | 244 -- crates/codestral/Cargo.toml | 2 +- crates/codestral/src/codestral.rs | 13 +- crates/copilot/Cargo.toml | 2 +- crates/copilot/src/copilot.rs | 4 +- ...rs => copilot_edit_prediction_delegate.rs} | 22 +- crates/edit_prediction/Cargo.toml | 62 + .../license_examples/0bsd.txt | 0 .../license_examples/apache-2.0-ex0.txt | 0 .../license_examples/apache-2.0-ex1.txt | 0 .../license_examples/apache-2.0-ex2.txt | 0 .../license_examples/apache-2.0-ex3.txt | 0 .../license_examples/apache-2.0-ex4.txt | 0 .../license_examples/bsd-1-clause.txt | 0 .../license_examples/bsd-2-clause-ex0.txt | 0 .../license_examples/bsd-3-clause-ex0.txt | 0 .../license_examples/bsd-3-clause-ex1.txt | 0 .../license_examples/bsd-3-clause-ex2.txt | 0 .../license_examples/bsd-3-clause-ex3.txt | 0 .../license_examples/bsd-3-clause-ex4.txt | 0 .../license_examples/isc.txt | 0 .../license_examples/mit-ex0.txt | 0 .../license_examples/mit-ex1.txt | 0 .../license_examples/mit-ex2.txt | 0 .../license_examples/mit-ex3.txt | 0 .../license_examples/upl-1.0.txt | 0 .../license_examples/zlib-ex0.txt | 0 .../license_patterns/0bsd-pattern | 0 .../license_patterns/apache-2.0-pattern | 0 .../apache-2.0-reference-pattern | 0 .../license_patterns/bsd-pattern | 0 .../license_patterns/isc-pattern | 0 .../license_patterns/mit-pattern | 0 .../license_patterns/upl-1.0-pattern | 0 .../license_patterns/zlib-pattern | 0 crates/edit_prediction/src/edit_prediction.rs | 2041 ++++++++- .../src/edit_prediction_tests.rs | 1806 ++++++++ .../src/license_detection.rs | 0 .../src/onboarding_modal.rs | 0 .../src/prediction.rs | 2 +- .../{zeta => edit_prediction}/src/sweep_ai.rs | 4 +- crates/{zeta => edit_prediction}/src/udiff.rs | 0 .../src/xml_edits.rs | 0 .../src/zed_edit_prediction_delegate.rs} | 114 +- crates/{zeta => edit_prediction}/src/zeta1.rs | 20 +- .../src/zeta1/input_excerpt.rs | 0 crates/edit_prediction/src/zeta2.rs | 358 ++ .../Cargo.toml | 11 +- .../LICENSE-GPL | 0 .../build.rs | 0 .../src/evaluate.rs | 14 +- .../src/example.rs | 4 +- .../src/headless.rs | 0 .../src/main.rs | 83 +- .../src/metrics.rs | 4 +- .../src/paths.rs | 0 .../src/predict.rs | 76 +- .../src/source_location.rs | 0 .../src/util.rs | 0 crates/edit_prediction_context/Cargo.toml | 23 +- .../src/assemble_excerpts.rs | 0 .../src/declaration.rs | 350 -- .../src/declaration_scoring.rs | 539 --- .../src/edit_prediction_context.rs | 736 ++-- .../src/edit_prediction_context_tests.rs | 0 crates/edit_prediction_context/src/excerpt.rs | 73 +- .../src/fake_definition_lsp.rs | 0 crates/edit_prediction_context/src/imports.rs | 1319 ------ crates/edit_prediction_context/src/outline.rs | 126 - .../edit_prediction_context/src/reference.rs | 173 - .../src/syntax_index.rs | 1069 ----- .../src/text_similarity.rs | 314 -- crates/edit_prediction_context2/Cargo.toml | 42 - .../src/edit_prediction_context2.rs | 465 -- crates/edit_prediction_types/Cargo.toml | 17 + .../LICENSE-GPL | 0 .../src/edit_prediction_types.rs | 298 ++ .../Cargo.toml | 16 +- .../{zeta => edit_prediction_ui}/LICENSE-GPL | 0 .../src/edit_prediction_button.rs | 50 +- .../src/edit_prediction_context_view.rs} | 73 +- .../src/edit_prediction_ui.rs | 128 + .../src/rate_prediction_modal.rs | 59 +- .../src/sweep_api_token_modal.rs | 9 +- crates/editor/Cargo.toml | 2 +- crates/editor/src/edit_prediction_tests.rs | 64 +- crates/editor/src/editor.rs | 24 +- crates/editor/src/editor_tests.rs | 6 +- crates/feature_flags/src/flags.rs | 6 - crates/supermaven/Cargo.toml | 2 +- crates/supermaven/src/supermaven.rs | 4 +- ...=> supermaven_edit_prediction_delegate.rs} | 14 +- crates/zed/Cargo.toml | 5 +- crates/zed/src/main.rs | 4 +- crates/zed/src/zed.rs | 8 +- .../zed/src/zed/edit_prediction_registry.rs | 46 +- crates/zeta/Cargo.toml | 85 - crates/zeta/src/retrieval_search.rs | 490 --- crates/zeta/src/zeta.rs | 3890 ----------------- crates/zeta/src/zeta_tests.rs | 671 --- crates/zeta2_tools/Cargo.toml | 48 - crates/zeta2_tools/LICENSE-GPL | 1 - crates/zeta2_tools/src/zeta2_tools.rs | 1035 ----- crates/zeta_cli/LICENSE-GPL | 1 - crates/zeta_cli/src/syntax_retrieval_stats.rs | 1260 ------ 113 files changed, 5529 insertions(+), 15011 deletions(-) delete mode 100644 crates/cloud_zeta2_prompt/src/retrieval_prompt.rs rename crates/copilot/src/{copilot_completion_provider.rs => copilot_edit_prediction_delegate.rs} (98%) rename crates/{zeta => edit_prediction}/license_examples/0bsd.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/apache-2.0-ex0.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/apache-2.0-ex1.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/apache-2.0-ex2.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/apache-2.0-ex3.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/apache-2.0-ex4.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-1-clause.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-2-clause-ex0.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-3-clause-ex0.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-3-clause-ex1.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-3-clause-ex2.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-3-clause-ex3.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/bsd-3-clause-ex4.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/isc.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/mit-ex0.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/mit-ex1.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/mit-ex2.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/mit-ex3.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/upl-1.0.txt (100%) rename crates/{zeta => edit_prediction}/license_examples/zlib-ex0.txt (100%) rename crates/{zeta => edit_prediction}/license_patterns/0bsd-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/apache-2.0-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/apache-2.0-reference-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/bsd-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/isc-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/mit-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/upl-1.0-pattern (100%) rename crates/{zeta => edit_prediction}/license_patterns/zlib-pattern (100%) create mode 100644 crates/edit_prediction/src/edit_prediction_tests.rs rename crates/{zeta => edit_prediction}/src/license_detection.rs (100%) rename crates/{zeta => edit_prediction}/src/onboarding_modal.rs (100%) rename crates/{zeta => edit_prediction}/src/prediction.rs (99%) rename crates/{zeta => edit_prediction}/src/sweep_ai.rs (99%) rename crates/{zeta => edit_prediction}/src/udiff.rs (100%) rename crates/{zeta => edit_prediction}/src/xml_edits.rs (100%) rename crates/{zeta/src/provider.rs => edit_prediction/src/zed_edit_prediction_delegate.rs} (58%) rename crates/{zeta => edit_prediction}/src/zeta1.rs (96%) rename crates/{zeta => edit_prediction}/src/zeta1/input_excerpt.rs (100%) create mode 100644 crates/edit_prediction/src/zeta2.rs rename crates/{zeta_cli => edit_prediction_cli}/Cargo.toml (84%) rename crates/{edit_prediction_button => edit_prediction_cli}/LICENSE-GPL (100%) rename crates/{zeta_cli => edit_prediction_cli}/build.rs (100%) rename crates/{zeta_cli => edit_prediction_cli}/src/evaluate.rs (98%) rename crates/{zeta_cli => edit_prediction_cli}/src/example.rs (99%) rename crates/{zeta_cli => edit_prediction_cli}/src/headless.rs (100%) rename crates/{zeta_cli => edit_prediction_cli}/src/main.rs (84%) rename crates/{zeta_cli => edit_prediction_cli}/src/metrics.rs (99%) rename crates/{zeta_cli => edit_prediction_cli}/src/paths.rs (100%) rename crates/{zeta_cli => edit_prediction_cli}/src/predict.rs (85%) rename crates/{zeta_cli => edit_prediction_cli}/src/source_location.rs (100%) rename crates/{zeta_cli => edit_prediction_cli}/src/util.rs (100%) rename crates/{edit_prediction_context2 => edit_prediction_context}/src/assemble_excerpts.rs (100%) delete mode 100644 crates/edit_prediction_context/src/declaration.rs delete mode 100644 crates/edit_prediction_context/src/declaration_scoring.rs rename crates/{edit_prediction_context2 => edit_prediction_context}/src/edit_prediction_context_tests.rs (100%) rename crates/{edit_prediction_context2 => edit_prediction_context}/src/fake_definition_lsp.rs (100%) delete mode 100644 crates/edit_prediction_context/src/imports.rs delete mode 100644 crates/edit_prediction_context/src/outline.rs delete mode 100644 crates/edit_prediction_context/src/reference.rs delete mode 100644 crates/edit_prediction_context/src/syntax_index.rs delete mode 100644 crates/edit_prediction_context/src/text_similarity.rs delete mode 100644 crates/edit_prediction_context2/Cargo.toml delete mode 100644 crates/edit_prediction_context2/src/edit_prediction_context2.rs create mode 100644 crates/edit_prediction_types/Cargo.toml rename crates/{edit_prediction_context2 => edit_prediction_types}/LICENSE-GPL (100%) create mode 100644 crates/edit_prediction_types/src/edit_prediction_types.rs rename crates/{edit_prediction_button => edit_prediction_ui}/Cargo.toml (77%) rename crates/{zeta => edit_prediction_ui}/LICENSE-GPL (100%) rename crates/{edit_prediction_button => edit_prediction_ui}/src/edit_prediction_button.rs (97%) rename crates/{zeta2_tools/src/zeta2_context_view.rs => edit_prediction_ui/src/edit_prediction_context_view.rs} (85%) create mode 100644 crates/edit_prediction_ui/src/edit_prediction_ui.rs rename crates/{zeta => edit_prediction_ui}/src/rate_prediction_modal.rs (95%) rename crates/{edit_prediction_button => edit_prediction_ui}/src/sweep_api_token_modal.rs (92%) rename crates/supermaven/src/{supermaven_completion_provider.rs => supermaven_edit_prediction_delegate.rs} (95%) delete mode 100644 crates/zeta/Cargo.toml delete mode 100644 crates/zeta/src/retrieval_search.rs delete mode 100644 crates/zeta/src/zeta.rs delete mode 100644 crates/zeta/src/zeta_tests.rs delete mode 100644 crates/zeta2_tools/Cargo.toml delete mode 120000 crates/zeta2_tools/LICENSE-GPL delete mode 100644 crates/zeta2_tools/src/zeta2_tools.rs delete mode 120000 crates/zeta_cli/LICENSE-GPL delete mode 100644 crates/zeta_cli/src/syntax_retrieval_stats.rs diff --git a/Cargo.lock b/Cargo.lock index 6d41fbe96fac878f496e93461c180e1c184216d6..885fbe77fd17a90e4cc948d4c40166d41a26cd35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,7 +211,7 @@ dependencies = [ "worktree", "zed_env_vars", "zlog", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -680,21 +680,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "argminmax" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65" -dependencies = [ - "num-traits", -] - -[[package]] -name = "array-init-cursor" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3" - [[package]] name = "arraydeque" version = "0.5.1" @@ -1278,15 +1263,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atoi_simd" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf" -dependencies = [ - "debug_unsafe", -] - [[package]] name = "atomic" version = "0.5.3" @@ -2070,26 +2046,6 @@ dependencies = [ "serde", ] -[[package]] -name = "bincode" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" -dependencies = [ - "bincode_derive", - "serde", - "unty", -] - -[[package]] -name = "bincode_derive" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" -dependencies = [ - "virtue", -] - [[package]] name = "bindgen" version = "0.71.1" @@ -2242,19 +2198,6 @@ dependencies = [ "profiling", ] -[[package]] -name = "blake3" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "block" version = "0.1.6" @@ -2344,12 +2287,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "boxcar" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" - [[package]] name = "breadcrumbs" version = "0.1.0" @@ -2516,9 +2453,6 @@ name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" -dependencies = [ - "serde", -] [[package]] name = "bytes-utils" @@ -2805,15 +2739,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "castaway" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" -dependencies = [ - "rustversion", -] - [[package]] name = "cbc" version = "0.1.2" @@ -2942,16 +2867,6 @@ dependencies = [ "windows-link 0.2.1", ] -[[package]] -name = "chrono-tz" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" -dependencies = [ - "chrono", - "phf 0.12.1", -] - [[package]] name = "chunked_transfer" version = "1.5.0" @@ -3201,12 +3116,7 @@ dependencies = [ "anyhow", "cloud_llm_client", "indoc", - "ordered-float 2.10.1", - "rustc-hash 2.1.1", - "schemars", "serde", - "serde_json", - "strum 0.27.2", ] [[package]] @@ -3314,8 +3224,8 @@ name = "codestral" version = "0.1.0" dependencies = [ "anyhow", - "edit_prediction", "edit_prediction_context", + "edit_prediction_types", "futures 0.3.31", "gpui", "http_client", @@ -3505,17 +3415,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "comfy-table" -version = "7.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b" -dependencies = [ - "crossterm", - "unicode-segmentation", - "unicode-width", -] - [[package]] name = "command-fds" version = "0.3.2" @@ -3569,21 +3468,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "compact_str" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "rustversion", - "ryu", - "serde", - "static_assertions", -] - [[package]] name = "component" version = "0.1.0" @@ -3689,12 +3573,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" -[[package]] -name = "constant_time_eq" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" - [[package]] name = "context_server" version = "0.1.0" @@ -3747,7 +3625,7 @@ dependencies = [ "command_palette_hooks", "ctor", "dirs 4.0.0", - "edit_prediction", + "edit_prediction_types", "editor", "fs", "futures 0.3.31", @@ -4160,7 +4038,7 @@ dependencies = [ name = "crashes" version = "0.1.0" dependencies = [ - "bincode 1.3.3", + "bincode", "cfg-if", "crash-handler", "extension_host", @@ -4174,7 +4052,7 @@ dependencies = [ "smol", "system_specs", "windows 0.61.3", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -4319,29 +4197,6 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" -[[package]] -name = "crossterm" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" -dependencies = [ - "bitflags 2.9.4", - "crossterm_winapi", - "document-features", - "parking_lot", - "rustix 1.1.2", - "winapi", -] - -[[package]] -name = "crossterm_winapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" -dependencies = [ - "winapi", -] - [[package]] name = "crunchy" version = "0.2.4" @@ -4696,12 +4551,6 @@ dependencies = [ "util", ] -[[package]] -name = "debug_unsafe" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85d3cef41d236720ed453e102153a53e4cc3d2fde848c0078a50cf249e8e3e5b" - [[package]] name = "debugger_tools" version = "0.1.0" @@ -5109,15 +4958,6 @@ dependencies = [ "zlog", ] -[[package]] -name = "document-features" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d" -dependencies = [ - "litrs", -] - [[package]] name = "documented" version = "0.9.2" @@ -5267,86 +5107,112 @@ dependencies = [ name = "edit_prediction" version = "0.1.0" dependencies = [ - "client", - "gpui", - "language", -] - -[[package]] -name = "edit_prediction_button" -version = "0.1.0" -dependencies = [ + "ai_onboarding", "anyhow", + "arrayvec", + "brotli", "client", + "clock", + "cloud_api_types", "cloud_llm_client", - "codestral", + "cloud_zeta2_prompt", + "collections", "copilot", - "edit_prediction", - "editor", + "credentials_provider", + "ctor", + "db", + "edit_prediction_context", + "edit_prediction_types", "feature_flags", "fs", "futures 0.3.31", "gpui", "indoc", + "itertools 0.14.0", "language", + "language_model", + "log", "lsp", "menu", - "paths", + "open_ai", + "parking_lot", + "postage", + "pretty_assertions", "project", + "rand 0.9.2", "regex", + "release_channel", + "semver", + "serde", "serde_json", "settings", - "supermaven", + "smol", + "strsim", + "strum 0.27.2", "telemetry", - "theme", + "telemetry_events", + "thiserror 2.0.17", "ui", - "ui_input", "util", + "uuid", "workspace", + "worktree", "zed_actions", - "zeta", + "zlog", ] [[package]] -name = "edit_prediction_context" +name = "edit_prediction_cli" version = "0.1.0" dependencies = [ "anyhow", - "arrayvec", + "chrono", "clap", + "client", "cloud_llm_client", + "cloud_zeta2_prompt", "collections", + "debug_adapter_extension", + "edit_prediction", + "edit_prediction_context", + "extension", + "fs", "futures 0.3.31", "gpui", - "hashbrown 0.15.5", + "gpui_tokio", "indoc", - "itertools 0.14.0", "language", + "language_extension", + "language_model", + "language_models", + "languages", "log", - "ordered-float 2.10.1", - "postage", + "node_runtime", + "paths", "pretty_assertions", "project", - "regex", + "prompt_store", + "pulldown-cmark 0.12.2", + "release_channel", + "reqwest_client", "serde", "serde_json", "settings", - "slotmap", - "strum 0.27.2", - "text", - "tree-sitter", - "tree-sitter-c", - "tree-sitter-cpp", - "tree-sitter-go", + "shellexpand 2.1.2", + "smol", + "terminal_view", + "toml 0.8.23", "util", + "watch", "zlog", ] [[package]] -name = "edit_prediction_context2" +name = "edit_prediction_context" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "env_logger 0.11.8", "futures 0.3.31", @@ -5368,6 +5234,56 @@ dependencies = [ "zlog", ] +[[package]] +name = "edit_prediction_types" +version = "0.1.0" +dependencies = [ + "client", + "gpui", + "language", +] + +[[package]] +name = "edit_prediction_ui" +version = "0.1.0" +dependencies = [ + "anyhow", + "buffer_diff", + "client", + "cloud_llm_client", + "cloud_zeta2_prompt", + "codestral", + "command_palette_hooks", + "copilot", + "edit_prediction", + "edit_prediction_types", + "editor", + "feature_flags", + "fs", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "lsp", + "markdown", + "menu", + "multi_buffer", + "paths", + "project", + "regex", + "serde_json", + "settings", + "supermaven", + "telemetry", + "text", + "theme", + "ui", + "ui_input", + "util", + "workspace", + "zed_actions", +] + [[package]] name = "editor" version = "0.1.0" @@ -5384,7 +5300,7 @@ dependencies = [ "ctor", "dap", "db", - "edit_prediction", + "edit_prediction_types", "emojis", "feature_flags", "file_icons", @@ -5723,14 +5639,8 @@ dependencies = [ ] [[package]] -name = "ethnum" -version = "1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" - -[[package]] -name = "euclid" -version = "0.22.11" +name = "euclid" +version = "0.22.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad9cdb4b747e485a12abb0e6566612956c7a1bafa3bdb8d682c5b6d403589e48" dependencies = [ @@ -6012,12 +5922,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" -[[package]] -name = "fallible-streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" - [[package]] name = "fancy-regex" version = "0.16.2" @@ -6029,12 +5933,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "fast-float2" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" - [[package]] name = "fast-srgb8" version = "1.0.0" @@ -6210,7 +6108,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", ] @@ -6467,16 +6364,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "fs4" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" -dependencies = [ - "rustix 1.1.2", - "windows-sys 0.59.0", -] - [[package]] name = "fs_benchmarks" version = "0.1.0" @@ -7540,7 +7427,6 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash 0.1.5", - "rayon", "serde", ] @@ -7652,7 +7538,7 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c255bdf46e07fb840d120a36dcc81f385140d7191c76a7391672675c01a55d" dependencies = [ - "bincode 1.3.3", + "bincode", "byteorder", "heed-traits", "serde", @@ -8412,7 +8298,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8251fb7bcd9ccd3725ed8deae9fe7db8e586495c9eb5b0c52e6233e5e75ea" dependencies = [ - "bincode 1.3.3", + "bincode", "crossbeam-channel", "fnv", "lazy_static", @@ -9256,15 +9142,6 @@ dependencies = [ "webrtc-sys", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" -dependencies = [ - "zlib-rs", -] - [[package]] name = "libz-sys" version = "1.1.22" @@ -9327,12 +9204,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" -[[package]] -name = "litrs" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed" - [[package]] name = "livekit" version = "0.7.8" @@ -9624,25 +9495,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "lz4" -version = "1.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "mac" version = "0.1.1" @@ -10505,15 +10357,6 @@ name = "notify-types" version = "2.0.0" source = "git+https://github.com/zed-industries/notify.git?rev=b4588b2e5aee68f4c0e100f140e808cbce7b1419#b4588b2e5aee68f4c0e100f140e808cbce7b1419" -[[package]] -name = "now" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" -dependencies = [ - "chrono", -] - [[package]] name = "ntapi" version = "0.4.1" @@ -10909,41 +10752,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" -dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes 1.10.1", - "chrono", - "form_urlencoded", - "futures 0.3.31", - "http 1.3.1", - "http-body-util", - "humantime", - "hyper 1.7.0", - "itertools 0.14.0", - "parking_lot", - "percent-encoding", - "quick-xml 0.38.3", - "rand 0.9.2", - "reqwest 0.12.24", - "ring", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror 2.0.17", - "tokio", - "tracing", - "url", - "walkdir", - "wasm-bindgen-futures", - "web-time", -] - [[package]] name = "ollama" version = "0.1.0" @@ -12184,16 +11992,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" -[[package]] -name = "planus" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71" -dependencies = [ - "array-init-cursor", - "hashbrown 0.15.5", -] - [[package]] name = "plist" version = "1.8.0" @@ -12261,520 +12059,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "polars" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5f7feb5d56b954e691dff22a8b2d78d77433dcc93c35fe21c3777fdc121b697" -dependencies = [ - "getrandom 0.2.16", - "getrandom 0.3.4", - "polars-arrow", - "polars-core", - "polars-error", - "polars-io", - "polars-lazy", - "polars-ops", - "polars-parquet", - "polars-sql", - "polars-time", - "polars-utils", - "version_check", -] - -[[package]] -name = "polars-arrow" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b4fed2343961b3eea3db2cee165540c3e1ad9d5782350cc55a9e76cf440148" -dependencies = [ - "atoi_simd", - "bitflags 2.9.4", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "getrandom 0.2.16", - "getrandom 0.3.4", - "hashbrown 0.15.5", - "itoa", - "lz4", - "num-traits", - "polars-arrow-format", - "polars-error", - "polars-schema", - "polars-utils", - "serde", - "simdutf8", - "streaming-iterator", - "strum_macros 0.27.2", - "version_check", - "zstd 0.13.3", -] - -[[package]] -name = "polars-arrow-format" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675" -dependencies = [ - "planus", - "serde", -] - -[[package]] -name = "polars-compute" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "138785beda4e4a90a025219f09d0d15a671b2be9091513ede58e05db6ad4413f" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "either", - "fast-float2", - "hashbrown 0.15.5", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "rand 0.9.2", - "ryu", - "serde", - "skiplist", - "strength_reduce", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e77b1f08ef6dbb032bb1d0d3365464be950df9905f6827a95b24c4ca5518901d" -dependencies = [ - "bitflags 2.9.4", - "boxcar", - "bytemuck", - "chrono", - "chrono-tz", - "comfy-table", - "either", - "hashbrown 0.15.5", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-row", - "polars-schema", - "polars-utils", - "rand 0.9.2", - "rand_distr", - "rayon", - "regex", - "serde", - "serde_json", - "strum_macros 0.27.2", - "uuid", - "version_check", - "xxhash-rust", -] - -[[package]] -name = "polars-dtype" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89c43d0ea57168be4546c4d8064479ed8b29a9c79c31a0c7c367ee734b9b7158" -dependencies = [ - "boxcar", - "hashbrown 0.15.5", - "polars-arrow", - "polars-error", - "polars-utils", - "serde", - "uuid", -] - -[[package]] -name = "polars-error" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9cb5d98f59f8b94673ee391840440ad9f0d2170afced95fc98aa86f895563c0" -dependencies = [ - "object_store", - "parking_lot", - "polars-arrow-format", - "regex", - "signal-hook", - "simdutf8", -] - -[[package]] -name = "polars-expr" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343931b818cf136349135ba11dbc18c27683b52c3477b1ba8ca606cf5ab1965c" -dependencies = [ - "bitflags 2.9.4", - "hashbrown 0.15.5", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-io", - "polars-ops", - "polars-plan", - "polars-row", - "polars-time", - "polars-utils", - "rand 0.9.2", - "rayon", - "recursive", -] - -[[package]] -name = "polars-io" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10388c64b8155122488229a881d1c6f4fdc393bc988e764ab51b182fcb2307e4" -dependencies = [ - "async-trait", - "atoi_simd", - "blake3", - "bytes 1.10.1", - "chrono", - "fast-float2", - "fs4", - "futures 0.3.31", - "glob", - "hashbrown 0.15.5", - "home", - "itoa", - "memchr", - "memmap2", - "num-traits", - "object_store", - "percent-encoding", - "polars-arrow", - "polars-core", - "polars-error", - "polars-parquet", - "polars-schema", - "polars-time", - "polars-utils", - "rayon", - "regex", - "reqwest 0.12.24", - "ryu", - "serde", - "serde_json", - "simdutf8", - "tokio", - "tokio-util", - "url", -] - -[[package]] -name = "polars-lazy" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb6e2c6c2fa4ea0c660df1c06cf56960c81e7c2683877995bae3d4e3d408147" -dependencies = [ - "bitflags 2.9.4", - "chrono", - "either", - "memchr", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-plan", - "polars-stream", - "polars-time", - "polars-utils", - "rayon", - "version_check", -] - -[[package]] -name = "polars-mem-engine" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20a856e98e253587c28d8132a5e7e5a75cb2c44731ca090f1481d45f1d123771" -dependencies = [ - "futures 0.3.31", - "memmap2", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "tokio", -] - -[[package]] -name = "polars-ops" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf6062173fdc9ba05775548beb66e76643a148d9aeadc9984ed712bc4babd76" -dependencies = [ - "argminmax", - "base64 0.22.1", - "bytemuck", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.15.5", - "hex", - "indexmap", - "libm", - "memchr", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-schema", - "polars-utils", - "rayon", - "regex", - "regex-syntax", - "strum_macros 0.27.2", - "unicode-normalization", - "unicode-reverse", - "version_check", -] - -[[package]] -name = "polars-parquet" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1d769180dec070df0dc4b89299b364bf2cfe32b218ecc4ddd8f1a49ae60669" -dependencies = [ - "async-stream", - "base64 0.22.1", - "brotli", - "bytemuck", - "ethnum", - "flate2", - "futures 0.3.31", - "hashbrown 0.15.5", - "lz4", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-parquet-format", - "polars-utils", - "serde", - "simdutf8", - "snap", - "streaming-decompression", - "zstd 0.13.3", -] - -[[package]] -name = "polars-parquet-format" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" -dependencies = [ - "async-trait", - "futures 0.3.31", -] - -[[package]] -name = "polars-plan" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd3a2e33ae4484fe407ab2d2ba5684f0889d1ccf3ad6b844103c03638e6d0a0" -dependencies = [ - "bitflags 2.9.4", - "bytemuck", - "bytes 1.10.1", - "chrono", - "chrono-tz", - "either", - "futures 0.3.31", - "hashbrown 0.15.5", - "memmap2", - "num-traits", - "percent-encoding", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-io", - "polars-ops", - "polars-parquet", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "regex", - "sha2", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-row" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18734f17e0e348724df3ae65f3ee744c681117c04b041cac969dfceb05edabc0" -dependencies = [ - "bitflags 2.9.4", - "bytemuck", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-utils", -] - -[[package]] -name = "polars-schema" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6c1ab13e04d5167661a9854ed1ea0482b2ed9b8a0f1118dabed7cd994a85e3" -dependencies = [ - "indexmap", - "polars-error", - "polars-utils", - "serde", - "version_check", -] - -[[package]] -name = "polars-sql" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e7766da02cc1d464994404d3e88a7a0ccd4933df3627c325480fbd9bbc0a11" -dependencies = [ - "bitflags 2.9.4", - "hex", - "polars-core", - "polars-error", - "polars-lazy", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rand 0.9.2", - "regex", - "serde", - "sqlparser", -] - -[[package]] -name = "polars-stream" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f6c6ca1ea01f9dea424d167e4f33f5ec44cd67fbfac9efd40575ed20521f14" -dependencies = [ - "async-channel 2.5.0", - "async-trait", - "atomic-waker", - "bitflags 2.9.4", - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-queue", - "crossbeam-utils", - "futures 0.3.31", - "memmap2", - "parking_lot", - "percent-encoding", - "pin-project-lite", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-utils", - "rand 0.9.2", - "rayon", - "recursive", - "slotmap", - "tokio", - "tokio-util", - "version_check", -] - -[[package]] -name = "polars-time" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6a3a6e279a7a984a0b83715660f9e880590c6129ec2104396bfa710bcd76dee" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "chrono-tz", - "now", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-ops", - "polars-utils", - "rayon", - "regex", - "strum_macros 0.27.2", -] - -[[package]] -name = "polars-utils" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57b267021b0e5422d7fbc70fd79e51b9f9a8466c585779373a18b0199e973f29" -dependencies = [ - "bincode 2.0.1", - "bytemuck", - "bytes 1.10.1", - "compact_str", - "either", - "flate2", - "foldhash 0.1.5", - "hashbrown 0.15.5", - "indexmap", - "libc", - "memmap2", - "num-traits", - "polars-error", - "rand 0.9.2", - "raw-cpuid 11.6.0", - "rayon", - "regex", - "rmp-serde", - "serde", - "serde_json", - "serde_stacker", - "slotmap", - "stacker", - "uuid", - "version_check", -] - [[package]] name = "polling" version = "3.11.0" @@ -13526,7 +12810,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" dependencies = [ "memchr", - "serde", ] [[package]] @@ -13860,26 +13143,6 @@ dependencies = [ "zed_actions", ] -[[package]] -name = "recursive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" -dependencies = [ - "recursive-proc-macro-impl", - "stacker", -] - -[[package]] -name = "recursive-proc-macro-impl" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" -dependencies = [ - "quote", - "syn 2.0.106", -] - [[package]] name = "redox_syscall" version = "0.2.16" @@ -14236,35 +13499,26 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper 1.7.0", - "hyper-rustls 0.27.7", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", - "quinn", - "rustls 0.23.33", - "rustls-native-certs 0.8.2", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tokio-rustls 0.26.2", - "tokio-util", "tower 0.5.2", "tower-http 0.6.6", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", ] @@ -14387,17 +13641,6 @@ dependencies = [ "paste", ] -[[package]] -name = "rmp-serde" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" -dependencies = [ - "byteorder", - "rmp", - "serde", -] - [[package]] name = "rmpv" version = "1.3.0" @@ -14467,7 +13710,7 @@ dependencies = [ "tracing", "util", "zlog", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -15359,17 +14602,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "serde_stacker" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4936375d50c4be7eff22293a9344f8e46f323ed2b3c243e52f89138d9bb0f4a" -dependencies = [ - "serde", - "serde_core", - "stacker", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -15711,16 +14943,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" -[[package]] -name = "skiplist" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f354fd282d3177c2951004953e2fdc4cb342fa159bbee8b829852b6a081c8ea1" -dependencies = [ - "rand 0.9.2", - "thiserror 2.0.17", -] - [[package]] name = "skrifa" version = "0.37.0" @@ -15796,12 +15018,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "snippet" version = "0.1.0" @@ -15848,26 +15064,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "soa-rs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b75ae4668062b095fda87ba54118697bed601f07f6c68bf50289a25ca0c8c935" -dependencies = [ - "soa-rs-derive", -] - -[[package]] -name = "soa-rs-derive" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c09121507da587d3434e5929ce3321162f36bd3eff403873cb163c06b176913" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "socket2" version = "0.5.10" @@ -15987,15 +15183,6 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "sqlparser" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" -dependencies = [ - "log", -] - [[package]] name = "sqlx" version = "0.8.6" @@ -16297,15 +15484,6 @@ dependencies = [ "ui", ] -[[package]] -name = "streaming-decompression" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" -dependencies = [ - "fallible-streaming-iterator", -] - [[package]] name = "streaming-iterator" version = "0.1.9" @@ -16447,7 +15625,7 @@ dependencies = [ "anyhow", "client", "collections", - "edit_prediction", + "edit_prediction_types", "editor", "env_logger 0.11.8", "futures 0.3.31", @@ -17691,7 +16869,6 @@ dependencies = [ "futures-core", "futures-io", "futures-sink", - "futures-util", "pin-project-lite", "tokio", ] @@ -18547,15 +17724,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" -[[package]] -name = "unicode-reverse" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "unicode-script" version = "0.5.7" @@ -18616,12 +17784,6 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" -[[package]] -name = "unty" -version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" - [[package]] name = "url" version = "2.5.7" @@ -18897,12 +18059,6 @@ dependencies = [ "settings", ] -[[package]] -name = "virtue" -version = "0.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" - [[package]] name = "vscode_theme" version = "0.2.0" @@ -21058,12 +20214,6 @@ dependencies = [ "toml_edit 0.22.27", ] -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - [[package]] name = "yaml-rust2" version = "0.8.1" @@ -21251,7 +20401,7 @@ dependencies = [ "audio", "auto_update", "auto_update_ui", - "bincode 1.3.3", + "bincode", "breadcrumbs", "call", "channel", @@ -21273,7 +20423,8 @@ dependencies = [ "debugger_tools", "debugger_ui", "diagnostics", - "edit_prediction_button", + "edit_prediction", + "edit_prediction_ui", "editor", "env_logger 0.11.8", "extension", @@ -21384,8 +20535,6 @@ dependencies = [ "zed-reqwest", "zed_actions", "zed_env_vars", - "zeta", - "zeta2_tools", "zlog", "zlog_settings", ] @@ -21697,151 +20846,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "zeta" -version = "0.1.0" -dependencies = [ - "ai_onboarding", - "anyhow", - "arrayvec", - "brotli", - "buffer_diff", - "client", - "clock", - "cloud_api_types", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "command_palette_hooks", - "copilot", - "credentials_provider", - "ctor", - "db", - "edit_prediction", - "edit_prediction_context", - "edit_prediction_context2", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "indoc", - "itertools 0.14.0", - "language", - "language_model", - "log", - "lsp", - "markdown", - "menu", - "open_ai", - "parking_lot", - "postage", - "pretty_assertions", - "project", - "rand 0.9.2", - "regex", - "release_channel", - "semver", - "serde", - "serde_json", - "settings", - "smol", - "strsim", - "strum 0.27.2", - "telemetry", - "telemetry_events", - "theme", - "thiserror 2.0.17", - "ui", - "util", - "uuid", - "workspace", - "worktree", - "zed_actions", - "zlog", -] - -[[package]] -name = "zeta2_tools" -version = "0.1.0" -dependencies = [ - "anyhow", - "clap", - "client", - "cloud_llm_client", - "collections", - "edit_prediction_context", - "editor", - "feature_flags", - "futures 0.3.31", - "gpui", - "indoc", - "language", - "multi_buffer", - "pretty_assertions", - "project", - "serde", - "serde_json", - "settings", - "telemetry", - "text", - "ui", - "ui_input", - "util", - "workspace", - "zeta", - "zlog", -] - -[[package]] -name = "zeta_cli" -version = "0.1.0" -dependencies = [ - "anyhow", - "chrono", - "clap", - "client", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "debug_adapter_extension", - "edit_prediction_context", - "extension", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "indoc", - "language", - "language_extension", - "language_model", - "language_models", - "languages", - "log", - "node_runtime", - "ordered-float 2.10.1", - "paths", - "polars", - "pretty_assertions", - "project", - "prompt_store", - "pulldown-cmark 0.12.2", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "settings", - "shellexpand 2.1.2", - "smol", - "soa-rs", - "terminal_view", - "toml 0.8.23", - "util", - "watch", - "zeta", - "zlog", -] - [[package]] name = "zip" version = "0.6.6" @@ -21851,7 +20855,7 @@ dependencies = [ "aes", "byteorder", "bzip2", - "constant_time_eq 0.1.5", + "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", @@ -21859,7 +20863,7 @@ dependencies = [ "pbkdf2 0.11.0", "sha1", "time", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -21877,12 +20881,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "zlib-rs" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" - [[package]] name = "zlog" version = "0.1.0" @@ -21910,16 +20908,7 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe 5.0.2+zstd.1.5.2", -] - -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe 7.2.4", + "zstd-safe", ] [[package]] @@ -21932,15 +20921,6 @@ dependencies = [ "zstd-sys", ] -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - [[package]] name = "zstd-sys" version = "2.0.16+zstd.1.5.7" diff --git a/Cargo.toml b/Cargo.toml index 62a44dbf35fefbf02a1b570146b0bf24cea6dcd8..83bc42e353f6462148abe15327373a3d57a029e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,10 +54,9 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/edit_prediction", - "crates/edit_prediction_button", + "crates/edit_prediction_types", + "crates/edit_prediction_ui", "crates/edit_prediction_context", - "crates/edit_prediction_context2", - "crates/zeta2_tools", "crates/editor", "crates/eval", "crates/eval_utils", @@ -202,8 +201,7 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zed_env_vars", - "crates/zeta", - "crates/zeta_cli", + "crates/edit_prediction_cli", "crates/zlog", "crates/zlog_settings", @@ -314,11 +312,9 @@ http_client = { path = "crates/http_client" } http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } -edit_prediction = { path = "crates/edit_prediction" } -edit_prediction_button = { path = "crates/edit_prediction_button" } +edit_prediction_types = { path = "crates/edit_prediction_types" } +edit_prediction_ui = { path = "crates/edit_prediction_ui" } edit_prediction_context = { path = "crates/edit_prediction_context" } -edit_prediction_context2 = { path = "crates/edit_prediction_context2" } -zeta2_tools = { path = "crates/zeta2_tools" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } journal = { path = "crates/journal" } @@ -435,7 +431,7 @@ x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zed_env_vars = { path = "crates/zed_env_vars" } -zeta = { path = "crates/zeta" } +edit_prediction = { path = "crates/edit_prediction" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } @@ -830,7 +826,7 @@ feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } image_viewer = { codegen-units = 1 } -edit_prediction_button = { codegen-units = 1 } +edit_prediction_ui = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } json_schema_store = { codegen-units = 1 } diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 5de5b9daae27113807cb6e97eda335a419f18ac9..0b001f31790c7f8211a6b44d227c15a6ff605ca4 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -41,7 +41,7 @@ "ctrl-f11": "debugger::StepInto", "shift-f11": "debugger::StepOut", "f11": "zed::ToggleFullScreen", - "ctrl-alt-z": "edit_prediction::RateCompletions", + "ctrl-alt-z": "edit_prediction::RatePredictions", "ctrl-alt-shift-i": "edit_prediction::ToggleMenu", "ctrl-alt-l": "lsp_tool::ToggleMenu" } @@ -1322,18 +1322,10 @@ } }, { - "context": "Zeta2Feedback > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "enter": "editor::Newline", - "ctrl-enter up": "dev::Zeta2RatePredictionPositive", - "ctrl-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", - "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 2fadafb6ca95f81de28165b23e4063dc7a0c38d8..e4595242d570628e2e70c43b66d14a0f9820512b 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -47,7 +47,7 @@ "cmd-m": "zed::Minimize", "fn-f": "zed::ToggleFullScreen", "ctrl-cmd-f": "zed::ToggleFullScreen", - "ctrl-cmd-z": "edit_prediction::RateCompletions", + "ctrl-cmd-z": "edit_prediction::RatePredictions", "ctrl-cmd-i": "edit_prediction::ToggleMenu", "ctrl-cmd-l": "lsp_tool::ToggleMenu", "ctrl-cmd-c": "editor::DisplayCursorNames" @@ -1427,18 +1427,10 @@ } }, { - "context": "Zeta2Feedback > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "enter": "editor::Newline", - "cmd-enter up": "dev::Zeta2RatePredictionPositive", - "cmd-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", - "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 8cf77f65813701fd42e3a6948b660368a24fd4e4..b625e7c7018c0f4c8277fcf3f739a8f06361c4df 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -1341,18 +1341,10 @@ } }, { - "context": "Zeta2Feedback > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "enter": "editor::Newline", - "ctrl-enter up": "dev::Zeta2RatePredictionPositive", - "ctrl-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", - "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index de8d69dc14870c5583679753c9a75a477e0cc759..9e590dc4cf48a82ecdda8b007c38ab15f3b602be 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -31,18 +31,10 @@ pub struct PredictEditsRequest { /// Within `signatures` pub excerpt_parent: Option, #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub included_files: Vec, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub signatures: Vec, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub referenced_declarations: Vec, + pub related_files: Vec, pub events: Vec>, #[serde(default)] pub can_collect_data: bool, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub diagnostic_groups: Vec, - #[serde(skip_serializing_if = "is_default", default)] - pub diagnostic_groups_truncated: bool, /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, @@ -58,7 +50,7 @@ pub struct PredictEditsRequest { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IncludedFile { +pub struct RelatedFile { pub path: Arc, pub max_row: Line, pub excerpts: Vec, @@ -72,11 +64,9 @@ pub struct Excerpt { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] pub enum PromptFormat { - MarkedExcerpt, - LabeledSections, - NumLinesUniDiff, + /// XML old_tex/new_text OldTextNewText, - /// Prompt format intended for use via zeta_cli + /// Prompt format intended for use via edit_prediction_cli OnlySnippets, /// One-sentence instructions used in fine-tuned models Minimal, @@ -87,7 +77,7 @@ pub enum PromptFormat { } impl PromptFormat { - pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff; + pub const DEFAULT: PromptFormat = PromptFormat::Minimal; } impl Default for PromptFormat { @@ -105,10 +95,7 @@ impl PromptFormat { impl std::fmt::Display for PromptFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"), - PromptFormat::LabeledSections => write!(f, "Labeled Sections"), PromptFormat::OnlySnippets => write!(f, "Only Snippets"), - PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"), PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), PromptFormat::Minimal => write!(f, "Minimal"), PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"), @@ -178,67 +165,6 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Signature { - pub text: String, - pub text_is_truncated: bool, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub parent_index: Option, - /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The - /// file is implicitly the file that contains the descendant declaration or excerpt. - pub range: Range, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ReferencedDeclaration { - pub path: Arc, - pub text: String, - pub text_is_truncated: bool, - /// Range of `text` within file, possibly truncated according to `text_is_truncated` - pub range: Range, - /// Range within `text` - pub signature_range: Range, - /// Index within `signatures`. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub parent_index: Option, - pub score_components: DeclarationScoreComponents, - pub signature_score: f32, - pub declaration_score: f32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeclarationScoreComponents { - pub is_same_file: bool, - pub is_referenced_nearby: bool, - pub is_referenced_in_breadcrumb: bool, - pub reference_count: usize, - pub same_file_declaration_count: usize, - pub declaration_count: usize, - pub reference_line_distance: u32, - pub declaration_line_distance: u32, - pub excerpt_vs_item_jaccard: f32, - pub excerpt_vs_signature_jaccard: f32, - pub adjacent_vs_item_jaccard: f32, - pub adjacent_vs_signature_jaccard: f32, - pub excerpt_vs_item_weighted_overlap: f32, - pub excerpt_vs_signature_weighted_overlap: f32, - pub adjacent_vs_item_weighted_overlap: f32, - pub adjacent_vs_signature_weighted_overlap: f32, - pub path_import_match_count: usize, - pub wildcard_path_import_match_count: usize, - pub import_similarity: f32, - pub max_import_similarity: f32, - pub normalized_import_similarity: f32, - pub wildcard_import_similarity: f32, - pub normalized_wildcard_import_similarity: f32, - pub included_by_others: usize, - pub includes_others: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(transparent)] -pub struct DiagnosticGroup(pub Box); - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsResponse { pub request_id: Uuid, @@ -262,10 +188,6 @@ pub struct Edit { pub content: String, } -fn is_default(value: &T) -> bool { - *value == T::default() -} - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] pub struct Point { pub line: Line, diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index fa8246950f8d03029388e0276954de946efc2346..a15e3fe43c28349920433272c4040ccc58ff4cb4 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -15,9 +15,4 @@ path = "src/cloud_zeta2_prompt.rs" anyhow.workspace = true cloud_llm_client.workspace = true indoc.workspace = true -ordered-float.workspace = true -rustc-hash.workspace = true -schemars.workspace = true serde.workspace = true -serde_json.workspace = true -strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index d67190c17556c5eb8b901e9baad73cc2691a9c78..62bfa45f47d0fdfefa9fbd72320c0ddee71cbc47 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,20 +1,12 @@ -//! Zeta2 prompt planning and generation code shared with cloud. -pub mod retrieval_prompt; - -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use cloud_llm_client::predict_edits_v3::{ - self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat, - ReferencedDeclaration, + self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile, }; use indoc::indoc; -use ordered_float::OrderedFloat; -use rustc_hash::{FxHashMap, FxHashSet}; -use serde::Serialize; use std::cmp; use std::fmt::Write; +use std::path::Path; use std::sync::Arc; -use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path}; -use strum::{EnumIter, IntoEnumIterator}; pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024; @@ -24,69 +16,6 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s /// NOTE: Differs from zed version of constant - includes a newline pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n"; -// TODO: use constants for markers? -const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {" - You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - - The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region. - - Other code is provided for context, and `…` indicates when code has been skipped. - - ## Edit History - -"}; - -const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#" - You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code. - - Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`). - - The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it. - - Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example: - - <|current_section|> - for i in 0..16 { - println!("{i}"); - } - - ## Edit History - -"#}; - -const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#" - # Instructions - - You are an edit prediction agent in a code editor. - Your job is to predict the next edit that the user will make, - based on their last few edits and their current cursor location. - - ## Output Format - - You must briefly explain your understanding of the user's goal, in one - or two sentences, and then specify their next edit in the form of a - unified diff, like this: - - ``` - --- a/src/myapp/cli.py - +++ b/src/myapp/cli.py - @@ ... @@ - import os - import time - import sys - +from constants import LOG_LEVEL_WARNING - @@ ... @@ - config.headless() - config.set_interactive(false) - -config.set_log_level(LOG_L) - +config.set_log_level(LOG_LEVEL_WARNING) - config.set_use_color(True) - ``` - - ## Edit History - -"#}; - const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase. @@ -94,20 +23,6 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" "#}; -const UNIFIED_DIFF_REMINDER: &str = indoc! {" - --- - - Analyze the edit history and the files, then provide the unified diff for your predicted edits. - Do not include the cursor marker in your output. - Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`). - Do not include line numbers in the hunk headers, use `@@ ... @@`. - Removed lines begin with `-`. - Added lines begin with `+`. - Context lines begin with an extra space. - Context and removed lines are used to match the target edit location, so make sure to include enough of them - to uniquely identify it amongst all excerpts of code provided. -"}; - const MINIMAL_PROMPT_REMINDER: &str = indoc! {" --- @@ -164,49 +79,25 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#" Remember that the edits in the edit history have already been applied. "#}; -pub fn build_prompt( - request: &predict_edits_v3::PredictEditsRequest, -) -> Result<(String, SectionLabels)> { - let mut section_labels = Default::default(); - +pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result { let prompt_data = PromptData { events: request.events.clone(), cursor_point: request.cursor_point, cursor_path: request.excerpt_path.clone(), - included_files: request.included_files.clone(), + included_files: request.related_files.clone(), }; match request.prompt_format { PromptFormat::MinimalQwen => { - return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels)); + return Ok(MinimalQwenPrompt.render(&prompt_data)); } PromptFormat::SeedCoder1120 => { - return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels)); + return Ok(SeedCoder1120Prompt.render(&prompt_data)); } _ => (), }; - let mut insertions = match request.prompt_format { - PromptFormat::MarkedExcerpt => vec![ - ( - Point { - line: request.excerpt_line_range.start, - column: 0, - }, - EDITABLE_REGION_START_MARKER_WITH_NEWLINE, - ), - (request.cursor_point, CURSOR_MARKER), - ( - Point { - line: request.excerpt_line_range.end, - column: 0, - }, - EDITABLE_REGION_END_MARKER_WITH_NEWLINE, - ), - ], - PromptFormat::LabeledSections - | PromptFormat::NumLinesUniDiff - | PromptFormat::Minimal - | PromptFormat::OldTextNewText => { + let insertions = match request.prompt_format { + PromptFormat::Minimal | PromptFormat::OldTextNewText => { vec![(request.cursor_point, CURSOR_MARKER)] } PromptFormat::OnlySnippets => vec![], @@ -215,9 +106,6 @@ pub fn build_prompt( }; let mut prompt = match request.prompt_format { - PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), - PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), - PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), PromptFormat::OnlySnippets => String::new(), PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), @@ -247,7 +135,7 @@ pub fn build_prompt( You can only edit exactly this part of the file. We prepend line numbers (e.g., `123|`); they are not part of the file.) "}, - PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {" + PromptFormat::OldTextNewText => indoc! {" ## Code Excerpts Here is some excerpts of code that you should take into account to predict the next edit. @@ -263,64 +151,51 @@ pub fn build_prompt( Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs. "}, - _ => indoc! {" + PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { + indoc! {" ## Code Excerpts The cursor marker <|user_cursor|> indicates the current user cursor position. The file is in current state, edits from edit history have been applied. - "}, + "} + } }; prompt.push_str(excerpts_preamble); prompt.push('\n'); - if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() { - let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?; - section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?; - } else { - if request.prompt_format == PromptFormat::LabeledSections { - anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm"); - } - - let include_line_numbers = matches!( - request.prompt_format, - PromptFormat::NumLinesUniDiff | PromptFormat::Minimal - ); - for related_file in &request.included_files { - if request.prompt_format == PromptFormat::Minimal { - write_codeblock_with_filename( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } else { - write_codeblock( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } + let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal); + for related_file in &request.related_files { + if request.prompt_format == PromptFormat::Minimal { + write_codeblock_with_filename( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); + } else { + write_codeblock( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); } } match request.prompt_format { - PromptFormat::NumLinesUniDiff => { - prompt.push_str(UNIFIED_DIFF_REMINDER); - } PromptFormat::OldTextNewText => { prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER); } @@ -330,7 +205,7 @@ pub fn build_prompt( _ => {} } - Ok((prompt, section_labels)) + Ok(prompt) } pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams { @@ -444,476 +319,11 @@ pub fn push_events(output: &mut String, events: &[Arc]) writeln!(output, "`````\n").unwrap(); } -pub struct SyntaxBasedPrompt<'a> { - request: &'a predict_edits_v3::PredictEditsRequest, - /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in - /// `to_prompt_string`. - snippets: Vec>, - budget_used: usize, -} - -#[derive(Clone, Debug)] -pub struct PlannedSnippet<'a> { - path: Arc, - range: Range, - text: &'a str, - // TODO: Indicate this in the output - #[allow(dead_code)] - text_is_truncated: bool, -} - -#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] -pub enum DeclarationStyle { - Signature, - Declaration, -} - -#[derive(Default, Clone, Debug, Serialize)] -pub struct SectionLabels { - pub excerpt_index: usize, - pub section_ranges: Vec<(Arc, Range)>, -} - -impl<'a> SyntaxBasedPrompt<'a> { - /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: - /// - /// Initializes a priority queue by populating it with each snippet, finding the - /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a - /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects - /// the cost of upgrade. - /// - /// TODO: Implement an early halting condition. One option might be to have another priority - /// queue where the score is the size, and update it accordingly. Another option might be to - /// have some simpler heuristic like bailing after N failed insertions, or based on how much - /// budget is left. - /// - /// TODO: Has the current known sources of imprecision: - /// - /// * Does not consider snippet overlap when ranking. For example, it might add a field to the - /// plan even though the containing struct is already included. - /// - /// * Does not consider cost of signatures when ranking snippets - this is tricky since - /// signatures may be shared by multiple snippets. - /// - /// * Does not include file paths / other text when considering max_bytes. - pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result { - let mut this = Self { - request, - snippets: Vec::new(), - budget_used: request.excerpt.len(), - }; - let mut included_parents = FxHashSet::default(); - let additional_parents = this.additional_parent_signatures( - &request.excerpt_path, - request.excerpt_parent, - &included_parents, - )?; - this.add_parents(&mut included_parents, additional_parents); - - let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES); - - if this.budget_used > max_bytes { - return Err(anyhow!( - "Excerpt + signatures size of {} already exceeds budget of {}", - this.budget_used, - max_bytes - )); - } - - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] - struct QueueEntry { - score_density: OrderedFloat, - declaration_index: usize, - style: DeclarationStyle, - } - - // Initialize priority queue with the best score for each snippet. - let mut queue: BinaryHeap = BinaryHeap::new(); - for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() { - let (style, score_density) = DeclarationStyle::iter() - .map(|style| { - ( - style, - OrderedFloat(declaration_score_density(&declaration, style)), - ) - }) - .max_by_key(|(_, score_density)| *score_density) - .unwrap(); - queue.push(QueueEntry { - score_density, - declaration_index, - style, - }); - } - - // Knapsack selection loop - while let Some(queue_entry) = queue.pop() { - let Some(declaration) = request - .referenced_declarations - .get(queue_entry.declaration_index) - else { - return Err(anyhow!( - "Invalid declaration index {}", - queue_entry.declaration_index - )); - }; - - let mut additional_bytes = declaration_size(declaration, queue_entry.style); - if this.budget_used + additional_bytes > max_bytes { - continue; - } - - let additional_parents = this.additional_parent_signatures( - &declaration.path, - declaration.parent_index, - &mut included_parents, - )?; - additional_bytes += additional_parents - .iter() - .map(|(_, snippet)| snippet.text.len()) - .sum::(); - if this.budget_used + additional_bytes > max_bytes { - continue; - } - - this.budget_used += additional_bytes; - this.add_parents(&mut included_parents, additional_parents); - let planned_snippet = match queue_entry.style { - DeclarationStyle::Signature => { - let Some(text) = declaration.text.get(declaration.signature_range.clone()) - else { - return Err(anyhow!( - "Invalid declaration signature_range {:?} with text.len() = {}", - declaration.signature_range, - declaration.text.len() - )); - }; - let signature_start_line = declaration.range.start - + Line( - declaration.text[..declaration.signature_range.start] - .lines() - .count() as u32, - ); - let signature_end_line = signature_start_line - + Line( - declaration.text - [declaration.signature_range.start..declaration.signature_range.end] - .lines() - .count() as u32, - ); - let range = signature_start_line..signature_end_line; - - PlannedSnippet { - path: declaration.path.clone(), - range, - text, - text_is_truncated: declaration.text_is_truncated, - } - } - DeclarationStyle::Declaration => PlannedSnippet { - path: declaration.path.clone(), - range: declaration.range.clone(), - text: &declaration.text, - text_is_truncated: declaration.text_is_truncated, - }, - }; - this.snippets.push(planned_snippet); - - // When a Signature is consumed, insert an entry for Definition style. - if queue_entry.style == DeclarationStyle::Signature { - let signature_size = declaration_size(&declaration, DeclarationStyle::Signature); - let declaration_size = - declaration_size(&declaration, DeclarationStyle::Declaration); - let signature_score = declaration_score(&declaration, DeclarationStyle::Signature); - let declaration_score = - declaration_score(&declaration, DeclarationStyle::Declaration); - - let score_diff = declaration_score - signature_score; - let size_diff = declaration_size.saturating_sub(signature_size); - if score_diff > 0.0001 && size_diff > 0 { - queue.push(QueueEntry { - declaration_index: queue_entry.declaration_index, - score_density: OrderedFloat(score_diff / (size_diff as f32)), - style: DeclarationStyle::Declaration, - }); - } - } - } - - anyhow::Ok(this) - } - - fn add_parents( - &mut self, - included_parents: &mut FxHashSet, - snippets: Vec<(usize, PlannedSnippet<'a>)>, - ) { - for (parent_index, snippet) in snippets { - included_parents.insert(parent_index); - self.budget_used += snippet.text.len(); - self.snippets.push(snippet); - } - } - - fn additional_parent_signatures( - &self, - path: &Arc, - parent_index: Option, - included_parents: &FxHashSet, - ) -> Result)>> { - let mut results = Vec::new(); - self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?; - Ok(results) - } - - fn additional_parent_signatures_impl( - &self, - path: &Arc, - parent_index: Option, - included_parents: &FxHashSet, - results: &mut Vec<(usize, PlannedSnippet<'a>)>, - ) -> Result<()> { - let Some(parent_index) = parent_index else { - return Ok(()); - }; - if included_parents.contains(&parent_index) { - return Ok(()); - } - let Some(parent_signature) = self.request.signatures.get(parent_index) else { - return Err(anyhow!("Invalid parent index {}", parent_index)); - }; - results.push(( - parent_index, - PlannedSnippet { - path: path.clone(), - range: parent_signature.range.clone(), - text: &parent_signature.text, - text_is_truncated: parent_signature.text_is_truncated, - }, - )); - self.additional_parent_signatures_impl( - path, - parent_signature.parent_index, - included_parents, - results, - ) - } - - /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple - /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive - /// chunks. - pub fn write( - &'a self, - excerpt_file_insertions: &mut Vec<(Point, &'static str)>, - prompt: &mut String, - ) -> Result { - let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> = - FxHashMap::default(); - for snippet in &self.snippets { - file_to_snippets - .entry(&snippet.path) - .or_default() - .push(snippet); - } - - // Reorder so that file with cursor comes last - let mut file_snippets = Vec::new(); - let mut excerpt_file_snippets = Vec::new(); - for (file_path, snippets) in file_to_snippets { - if file_path == self.request.excerpt_path.as_ref() { - excerpt_file_snippets = snippets; - } else { - file_snippets.push((file_path, snippets, false)); - } - } - let excerpt_snippet = PlannedSnippet { - path: self.request.excerpt_path.clone(), - range: self.request.excerpt_line_range.clone(), - text: &self.request.excerpt, - text_is_truncated: false, - }; - excerpt_file_snippets.push(&excerpt_snippet); - file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true)); - - let section_labels = - self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?; - - Ok(section_labels) - } - - fn push_file_snippets( - &self, - output: &mut String, - excerpt_file_insertions: &mut Vec<(Point, &'static str)>, - file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>, - ) -> Result { - let mut section_ranges = Vec::new(); - let mut excerpt_index = None; - - for (file_path, mut snippets, is_excerpt_file) in file_snippets { - snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end))); - - // TODO: What if the snippets get expanded too large to be editable? - let mut current_snippet: Option<(&PlannedSnippet, Range)> = None; - let mut disjoint_snippets: Vec<(&PlannedSnippet, Range)> = Vec::new(); - for snippet in snippets { - if let Some((_, current_snippet_range)) = current_snippet.as_mut() - && snippet.range.start <= current_snippet_range.end - { - current_snippet_range.end = current_snippet_range.end.max(snippet.range.end); - continue; - } - if let Some(current_snippet) = current_snippet.take() { - disjoint_snippets.push(current_snippet); - } - current_snippet = Some((snippet, snippet.range.clone())); - } - if let Some(current_snippet) = current_snippet.take() { - disjoint_snippets.push(current_snippet); - } - - writeln!(output, "`````path={}", file_path.display()).ok(); - let mut skipped_last_snippet = false; - for (snippet, range) in disjoint_snippets { - let section_index = section_ranges.len(); - - match self.request.prompt_format { - PromptFormat::MarkedExcerpt - | PromptFormat::OnlySnippets - | PromptFormat::OldTextNewText - | PromptFormat::Minimal - | PromptFormat::NumLinesUniDiff => { - if range.start.0 > 0 && !skipped_last_snippet { - output.push_str("…\n"); - } - } - PromptFormat::LabeledSections => { - if is_excerpt_file - && range.start <= self.request.excerpt_line_range.start - && range.end >= self.request.excerpt_line_range.end - { - writeln!(output, "<|current_section|>").ok(); - } else { - writeln!(output, "<|section_{}|>", section_index).ok(); - } - } - PromptFormat::MinimalQwen => unreachable!(), - PromptFormat::SeedCoder1120 => unreachable!(), - } - - let push_full_snippet = |output: &mut String| { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - for (i, line) in snippet.text.lines().enumerate() { - writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?; - } - } else { - output.push_str(&snippet.text); - } - anyhow::Ok(()) - }; - - if is_excerpt_file { - if self.request.prompt_format == PromptFormat::OnlySnippets { - if range.start >= self.request.excerpt_line_range.start - && range.end <= self.request.excerpt_line_range.end - { - skipped_last_snippet = true; - } else { - skipped_last_snippet = false; - output.push_str(snippet.text); - } - } else if !excerpt_file_insertions.is_empty() { - let lines = snippet.text.lines().collect::>(); - let push_line = |output: &mut String, line_ix: usize| { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?; - } - anyhow::Ok(writeln!(output, "{}", lines[line_ix])?) - }; - let mut last_line_ix = 0; - let mut insertion_ix = 0; - while insertion_ix < excerpt_file_insertions.len() { - let (point, insertion) = &excerpt_file_insertions[insertion_ix]; - let found = point.line >= range.start && point.line <= range.end; - if found { - excerpt_index = Some(section_index); - let insertion_line_ix = (point.line.0 - range.start.0) as usize; - for line_ix in last_line_ix..insertion_line_ix { - push_line(output, line_ix)?; - } - if let Some(next_line) = lines.get(insertion_line_ix) { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - write!( - output, - "{}|", - insertion_line_ix as u32 + range.start.0 + 1 - )? - } - output.push_str(&next_line[..point.column as usize]); - output.push_str(insertion); - writeln!(output, "{}", &next_line[point.column as usize..])?; - } else { - writeln!(output, "{}", insertion)?; - } - last_line_ix = insertion_line_ix + 1; - excerpt_file_insertions.remove(insertion_ix); - continue; - } - insertion_ix += 1; - } - skipped_last_snippet = false; - for line_ix in last_line_ix..lines.len() { - push_line(output, line_ix)?; - } - } else { - skipped_last_snippet = false; - push_full_snippet(output)?; - } - } else { - skipped_last_snippet = false; - push_full_snippet(output)?; - } - - section_ranges.push((snippet.path.clone(), range)); - } - - output.push_str("`````\n\n"); - } - - Ok(SectionLabels { - // TODO: Clean this up - excerpt_index: match self.request.prompt_format { - PromptFormat::OnlySnippets => 0, - _ => excerpt_index.context("bug: no snippet found for excerpt")?, - }, - section_ranges, - }) - } -} - -fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { - declaration_score(declaration, style) / declaration_size(declaration, style) as f32 -} - -fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { - match style { - DeclarationStyle::Signature => declaration.signature_score, - DeclarationStyle::Declaration => declaration.declaration_score, - } -} - -fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize { - match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.text.len(), - } -} - struct PromptData { events: Vec>, cursor_point: Point, cursor_path: Arc, // TODO: make a common struct with cursor_point - included_files: Vec, + included_files: Vec, } #[derive(Default)] @@ -1051,7 +461,7 @@ impl SeedCoder1120Prompt { context } - fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String { + fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String { let mut buf = String::new(); const FIM_SUFFIX: &str = "<[fim-suffix]>"; const FIM_PREFIX: &str = "<[fim-prefix]>"; diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs deleted file mode 100644 index fd35f63f03ff967491a28d817852f6622e4919ca..0000000000000000000000000000000000000000 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ /dev/null @@ -1,244 +0,0 @@ -use anyhow::Result; -use cloud_llm_client::predict_edits_v3::{self, Excerpt}; -use indoc::indoc; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::fmt::Write; - -use crate::{push_events, write_codeblock}; - -pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result { - let mut prompt = SEARCH_INSTRUCTIONS.to_string(); - - if !request.events.is_empty() { - writeln!(&mut prompt, "\n## User Edits\n\n")?; - push_events(&mut prompt, &request.events); - } - - writeln!(&mut prompt, "## Cursor context\n")?; - write_codeblock( - &request.excerpt_path, - &[Excerpt { - start_line: request.excerpt_line_range.start, - text: request.excerpt.into(), - }], - &[], - request.cursor_file_max_row, - true, - &mut prompt, - ); - - writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?; - - Ok(prompt) -} - -/// Search for relevant code -/// -/// For the best results, run multiple queries at once with a single invocation of this tool. -#[derive(Clone, Deserialize, Serialize, JsonSchema)] -pub struct SearchToolInput { - /// An array of queries to run for gathering context relevant to the next prediction - #[schemars(length(max = 3))] - #[serde(deserialize_with = "deserialize_queries")] - pub queries: Box<[SearchToolQuery]>, -} - -fn deserialize_queries<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - use serde::de::Error; - - #[derive(Deserialize)] - #[serde(untagged)] - enum QueryCollection { - Array(Box<[SearchToolQuery]>), - DoubleArray(Box<[Box<[SearchToolQuery]>]>), - Single(SearchToolQuery), - } - - #[derive(Deserialize)] - #[serde(untagged)] - enum MaybeDoubleEncoded { - SingleEncoded(QueryCollection), - DoubleEncoded(String), - } - - let result = MaybeDoubleEncoded::deserialize(deserializer)?; - - let normalized = match result { - MaybeDoubleEncoded::SingleEncoded(value) => value, - MaybeDoubleEncoded::DoubleEncoded(value) => { - serde_json::from_str(&value).map_err(D::Error::custom)? - } - }; - - Ok(match normalized { - QueryCollection::Array(items) => items, - QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]), - QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(), - }) -} - -/// Search for relevant code by path, syntax hierarchy, and content. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)] -pub struct SearchToolQuery { - /// 1. A glob pattern to match file paths in the codebase to search in. - pub glob: String, - /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy. - /// - /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes. - /// - /// Example: Searching for a `User` class - /// ["class\s+User"] - /// - /// Example: Searching for a `get_full_name` method under a `User` class - /// ["class\s+User", "def\sget_full_name"] - /// - /// Skip this field to match on content alone. - #[schemars(length(max = 3))] - #[serde(default)] - pub syntax_node: Vec, - /// 3. An optional regular expression to match the final content that should appear in the results. - /// - /// - Content will be matched within all lines of the matched syntax nodes. - /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible. - /// - If no syntax node regexes are provided, the content will be matched within the entire file. - pub content: Option, -} - -pub const TOOL_NAME: &str = "search"; - -const SEARCH_INSTRUCTIONS: &str = indoc! {r#" - You are part of an edit prediction system in a code editor. - Your role is to search for code that will serve as context for predicting the next edit. - - - Analyze the user's recent edits and current cursor context - - Use the `search` tool to find code that is relevant for predicting the next edit - - Focus on finding: - - Code patterns that might need similar changes based on the recent edits - - Functions, variables, types, and constants referenced in the current cursor context - - Related implementations, usages, or dependencies that may require consistent updates - - How items defined in the cursor excerpt are used or altered - - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible - - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function - - Avoid using wildcard globs if you already know the file path of the content you're looking for -"#}; - -const TOOL_USE_REMINDER: &str = indoc! {" - -- - Analyze the user's intent in one to two sentences, then call the `search` tool. -"}; - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - - #[test] - fn test_deserialize_queries() { - let single_query_json = indoc! {r#"{ - "queries": { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - } - }"#}; - - let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap(); - assert_eq!(flat_input.queries.len(), 1); - assert_eq!(flat_input.queries[0].glob, "**/*.rs"); - assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); - - let flat_json = indoc! {r#"{ - "queries": [ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - }, - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ] - }"#}; - - let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap(); - assert_eq!(flat_input.queries.len(), 2); - assert_eq!(flat_input.queries[0].glob, "**/*.rs"); - assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); - assert_eq!(flat_input.queries[1].glob, "**/*.ts"); - assert_eq!(flat_input.queries[1].syntax_node.len(), 0); - assert_eq!(flat_input.queries[1].content, None); - - let nested_json = indoc! {r#"{ - "queries": [ - [ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - } - ], - [ - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ] - ] - }"#}; - - let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap(); - - assert_eq!(nested_input.queries.len(), 2); - - assert_eq!(nested_input.queries[0].glob, "**/*.rs"); - assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(nested_input.queries[0].content, Some("assert".to_string())); - assert_eq!(nested_input.queries[1].glob, "**/*.ts"); - assert_eq!(nested_input.queries[1].syntax_node.len(), 0); - assert_eq!(nested_input.queries[1].content, None); - - let double_encoded_queries = serde_json::to_string(&json!({ - "queries": serde_json::to_string(&json!([ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - }, - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ])).unwrap() - })) - .unwrap(); - - let double_encoded_input: SearchToolInput = - serde_json::from_str(&double_encoded_queries).unwrap(); - - assert_eq!(double_encoded_input.queries.len(), 2); - - assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs"); - assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!( - double_encoded_input.queries[0].content, - Some("assert".to_string()) - ); - assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts"); - assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0); - assert_eq!(double_encoded_input.queries[1].content, None); - - // ### ERROR Switching from var declarations to lexical declarations [RUN 073] - // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]} - } -} diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml index b402274a33530424349081da764a4b6766e419e9..7f3bf3b22dda8f9dbde1923c76855342c6cbac4c 100644 --- a/crates/codestral/Cargo.toml +++ b/crates/codestral/Cargo.toml @@ -10,7 +10,7 @@ path = "src/codestral.rs" [dependencies] anyhow.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true edit_prediction_context.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 6a500acbf6ec5eea63c35a8deb83a8545cee497e..9bf0296ac357937cd1ad1470dba9a98864911de9 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; -use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate}; use futures::AsyncReadExt; use gpui::{App, Context, Entity, Task}; use http_client::HttpClient; @@ -43,17 +43,17 @@ impl CurrentCompletion { /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated. /// Returns None if the user's edits conflict with the predicted edits. fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { - edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) + edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) } } -pub struct CodestralCompletionProvider { +pub struct CodestralEditPredictionDelegate { http_client: Arc, pending_request: Option>>, current_completion: Option, } -impl CodestralCompletionProvider { +impl CodestralEditPredictionDelegate { pub fn new(http_client: Arc) -> Self { Self { http_client, @@ -165,7 +165,7 @@ impl CodestralCompletionProvider { } } -impl EditPredictionProvider for CodestralCompletionProvider { +impl EditPredictionDelegate for CodestralEditPredictionDelegate { fn name() -> &'static str { "codestral" } @@ -174,7 +174,7 @@ impl EditPredictionProvider for CodestralCompletionProvider { "Codestral" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -239,7 +239,6 @@ impl EditPredictionProvider for CodestralCompletionProvider { cursor_point, &snapshot, &EXCERPT_OPTIONS, - None, ) .context("Line containing cursor doesn't fit in excerpt max bytes")?; diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 0d3b19c0c7bd264f8ed10e53289376055f833307..459abda17573d66287e2c8ca0b995292acaf163b 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -33,7 +33,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true language.workspace = true log.workspace = true lsp.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index ed18a199bf2c08c8c046a8ad3e7f945b1340643e..6fbdeff807b65d22193ba7fdcb8e990f7184f70e 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,5 +1,5 @@ pub mod copilot_chat; -mod copilot_completion_provider; +mod copilot_edit_prediction_delegate; pub mod copilot_responses; pub mod request; mod sign_in; @@ -46,7 +46,7 @@ use util::rel_path::RelPath; use util::{ResultExt, fs::remove_matching}; use workspace::Workspace; -pub use crate::copilot_completion_provider::CopilotCompletionProvider; +pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate; pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in}; actions!( diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_edit_prediction_delegate.rs similarity index 98% rename from crates/copilot/src/copilot_completion_provider.rs rename to crates/copilot/src/copilot_edit_prediction_delegate.rs index e92f0c7d7dd7e51c4a8fdc19f34bd6eb4189c097..961154dbeecad007f026f25eeac25de95d751d9e 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_edit_prediction_delegate.rs @@ -1,6 +1,6 @@ use crate::{Completion, Copilot}; use anyhow::Result; -use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; +use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate}; use gpui::{App, Context, Entity, EntityId, Task}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; use settings::Settings; @@ -8,7 +8,7 @@ use std::{path::Path, time::Duration}; pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); -pub struct CopilotCompletionProvider { +pub struct CopilotEditPredictionDelegate { cycled: bool, buffer_id: Option, completions: Vec, @@ -19,7 +19,7 @@ pub struct CopilotCompletionProvider { copilot: Entity, } -impl CopilotCompletionProvider { +impl CopilotEditPredictionDelegate { pub fn new(copilot: Entity) -> Self { Self { cycled: false, @@ -47,7 +47,7 @@ impl CopilotCompletionProvider { } } -impl EditPredictionProvider for CopilotCompletionProvider { +impl EditPredictionDelegate for CopilotEditPredictionDelegate { fn name() -> &'static str { "copilot" } @@ -56,7 +56,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { "Copilot" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -314,7 +314,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -546,7 +546,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -670,7 +670,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -753,7 +753,7 @@ mod tests { window.focus(&editor.focus_handle(cx)); }) .unwrap(); - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); editor .update(cx, |editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) @@ -848,7 +848,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -1000,7 +1000,7 @@ mod tests { window.focus(&editor.focus_handle(cx)) }) .unwrap(); - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); editor .update(cx, |editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 2c6888d14be49c857e7805fb63f9f9335ac32c8e..6e62cfa6f038671d595c5671de147cdc2125064d 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -11,7 +11,69 @@ workspace = true [lib] path = "src/edit_prediction.rs" +[features] +eval-support = [] + [dependencies] +ai_onboarding.workspace = true +anyhow.workspace = true +arrayvec.workspace = true +brotli.workspace = true client.workspace = true +cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true +collections.workspace = true +copilot.workspace = true +credentials_provider.workspace = true +db.workspace = true +edit_prediction_types.workspace = true +edit_prediction_context.workspace = true +feature_flags.workspace = true +fs.workspace = true +futures.workspace = true gpui.workspace = true +indoc.workspace = true +itertools.workspace = true language.workspace = true +language_model.workspace = true +log.workspace = true +lsp.workspace = true +menu.workspace = true +open_ai.workspace = true +postage.workspace = true +pretty_assertions.workspace = true +project.workspace = true +rand.workspace = true +regex.workspace = true +release_channel.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strsim.workspace = true +strum.workspace = true +telemetry.workspace = true +telemetry_events.workspace = true +thiserror.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +workspace.workspace = true +worktree.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +cloud_llm_client = { workspace = true, features = ["test-support"] } +ctor.workspace = true +gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } +language_model = { workspace = true, features = ["test-support"] } +lsp.workspace = true +parking_lot.workspace = true +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/zeta/license_examples/0bsd.txt b/crates/edit_prediction/license_examples/0bsd.txt similarity index 100% rename from crates/zeta/license_examples/0bsd.txt rename to crates/edit_prediction/license_examples/0bsd.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex0.txt b/crates/edit_prediction/license_examples/apache-2.0-ex0.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex0.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex0.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex1.txt b/crates/edit_prediction/license_examples/apache-2.0-ex1.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex1.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex1.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex2.txt b/crates/edit_prediction/license_examples/apache-2.0-ex2.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex2.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex2.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex3.txt b/crates/edit_prediction/license_examples/apache-2.0-ex3.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex3.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex3.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex4.txt b/crates/edit_prediction/license_examples/apache-2.0-ex4.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex4.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex4.txt diff --git a/crates/zeta/license_examples/bsd-1-clause.txt b/crates/edit_prediction/license_examples/bsd-1-clause.txt similarity index 100% rename from crates/zeta/license_examples/bsd-1-clause.txt rename to crates/edit_prediction/license_examples/bsd-1-clause.txt diff --git a/crates/zeta/license_examples/bsd-2-clause-ex0.txt b/crates/edit_prediction/license_examples/bsd-2-clause-ex0.txt similarity index 100% rename from crates/zeta/license_examples/bsd-2-clause-ex0.txt rename to crates/edit_prediction/license_examples/bsd-2-clause-ex0.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex0.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex0.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex0.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex0.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex1.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex1.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex1.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex1.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex2.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex2.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex2.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex2.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex3.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex3.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex3.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex3.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex4.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex4.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex4.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex4.txt diff --git a/crates/zeta/license_examples/isc.txt b/crates/edit_prediction/license_examples/isc.txt similarity index 100% rename from crates/zeta/license_examples/isc.txt rename to crates/edit_prediction/license_examples/isc.txt diff --git a/crates/zeta/license_examples/mit-ex0.txt b/crates/edit_prediction/license_examples/mit-ex0.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex0.txt rename to crates/edit_prediction/license_examples/mit-ex0.txt diff --git a/crates/zeta/license_examples/mit-ex1.txt b/crates/edit_prediction/license_examples/mit-ex1.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex1.txt rename to crates/edit_prediction/license_examples/mit-ex1.txt diff --git a/crates/zeta/license_examples/mit-ex2.txt b/crates/edit_prediction/license_examples/mit-ex2.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex2.txt rename to crates/edit_prediction/license_examples/mit-ex2.txt diff --git a/crates/zeta/license_examples/mit-ex3.txt b/crates/edit_prediction/license_examples/mit-ex3.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex3.txt rename to crates/edit_prediction/license_examples/mit-ex3.txt diff --git a/crates/zeta/license_examples/upl-1.0.txt b/crates/edit_prediction/license_examples/upl-1.0.txt similarity index 100% rename from crates/zeta/license_examples/upl-1.0.txt rename to crates/edit_prediction/license_examples/upl-1.0.txt diff --git a/crates/zeta/license_examples/zlib-ex0.txt b/crates/edit_prediction/license_examples/zlib-ex0.txt similarity index 100% rename from crates/zeta/license_examples/zlib-ex0.txt rename to crates/edit_prediction/license_examples/zlib-ex0.txt diff --git a/crates/zeta/license_patterns/0bsd-pattern b/crates/edit_prediction/license_patterns/0bsd-pattern similarity index 100% rename from crates/zeta/license_patterns/0bsd-pattern rename to crates/edit_prediction/license_patterns/0bsd-pattern diff --git a/crates/zeta/license_patterns/apache-2.0-pattern b/crates/edit_prediction/license_patterns/apache-2.0-pattern similarity index 100% rename from crates/zeta/license_patterns/apache-2.0-pattern rename to crates/edit_prediction/license_patterns/apache-2.0-pattern diff --git a/crates/zeta/license_patterns/apache-2.0-reference-pattern b/crates/edit_prediction/license_patterns/apache-2.0-reference-pattern similarity index 100% rename from crates/zeta/license_patterns/apache-2.0-reference-pattern rename to crates/edit_prediction/license_patterns/apache-2.0-reference-pattern diff --git a/crates/zeta/license_patterns/bsd-pattern b/crates/edit_prediction/license_patterns/bsd-pattern similarity index 100% rename from crates/zeta/license_patterns/bsd-pattern rename to crates/edit_prediction/license_patterns/bsd-pattern diff --git a/crates/zeta/license_patterns/isc-pattern b/crates/edit_prediction/license_patterns/isc-pattern similarity index 100% rename from crates/zeta/license_patterns/isc-pattern rename to crates/edit_prediction/license_patterns/isc-pattern diff --git a/crates/zeta/license_patterns/mit-pattern b/crates/edit_prediction/license_patterns/mit-pattern similarity index 100% rename from crates/zeta/license_patterns/mit-pattern rename to crates/edit_prediction/license_patterns/mit-pattern diff --git a/crates/zeta/license_patterns/upl-1.0-pattern b/crates/edit_prediction/license_patterns/upl-1.0-pattern similarity index 100% rename from crates/zeta/license_patterns/upl-1.0-pattern rename to crates/edit_prediction/license_patterns/upl-1.0-pattern diff --git a/crates/zeta/license_patterns/zlib-pattern b/crates/edit_prediction/license_patterns/zlib-pattern similarity index 100% rename from crates/zeta/license_patterns/zlib-pattern rename to crates/edit_prediction/license_patterns/zlib-pattern diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 1984383a9691ae9373973a3eb9f00db4e7e795f2..ddb29d0796a6c6b24ee3914533b29b967d224ac8 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,298 +1,1911 @@ -use std::{ops::Range, sync::Arc}; +use anyhow::Result; +use arrayvec::ArrayVec; +use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, + EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, + ZED_VERSION_HEADER_NAME, +}; +use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; +use collections::{HashMap, HashSet}; +use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction_context::EditPredictionExcerptOptions; +use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; +use futures::{ + AsyncReadExt as _, FutureExt as _, StreamExt as _, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, + select_biased, +}; +use gpui::BackgroundExecutor; +use gpui::{ + App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, + http_client::{self, AsyncBody, Method}, + prelude::*, +}; +use language::language_settings::all_language_settings; +use language::{Anchor, Buffer, File, Point, ToPoint}; +use language::{BufferSnapshot, OffsetRangeExt}; +use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use project::{Project, ProjectPath, WorktreeId}; +use release_channel::AppVersion; +use semver::Version; +use serde::de::DeserializeOwned; +use settings::{EditPredictionProvider, SettingsStore, update_settings_file}; +use std::collections::{VecDeque, hash_map}; +use workspace::Workspace; + +use std::ops::Range; +use std::path::Path; +use std::rc::Rc; +use std::str::FromStr as _; +use std::sync::{Arc, LazyLock}; +use std::time::{Duration, Instant}; +use std::{env, mem}; +use thiserror::Error; +use util::{RangeExt as _, ResultExt as _}; +use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; + +mod license_detection; +mod onboarding_modal; +mod prediction; +pub mod sweep_ai; +pub mod udiff; +mod xml_edits; +mod zed_edit_prediction_delegate; +pub mod zeta1; +pub mod zeta2; + +#[cfg(test)] +mod edit_prediction_tests; + +use crate::license_detection::LicenseDetectionWatcher; +use crate::onboarding_modal::ZedPredictModal; +pub use crate::prediction::EditPrediction; +pub use crate::prediction::EditPredictionId; +pub use crate::prediction::EditPredictionInputs; +use crate::prediction::EditPredictionResult; +pub use crate::sweep_ai::SweepAi; +pub use telemetry_events::EditPredictionRating; +pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate; + +actions!( + edit_prediction, + [ + /// Resets the edit prediction onboarding state. + ResetOnboarding, + /// Clears the edit prediction history. + ClearHistory, + ] +); + +/// Maximum number of events to track. +const EVENT_COUNT_MAX: usize = 6; +const CHANGE_GROUPING_LINE_SPAN: u32 = 8; +const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; +const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); -use client::EditPredictionUsage; -use gpui::{App, Context, Entity, SharedString}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt}; +pub struct SweepFeatureFlag; -// TODO: Find a better home for `Direction`. -// -// This should live in an ancestor crate of `editor` and `edit_prediction`, -// but at time of writing there isn't an obvious spot. -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum Direction { - Prev, - Next, +impl FeatureFlag for SweepFeatureFlag { + const NAME: &str = "sweep-ai"; } -#[derive(Clone)] -pub enum EditPrediction { - /// Edits within the buffer that requested the prediction - Local { - id: Option, - edits: Vec<(Range, Arc)>, - edit_preview: Option, - }, - /// Jump to a different file from the one that requested the prediction - Jump { - id: Option, - snapshot: language::BufferSnapshot, - target: language::Anchor, +pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { + context: EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, }, + max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, + prompt_format: PromptFormat::DEFAULT, +}; + +static USE_OLLAMA: LazyLock = + LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); + +static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { + match env::var("ZED_ZETA2_MODEL").as_deref() { + Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten + Ok(model) => model, + Err(_) if *USE_OLLAMA => "qwen3-coder:30b", + Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten + } + .to_string() +}); +static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { + env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { + if *USE_OLLAMA { + Some("http://localhost:11434/v1/chat/completions".into()) + } else { + None + } + }) +}); + +pub struct Zeta2FeatureFlag; + +impl FeatureFlag for Zeta2FeatureFlag { + const NAME: &'static str = "zeta2"; + + fn enabled_for_staff() -> bool { + true + } } -pub enum DataCollectionState { - /// The provider doesn't support data collection. - Unsupported, - /// Data collection is enabled. - Enabled { is_project_open_source: bool }, - /// Data collection is disabled or unanswered. - Disabled { is_project_open_source: bool }, +#[derive(Clone)] +struct EditPredictionStoreGlobal(Entity); + +impl Global for EditPredictionStoreGlobal {} + +pub struct EditPredictionStore { + client: Arc, + user_store: Entity, + llm_token: LlmApiToken, + _llm_token_subscription: Subscription, + projects: HashMap, + use_context: bool, + options: ZetaOptions, + update_required: bool, + debug_tx: Option>, + #[cfg(feature = "eval-support")] + eval_cache: Option>, + edit_prediction_model: EditPredictionModel, + pub sweep_ai: SweepAi, + data_collection_choice: DataCollectionChoice, + reject_predictions_tx: mpsc::UnboundedSender, + shown_predictions: VecDeque, + rated_predictions: HashSet, +} + +#[derive(Copy, Clone, Default, PartialEq, Eq)] +pub enum EditPredictionModel { + #[default] + Zeta1, + Zeta2, + Sweep, } -impl DataCollectionState { - pub fn is_supported(&self) -> bool { - !matches!(self, DataCollectionState::Unsupported) +#[derive(Debug, Clone, PartialEq)] +pub struct ZetaOptions { + pub context: EditPredictionExcerptOptions, + pub max_prompt_bytes: usize, + pub prompt_format: predict_edits_v3::PromptFormat, +} + +#[derive(Debug)] +pub enum DebugEvent { + ContextRetrievalStarted(ContextRetrievalStartedDebugEvent), + ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent), + EditPredictionRequested(EditPredictionRequestedDebugEvent), +} + +#[derive(Debug)] +pub struct ContextRetrievalStartedDebugEvent { + pub project_entity_id: EntityId, + pub timestamp: Instant, + pub search_prompt: String, +} + +#[derive(Debug)] +pub struct ContextRetrievalFinishedDebugEvent { + pub project_entity_id: EntityId, + pub timestamp: Instant, + pub metadata: Vec<(&'static str, SharedString)>, +} + +#[derive(Debug)] +pub struct EditPredictionRequestedDebugEvent { + pub inputs: EditPredictionInputs, + pub retrieval_time: Duration, + pub buffer: WeakEntity, + pub position: Anchor, + pub local_prompt: Result, + pub response_rx: oneshot::Receiver<(Result, Duration)>, +} + +pub type RequestDebugInfo = predict_edits_v3::DebugInfo; + +struct ProjectState { + events: VecDeque>, + last_event: Option, + recent_paths: VecDeque, + registered_buffers: HashMap, + current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + context_updates_tx: smol::channel::Sender<()>, + context_updates_rx: smol::channel::Receiver<()>, + last_prediction_refresh: Option<(EntityId, Instant)>, + cancelled_predictions: HashSet, + context: Entity, + license_detection_watchers: HashMap>, + _subscription: gpui::Subscription, +} + +impl ProjectState { + pub fn events(&self, cx: &App) -> Vec> { + self.events + .iter() + .cloned() + .chain( + self.last_event + .as_ref() + .and_then(|event| event.finalize(&self.license_detection_watchers, cx)), + ) + .collect() } - pub fn is_enabled(&self) -> bool { - matches!(self, DataCollectionState::Enabled { .. }) + fn cancel_pending_prediction( + &mut self, + pending_prediction: PendingPrediction, + cx: &mut Context, + ) { + self.cancelled_predictions.insert(pending_prediction.id); + + cx.spawn(async move |this, cx| { + let Some(prediction_id) = pending_prediction.task.await else { + return; + }; + + this.update(cx, |this, _cx| { + this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false); + }) + .ok(); + }) + .detach() } +} + +#[derive(Debug, Clone)] +struct CurrentEditPrediction { + pub requested_by: PredictionRequestedBy, + pub prediction: EditPrediction, + pub was_shown: bool, +} + +impl CurrentEditPrediction { + fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { + let Some(new_edits) = self + .prediction + .interpolate(&self.prediction.buffer.read(cx)) + else { + return false; + }; - pub fn is_project_open_source(&self) -> bool { + if self.prediction.buffer != old_prediction.prediction.buffer { + return true; + } + + let Some(old_edits) = old_prediction + .prediction + .interpolate(&old_prediction.prediction.buffer.read(cx)) + else { + return true; + }; + + let requested_by_buffer_id = self.requested_by.buffer_id(); + + // This reduces the occurrence of UI thrash from replacing edits + // + // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. + if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) + && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) + && old_edits.len() == 1 + && new_edits.len() == 1 + { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text.as_ref()) + } else { + true + } + } +} + +#[derive(Debug, Clone)] +enum PredictionRequestedBy { + DiagnosticsUpdate, + Buffer(EntityId), +} + +impl PredictionRequestedBy { + pub fn buffer_id(&self) -> Option { match self { - Self::Enabled { - is_project_open_source, - } - | Self::Disabled { - is_project_open_source, - } => *is_project_open_source, - _ => false, + PredictionRequestedBy::DiagnosticsUpdate => None, + PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), } } } -pub trait EditPredictionProvider: 'static + Sized { - fn name() -> &'static str; - fn display_name() -> &'static str; - fn show_completions_in_menu() -> bool; - fn show_tab_accept_marker() -> bool { - false +#[derive(Debug)] +struct PendingPrediction { + id: usize, + task: Task>, +} + +/// A prediction from the perspective of a buffer. +#[derive(Debug)] +enum BufferEditPrediction<'a> { + Local { prediction: &'a EditPrediction }, + Jump { prediction: &'a EditPrediction }, +} + +#[cfg(test)] +impl std::ops::Deref for BufferEditPrediction<'_> { + type Target = EditPrediction; + + fn deref(&self) -> &Self::Target { + match self { + BufferEditPrediction::Local { prediction } => prediction, + BufferEditPrediction::Jump { prediction } => prediction, + } } - fn supports_jump_to_edit() -> bool { - true +} + +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +struct LastEvent { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + end_edit_anchor: Option, +} + +impl LastEvent { + pub fn finalize( + &self, + license_detection_watchers: &HashMap>, + cx: &App, + ) -> Option> { + let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); + let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); + + let file = self.new_snapshot.file(); + let old_file = self.old_snapshot.file(); + + let in_open_source_repo = [file, old_file].iter().all(|file| { + file.is_some_and(|file| { + license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + }) + }); + + let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text()); + + if path == old_path && diff.is_empty() { + None + } else { + Some(Arc::new(predict_edits_v3::Event::BufferChange { + old_path, + path, + diff, + in_open_source_repo, + // TODO: Actually detect if this edit was predicted or not + predicted: false, + })) + } } +} - fn data_collection_state(&self, _cx: &App) -> DataCollectionState { - DataCollectionState::Unsupported +fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc { + if let Some(file) = snapshot.file() { + file.full_path(cx).into() + } else { + Path::new(&format!("untitled-{}", snapshot.remote_id())).into() } +} - fn usage(&self, _cx: &App) -> Option { - None +impl EditPredictionStore { + pub fn try_global(cx: &App) -> Option> { + cx.try_global::() + .map(|global| global.0.clone()) } - fn toggle_data_collection(&mut self, _cx: &mut App) {} - fn is_enabled( - &self, + pub fn global( + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Entity { + cx.try_global::() + .map(|global| global.0.clone()) + .unwrap_or_else(|| { + let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); + cx.set_global(EditPredictionStoreGlobal(ep_store.clone())); + ep_store + }) + } + + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + let data_collection_choice = Self::load_data_collection_choice(); + + let llm_token = LlmApiToken::default(); + + let (reject_tx, reject_rx) = mpsc::unbounded(); + cx.background_spawn({ + let client = client.clone(); + let llm_token = llm_token.clone(); + let app_version = AppVersion::global(cx); + let background_executor = cx.background_executor().clone(); + async move { + Self::handle_rejected_predictions( + reject_rx, + client, + llm_token, + app_version, + background_executor, + ) + .await + } + }) + .detach(); + + let mut this = Self { + projects: HashMap::default(), + client, + user_store, + options: DEFAULT_OPTIONS, + use_context: false, + llm_token, + _llm_token_subscription: cx.subscribe( + &refresh_llm_token_listener, + |this, _listener, _event, cx| { + let client = this.client.clone(); + let llm_token = this.llm_token.clone(); + cx.spawn(async move |_this, _cx| { + llm_token.refresh(&client).await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }, + ), + update_required: false, + debug_tx: None, + #[cfg(feature = "eval-support")] + eval_cache: None, + edit_prediction_model: EditPredictionModel::Zeta2, + sweep_ai: SweepAi::new(cx), + data_collection_choice, + reject_predictions_tx: reject_tx, + rated_predictions: Default::default(), + shown_predictions: Default::default(), + }; + + this.enable_or_disable_context_retrieval(cx); + let weak_this = cx.weak_entity(); + cx.on_flags_ready(move |_, cx| { + weak_this + .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx)) + .ok(); + }) + .detach(); + cx.observe_global::(|this, cx| { + this.enable_or_disable_context_retrieval(cx); + }) + .detach(); + + this + } + + pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { + self.edit_prediction_model = model; + } + + pub fn has_sweep_api_token(&self) -> bool { + self.sweep_ai + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() + } + + #[cfg(feature = "eval-support")] + pub fn with_eval_cache(&mut self, cache: Arc) { + self.eval_cache = Some(cache); + } + + pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { + let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); + self.debug_tx = Some(debug_watch_tx); + debug_watch_rx + } + + pub fn options(&self) -> &ZetaOptions { + &self.options + } + + pub fn set_options(&mut self, options: ZetaOptions) { + self.options = options; + } + + pub fn set_use_context(&mut self, use_context: bool) { + self.use_context = use_context; + } + + pub fn clear_history(&mut self) { + for project_state in self.projects.values_mut() { + project_state.events.clear(); + } + } + + pub fn context_for_project<'a>( + &'a self, + project: &Entity, + cx: &'a App, + ) -> &'a [RelatedFile] { + self.projects + .get(&project.entity_id()) + .map(|project| project.context.read(cx).related_files()) + .unwrap_or(&[]) + } + + pub fn usage(&self, cx: &App) -> Option { + if self.edit_prediction_model == EditPredictionModel::Zeta2 { + self.user_store.read(cx).edit_prediction_usage() + } else { + None + } + } + + pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { + self.get_or_init_project(project, cx); + } + + pub fn register_buffer( + &mut self, buffer: &Entity, - cursor_position: language::Anchor, - cx: &App, - ) -> bool; - fn is_refreshing(&self, cx: &App) -> bool; - fn refresh( + project: &Entity, + cx: &mut Context, + ) { + let project_state = self.get_or_init_project(project, cx); + Self::register_buffer_impl(project_state, buffer, project, cx); + } + + fn get_or_init_project( &mut self, - buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, + project: &Entity, cx: &mut Context, - ); - fn cycle( + ) -> &mut ProjectState { + let entity_id = project.entity_id(); + let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); + self.projects + .entry(entity_id) + .or_insert_with(|| ProjectState { + context: { + let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx)); + cx.subscribe( + &related_excerpt_store, + move |this, _, event, _| match event { + RelatedExcerptStoreEvent::StartedRefresh => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalStarted( + ContextRetrievalStartedDebugEvent { + project_entity_id: entity_id, + timestamp: Instant::now(), + search_prompt: String::new(), + }, + )) + .ok(); + } + } + RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + } => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalFinished( + ContextRetrievalFinishedDebugEvent { + project_entity_id: entity_id, + timestamp: Instant::now(), + metadata: vec![ + ( + "Cache Hits", + format!( + "{}/{}", + cache_hit_count, + cache_hit_count + cache_miss_count + ) + .into(), + ), + ( + "Max LSP Time", + format!( + "{} ms", + max_definition_latency.as_millis() + ) + .into(), + ), + ( + "Mean LSP Time", + format!( + "{} ms", + mean_definition_latency.as_millis() + ) + .into(), + ), + ], + }, + )) + .ok(); + } + if let Some(project_state) = this.projects.get(&entity_id) { + project_state.context_updates_tx.send_blocking(()).ok(); + } + } + }, + ) + .detach(); + related_excerpt_store + }, + events: VecDeque::new(), + last_event: None, + recent_paths: VecDeque::new(), + context_updates_rx, + context_updates_tx, + registered_buffers: HashMap::default(), + current_prediction: None, + cancelled_predictions: HashSet::default(), + pending_predictions: ArrayVec::new(), + next_pending_prediction_id: 0, + last_prediction_refresh: None, + license_detection_watchers: HashMap::default(), + _subscription: cx.subscribe(&project, Self::handle_project_event), + }) + } + + pub fn project_context_updates( + &self, + project: &Entity, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; + Some(project_state.context_updates_rx.clone()) + } + + fn handle_project_event( &mut self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, + project: Entity, + event: &project::Event, cx: &mut Context, - ); - fn accept(&mut self, cx: &mut Context); - fn discard(&mut self, cx: &mut Context); - fn did_show(&mut self, _cx: &mut Context) {} - fn suggest( + ) { + // TODO [zeta2] init with recent paths + match event { + project::Event::ActiveEntryChanged(Some(active_entry_id)) => { + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = project_state + .recent_paths + .iter() + .position(|probe| probe == &path) + { + project_state.recent_paths.remove(ix); + } + project_state.recent_paths.push_front(path); + } + } + project::Event::DiagnosticsUpdated { .. } => { + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics(project, cx); + } + } + _ => (), + } + } + + fn register_buffer_impl<'a>( + project_state: &'a mut ProjectState, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + + if let Some(file) = buffer.read(cx).file() { + let worktree_id = file.worktree_id(cx); + if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) { + project_state + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| { + let project_entity_id = project.entity_id(); + cx.observe_release(&worktree, move |this, _worktree, _cx| { + let Some(project_state) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + project_state + .license_detection_watchers + .remove(&worktree_id); + }) + .detach(); + Rc::new(LicenseDetectionWatcher::new(&worktree, cx)) + }); + } + } + + match project_state.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(project_state) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + project_state.registered_buffers.remove(&buffer_id); + }), + ], + }) + } + } + } + + fn report_changes_for_buffer( &mut self, buffer: &Entity, - cursor_position: language::Anchor, + project: &Entity, cx: &mut Context, - ) -> Option; -} + ) { + let project_state = self.get_or_init_project(project, cx); + let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx); + + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version == registered_buffer.snapshot.version { + return; + } + + let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + let end_edit_anchor = new_snapshot + .anchored_edits_since::(&old_snapshot.version) + .last() + .map(|(_, range)| range.end); + let events = &mut project_state.events; -pub trait EditPredictionProviderHandle { - fn name(&self) -> &'static str; - fn display_name(&self) -> &'static str; - fn is_enabled( + if let Some(LastEvent { + new_snapshot: last_new_snapshot, + end_edit_anchor: last_end_edit_anchor, + .. + }) = project_state.last_event.as_mut() + { + let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() + == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version; + + let should_coalesce = is_next_snapshot_of_same_buffer + && end_edit_anchor + .as_ref() + .zip(last_end_edit_anchor.as_ref()) + .is_some_and(|(a, b)| { + let a = a.to_point(&new_snapshot); + let b = b.to_point(&new_snapshot); + a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN + }); + + if should_coalesce { + *last_end_edit_anchor = end_edit_anchor; + *last_new_snapshot = new_snapshot; + return; + } + } + + if events.len() + 1 >= EVENT_COUNT_MAX { + events.pop_front(); + } + + if let Some(event) = project_state.last_event.take() { + events.extend(event.finalize(&project_state.license_detection_watchers, cx)); + } + + project_state.last_event = Some(LastEvent { + old_snapshot, + new_snapshot, + end_edit_anchor, + }); + } + + fn current_prediction_for_buffer( &self, buffer: &Entity, - cursor_position: language::Anchor, + project: &Entity, cx: &App, - ) -> bool; - fn show_completions_in_menu(&self) -> bool; - fn show_tab_accept_marker(&self) -> bool; - fn supports_jump_to_edit(&self) -> bool; - fn data_collection_state(&self, cx: &App) -> DataCollectionState; - fn usage(&self, cx: &App) -> Option; - fn toggle_data_collection(&self, cx: &mut App); - fn is_refreshing(&self, cx: &App) -> bool; - fn refresh( - &self, - buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, - cx: &mut App, - ); - fn cycle( - &self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, - cx: &mut App, - ); - fn did_show(&self, cx: &mut App); - fn accept(&self, cx: &mut App); - fn discard(&self, cx: &mut App); - fn suggest( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut App, - ) -> Option; -} + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; -impl EditPredictionProviderHandle for Entity -where - T: EditPredictionProvider, -{ - fn name(&self) -> &'static str { - T::name() - } + let CurrentEditPrediction { + requested_by, + prediction, + .. + } = project_state.current_prediction.as_ref()?; - fn display_name(&self) -> &'static str { - T::display_name() - } + if prediction.targets_buffer(buffer.read(cx)) { + Some(BufferEditPrediction::Local { prediction }) + } else { + let show_jump = match requested_by { + PredictionRequestedBy::Buffer(requested_by_buffer_id) => { + requested_by_buffer_id == &buffer.entity_id() + } + PredictionRequestedBy::DiagnosticsUpdate => true, + }; - fn show_completions_in_menu(&self) -> bool { - T::show_completions_in_menu() + if show_jump { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None + } + } } - fn show_tab_accept_marker(&self) -> bool { - T::show_tab_accept_marker() - } + fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + match self.edit_prediction_model { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Sweep => return, + } + + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + let Some(prediction) = project_state.current_prediction.take() else { + return; + }; + let request_id = prediction.prediction.id.to_string(); + for pending_prediction in mem::take(&mut project_state.pending_predictions) { + project_state.cancel_pending_prediction(pending_prediction, cx); + } + + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + cx.spawn(async move |this, cx| { + let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/accept", &[])? + }; + + let response = cx + .background_spawn(Self::send_api_request::<()>( + move |builder| { + let req = builder.uri(url.as_ref()).body( + serde_json::to_string(&AcceptEditPredictionBody { + request_id: request_id.clone(), + })? + .into(), + ); + Ok(req?) + }, + client, + llm_token, + app_version, + )) + .await; - fn supports_jump_to_edit(&self) -> bool { - T::supports_jump_to_edit() + Self::handle_api_response(&this, response, cx)?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } - fn data_collection_state(&self, cx: &App) -> DataCollectionState { - self.read(cx).data_collection_state(cx) + async fn handle_rejected_predictions( + rx: UnboundedReceiver, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + background_executor: BackgroundExecutor, + ) { + let mut rx = std::pin::pin!(rx.peekable()); + let mut batched = Vec::new(); + + while let Some(rejection) = rx.next().await { + batched.push(rejection); + + if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { + select_biased! { + next = rx.as_mut().peek().fuse() => { + if next.is_some() { + continue; + } + } + () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {}, + } + } + + let url = client + .http_client() + .build_zed_llm_url("/predict_edits/reject", &[]) + .unwrap(); + + let flush_count = batched + .len() + // in case items have accumulated after failure + .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST); + let start = batched.len() - flush_count; + + let body = RejectEditPredictionsBodyRef { + rejections: &batched[start..], + }; + + let result = Self::send_api_request::<()>( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&body)?.into()); + anyhow::Ok(req?) + }, + client.clone(), + llm_token.clone(), + app_version.clone(), + ) + .await; + + if result.log_err().is_some() { + batched.drain(start..); + } + } } - fn usage(&self, cx: &App) -> Option { - self.read(cx).usage(cx) + fn reject_current_prediction( + &mut self, + reason: EditPredictionRejectReason, + project: &Entity, + ) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.pending_predictions.clear(); + if let Some(prediction) = project_state.current_prediction.take() { + self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown); + } + }; } - fn toggle_data_collection(&self, cx: &mut App) { - self.update(cx, |this, cx| this.toggle_data_collection(cx)) + fn did_show_current_prediction(&mut self, project: &Entity, _cx: &mut Context) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + if let Some(current_prediction) = project_state.current_prediction.as_mut() { + if !current_prediction.was_shown { + current_prediction.was_shown = true; + self.shown_predictions + .push_front(current_prediction.prediction.clone()); + if self.shown_predictions.len() > 50 { + let completion = self.shown_predictions.pop_back().unwrap(); + self.rated_predictions.remove(&completion.id); + } + } + } + } } - fn is_enabled( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &App, - ) -> bool { - self.read(cx).is_enabled(buffer, cursor_position, cx) + fn reject_prediction( + &mut self, + prediction_id: EditPredictionId, + reason: EditPredictionRejectReason, + was_shown: bool, + ) { + match self.edit_prediction_model { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Sweep => return, + } + + self.reject_predictions_tx + .unbounded_send(EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + }) + .log_err(); } - fn is_refreshing(&self, cx: &App) -> bool { - self.read(cx).is_refreshing(cx) + fn is_refreshing(&self, project: &Entity) -> bool { + self.projects + .get(&project.entity_id()) + .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) } - fn refresh( - &self, + pub fn refresh_prediction_from_buffer( + &mut self, + project: Entity, buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, - cx: &mut App, + position: language::Anchor, + cx: &mut Context, ) { - self.update(cx, |this, cx| { - this.refresh(buffer, cursor_position, debounce, cx) + self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &buffer, + position, + PredictEditsRequestTrigger::Other, + cx, + ) + }) + .log_err() + else { + return Task::ready(anyhow::Ok(None)); + }; + + cx.spawn(async move |_cx| { + request_task.await.map(|prediction_result| { + prediction_result.map(|prediction_result| { + ( + prediction_result, + PredictionRequestedBy::Buffer(buffer.entity_id()), + ) + }) + }) + }) }) } - fn cycle( - &self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, - cx: &mut App, + pub fn refresh_prediction_from_diagnostics( + &mut self, + project: Entity, + cx: &mut Context, ) { - self.update(cx, |this, cx| { - this.cycle(buffer, cursor_position, direction, cx) + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + // Prefer predictions from buffer + if project_state.current_prediction.is_some() { + return; + }; + + self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { + let Some(open_buffer_task) = project + .update(cx, |project, cx| { + project + .active_entry() + .and_then(|entry| project.path_for_entry(entry, cx)) + .map(|path| project.open_buffer(path, cx)) + }) + .log_err() + .flatten() + else { + return Task::ready(anyhow::Ok(None)); + }; + + cx.spawn(async move |cx| { + let active_buffer = open_buffer_task.await?; + let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + Default::default(), + Default::default(), + &project, + cx, + ) + .await? + else { + return anyhow::Ok(None); + }; + + let Some(prediction_result) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &jump_buffer, + jump_position, + PredictEditsRequestTrigger::Diagnostics, + cx, + ) + })? + .await? + else { + return anyhow::Ok(None); + }; + + this.update(cx, |this, cx| { + Some(( + if this + .get_or_init_project(&project, cx) + .current_prediction + .is_none() + { + prediction_result + } else { + EditPredictionResult { + id: prediction_result.id, + prediction: Err(EditPredictionRejectReason::CurrentPreferred), + } + }, + PredictionRequestedBy::DiagnosticsUpdate, + )) + }) + }) + }); + } + + #[cfg(not(test))] + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + #[cfg(test)] + pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; + + fn queue_prediction_refresh( + &mut self, + project: Entity, + throttle_entity: EntityId, + cx: &mut Context, + do_refresh: impl FnOnce( + WeakEntity, + &mut AsyncApp, + ) + -> Task>> + + 'static, + ) { + let project_state = self.get_or_init_project(&project, cx); + let pending_prediction_id = project_state.next_pending_prediction_id; + project_state.next_pending_prediction_id += 1; + let last_request = project_state.last_prediction_refresh; + + let task = cx.spawn(async move |this, cx| { + if let Some((last_entity, last_timestamp)) = last_request + && throttle_entity == last_entity + && let Some(timeout) = + (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } + + // If this task was cancelled before the throttle timeout expired, + // do not perform a request. + let mut is_cancelled = true; + this.update(cx, |this, cx| { + let project_state = this.get_or_init_project(&project, cx); + if !project_state + .cancelled_predictions + .remove(&pending_prediction_id) + { + project_state.last_prediction_refresh = Some((throttle_entity, Instant::now())); + is_cancelled = false; + } + }) + .ok(); + if is_cancelled { + return None; + } + + let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten(); + let new_prediction_id = new_prediction_result + .as_ref() + .map(|(prediction, _)| prediction.id.clone()); + + // When a prediction completes, remove it from the pending list, and cancel + // any pending predictions that were enqueued before it. + this.update(cx, |this, cx| { + let project_state = this.get_or_init_project(&project, cx); + + let is_cancelled = project_state + .cancelled_predictions + .remove(&pending_prediction_id); + + let new_current_prediction = if !is_cancelled + && let Some((prediction_result, requested_by)) = new_prediction_result + { + match prediction_result.prediction { + Ok(prediction) => { + let new_prediction = CurrentEditPrediction { + requested_by, + prediction, + was_shown: false, + }; + + if let Some(current_prediction) = + project_state.current_prediction.as_ref() + { + if new_prediction.should_replace_prediction(¤t_prediction, cx) + { + this.reject_current_prediction( + EditPredictionRejectReason::Replaced, + &project, + ); + + Some(new_prediction) + } else { + this.reject_prediction( + new_prediction.prediction.id, + EditPredictionRejectReason::CurrentPreferred, + false, + ); + None + } + } else { + Some(new_prediction) + } + } + Err(reject_reason) => { + this.reject_prediction(prediction_result.id, reject_reason, false); + None + } + } + } else { + None + }; + + let project_state = this.get_or_init_project(&project, cx); + + if let Some(new_prediction) = new_current_prediction { + project_state.current_prediction = Some(new_prediction); + } + + let mut pending_predictions = mem::take(&mut project_state.pending_predictions); + for (ix, pending_prediction) in pending_predictions.iter().enumerate() { + if pending_prediction.id == pending_prediction_id { + pending_predictions.remove(ix); + for pending_prediction in pending_predictions.drain(0..ix) { + project_state.cancel_pending_prediction(pending_prediction, cx) + } + break; + } + } + this.get_or_init_project(&project, cx).pending_predictions = pending_predictions; + cx.notify(); + }) + .ok(); + + new_prediction_id + }); + + if project_state.pending_predictions.len() <= 1 { + project_state.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + } else if project_state.pending_predictions.len() == 2 { + let pending_prediction = project_state.pending_predictions.pop().unwrap(); + project_state.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + project_state.cancel_pending_prediction(pending_prediction, cx); + } + } + + pub fn request_prediction( + &mut self, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + trigger: PredictEditsRequestTrigger, + cx: &mut Context, + ) -> Task>> { + self.request_prediction_internal( + project.clone(), + active_buffer.clone(), + position, + trigger, + cx.has_flag::(), + cx, + ) + } + + fn request_prediction_internal( + &mut self, + project: Entity, + active_buffer: Entity, + position: language::Anchor, + trigger: PredictEditsRequestTrigger, + allow_jump: bool, + cx: &mut Context, + ) -> Task>> { + const DIAGNOSTIC_LINES_RANGE: u32 = 20; + + self.get_or_init_project(&project, cx); + let project_state = self.projects.get(&project.entity_id()).unwrap(); + let events = project_state.events(cx); + let has_events = !events.is_empty(); + + let snapshot = active_buffer.read(cx).snapshot(); + let cursor_point = position.to_point(&snapshot); + let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + let diagnostic_search_range = + Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); + + let related_files = if self.use_context { + self.context_for_project(&project, cx).to_vec() + } else { + Vec::new() + }; + + let task = match self.edit_prediction_model { + EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1( + self, + &project, + &active_buffer, + snapshot.clone(), + position, + events, + trigger, + cx, + ), + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2( + self, + &project, + &active_buffer, + snapshot.clone(), + position, + events, + related_files, + trigger, + cx, + ), + EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep( + &project, + &active_buffer, + snapshot.clone(), + position, + events, + &project_state.recent_paths, + related_files, + diagnostic_search_range.clone(), + cx, + ), + }; + + cx.spawn(async move |this, cx| { + let prediction = task.await?; + + if prediction.is_none() && allow_jump { + let cursor_point = position.to_point(&snapshot); + if has_events + && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer.clone(), + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) + .await? + { + return this + .update(cx, |this, cx| { + this.request_prediction_internal( + project, + jump_buffer, + jump_position, + trigger, + false, + cx, + ) + })? + .await; + } + + return anyhow::Ok(None); + } + + Ok(prediction) }) } - fn accept(&self, cx: &mut App) { - self.update(cx, |this, cx| this.accept(cx)) + async fn next_diagnostic_location( + active_buffer: Entity, + active_buffer_snapshot: &BufferSnapshot, + active_buffer_diagnostic_search_range: Range, + active_buffer_cursor_point: Point, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result, language::Anchor)>> { + // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request + let mut jump_location = active_buffer_snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(_, group)| { + let range = &group.entries[group.primary_ix] + .range + .to_point(&active_buffer_snapshot); + if range.overlaps(&active_buffer_diagnostic_search_range) { + None + } else { + Some(range.start) + } + }) + .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) + .map(|position| { + ( + active_buffer.clone(), + active_buffer_snapshot.anchor_before(position), + ) + }); + + if jump_location.is_none() { + let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { + let file = buffer.file()?; + + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + })?; + + let buffer_task = project.update(cx, |project, cx| { + let (path, _, _) = project + .diagnostic_summaries(false, cx) + .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) + .max_by_key(|(path, _, _)| { + // find the buffer with errors that shares most parent directories + path.path + .components() + .zip( + active_buffer_path + .as_ref() + .map(|p| p.path.components()) + .unwrap_or_default(), + ) + .take_while(|(a, b)| a == b) + .count() + })?; + + Some(project.open_buffer(path, cx)) + })?; + + if let Some(buffer_task) = buffer_task { + let closest_buffer = buffer_task.await?; + + jump_location = closest_buffer + .read_with(cx, |buffer, _cx| { + buffer + .buffer_diagnostics(None) + .into_iter() + .min_by_key(|entry| entry.diagnostic.severity) + .map(|entry| entry.range.start) + })? + .map(|position| (closest_buffer, position)); + } + } + + anyhow::Ok(jump_location) } - fn discard(&self, cx: &mut App) { - self.update(cx, |this, cx| this.discard(cx)) + async fn send_raw_llm_request( + request: open_ai::Request, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + #[cfg(feature = "eval-support")] eval_cache: Option>, + #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, + ) -> Result<(open_ai::Response, Option)> { + let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/raw", &[])? + }; + + #[cfg(feature = "eval-support")] + let cache_key = if let Some(cache) = eval_cache { + use collections::FxHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = FxHasher::default(); + url.hash(&mut hasher); + let request_str = serde_json::to_string_pretty(&request)?; + request_str.hash(&mut hasher); + let hash = hasher.finish(); + + let key = (eval_cache_kind, hash); + if let Some(response_str) = cache.read(key) { + return Ok((serde_json::from_str(&response_str)?, None)); + } + + Some((cache, request_str, key)) + } else { + None + }; + + let (response, usage) = Self::send_api_request( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&request)?.into()); + Ok(req?) + }, + client, + llm_token, + app_version, + ) + .await?; + + #[cfg(feature = "eval-support")] + if let Some((cache, request, key)) = cache_key { + cache.write(key, &request, &serde_json::to_string_pretty(&response)?); + } + + Ok((response, usage)) } - fn did_show(&self, cx: &mut App) { - self.update(cx, |this, cx| this.did_show(cx)) + fn handle_api_response( + this: &WeakEntity, + response: Result<(T, Option)>, + cx: &mut gpui::AsyncApp, + ) -> Result { + match response { + Ok((data, usage)) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + Ok(data) + } + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }) + .ok(); + } + Err(err) + } + } } - fn suggest( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut App, - ) -> Option { - self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) - } -} - -/// Returns edits updated based on user edits since the old snapshot. None is returned if any user -/// edit is not a prefix of a predicted insertion. -pub fn interpolate_edits( - old_snapshot: &BufferSnapshot, - new_snapshot: &BufferSnapshot, - current_edits: &[(Range, Arc)], -) -> Option, Arc)>> { - let mut edits = Vec::new(); - - let mut model_edits = current_edits.iter().peekable(); - for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { - while let Some((model_old_range, _)) = model_edits.peek() { - let model_old_range = model_old_range.to_offset(old_snapshot); - if model_old_range.end < user_edit.old.start { - let (model_old_range, model_new_text) = model_edits.next().unwrap(); - edits.push((model_old_range.clone(), model_new_text.clone())); + async fn send_api_request( + build: impl Fn(http_client::http::request::Builder) -> Result>, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + ) -> Result<(Res, Option)> + where + Res: DeserializeOwned, + { + let http_client = client.http_client(); + let mut token = llm_token.acquire(&client).await?; + let mut did_retry = false; + + loop { + let request_builder = http_client::Request::builder().method(Method::POST); + + let request = build( + request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), + )?; + + let mut response = http_client.send(request).await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) + { + anyhow::ensure!( + app_version >= minimum_required_version, + ZedUpdateRequiredError { + minimum_version: minimum_required_version + } + ); + } + + if response.status().is_success() { + let usage = EditPredictionUsage::from_headers(response.headers()).ok(); + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + return Ok((serde_json::from_slice(&body)?, usage)); + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(&client).await?; } else { - break; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + body + ); } } + } - if let Some((model_old_range, model_new_text)) = model_edits.peek() { - let model_old_offset_range = model_old_range.to_offset(old_snapshot); - if user_edit.old == model_old_offset_range { - let user_new_text = new_snapshot - .text_for_range(user_edit.new.clone()) - .collect::(); + pub fn refresh_context( + &mut self, + project: &Entity, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) { + if self.use_context { + self.get_or_init_project(project, cx) + .context + .update(cx, |store, cx| { + store.refresh(buffer.clone(), cursor_position, cx); + }); + } + } - if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { - if !model_suffix.is_empty() { - let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.into())); - } + fn is_file_open_source( + &self, + project: &Entity, + file: &Arc, + cx: &App, + ) -> bool { + if !file.is_local() || file.is_private() { + return false; + } + let Some(project_state) = self.projects.get(&project.entity_id()) else { + return false; + }; + project_state + .license_detection_watchers + .get(&file.worktree_id(cx)) + .as_ref() + .is_some_and(|watcher| watcher.is_project_open_source()) + } - model_edits.next(); - continue; + fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { + self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx) + } + + fn can_collect_events(&self, events: &[Arc]) -> bool { + if !self.data_collection_choice.is_enabled() { + return false; + } + events.iter().all(|event| { + matches!( + event.as_ref(), + Event::BufferChange { + in_open_source_repo: true, + .. } + ) + }) + } + + fn load_data_collection_choice() -> DataCollectionChoice { + let choice = KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .flatten(); + + match choice.as_deref() { + Some("true") => DataCollectionChoice::Enabled, + Some("false") => DataCollectionChoice::Disabled, + Some(_) => { + log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'"); + DataCollectionChoice::NotAnswered } + None => DataCollectionChoice::NotAnswered, + } + } + + fn toggle_data_collection_choice(&mut self, cx: &mut Context) { + self.data_collection_choice = self.data_collection_choice.toggle(); + let new_choice = self.data_collection_choice; + db::write_and_log(cx, move || { + KEY_VALUE_STORE.write_kvp( + ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), + new_choice.is_enabled().to_string(), + ) + }); + } + + pub fn shown_predictions(&self) -> impl DoubleEndedIterator { + self.shown_predictions.iter() + } + + pub fn shown_completions_len(&self) -> usize { + self.shown_predictions.len() + } + + pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool { + self.rated_predictions.contains(id) + } + + pub fn rate_prediction( + &mut self, + prediction: &EditPrediction, + rating: EditPredictionRating, + feedback: String, + cx: &mut Context, + ) { + self.rated_predictions.insert(prediction.id.clone()); + telemetry::event!( + "Edit Prediction Rated", + rating, + inputs = prediction.inputs, + output = prediction.edit_preview.as_unified_diff(&prediction.edits), + feedback + ); + self.client.telemetry().flush_events().detach(); + cx.notify(); + } + + fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) { + self.use_context = cx.has_flag::() + && all_language_settings(None, cx).edit_predictions.use_context; + } +} + +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: Version, +} + +#[cfg(feature = "eval-support")] +pub type EvalCacheKey = (EvalCacheEntryKind, u64); + +#[cfg(feature = "eval-support")] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EvalCacheEntryKind { + Context, + Search, + Prediction, +} + +#[cfg(feature = "eval-support")] +impl std::fmt::Display for EvalCacheEntryKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EvalCacheEntryKind::Search => write!(f, "search"), + EvalCacheEntryKind::Context => write!(f, "context"), + EvalCacheEntryKind::Prediction => write!(f, "prediction"), + } + } +} + +#[cfg(feature = "eval-support")] +pub trait EvalCache: Send + Sync { + fn read(&self, key: EvalCacheKey) -> Option; + fn write(&self, key: EvalCacheKey, input: &str, value: &str); +} + +#[derive(Debug, Clone, Copy)] +pub enum DataCollectionChoice { + NotAnswered, + Enabled, + Disabled, +} + +impl DataCollectionChoice { + pub fn is_enabled(self) -> bool { + match self { + Self::Enabled => true, + Self::NotAnswered | Self::Disabled => false, } + } - return None; + pub fn is_answered(self) -> bool { + match self { + Self::Enabled | Self::Disabled => true, + Self::NotAnswered => false, + } } - edits.extend(model_edits.cloned()); + #[must_use] + pub fn toggle(&self) -> DataCollectionChoice { + match self { + Self::Enabled => Self::Disabled, + Self::Disabled => Self::Enabled, + Self::NotAnswered => Self::Enabled, + } + } +} + +impl From for DataCollectionChoice { + fn from(value: bool) -> Self { + match value { + true => DataCollectionChoice::Enabled, + false => DataCollectionChoice::Disabled, + } + } +} + +struct ZedPredictUpsell; + +impl Dismissable for ZedPredictUpsell { + const KEY: &'static str = "dismissed-edit-predict-upsell"; + + fn dismissed() -> bool { + // To make this backwards compatible with older versions of Zed, we + // check if the user has seen the previous Edit Prediction Onboarding + // before, by checking the data collection choice which was written to + // the database once the user clicked on "Accept and Enable" + if KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .is_some_and(|s| s.is_some()) + { + return true; + } + + KEY_VALUE_STORE + .read_kvp(Self::KEY) + .log_err() + .is_some_and(|s| s.is_some()) + } +} + +pub fn should_show_upsell_modal() -> bool { + !ZedPredictUpsell::dismissed() +} + +pub fn init(cx: &mut App) { + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { + workspace.register_action( + move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| { + ZedPredictModal::toggle( + workspace, + workspace.user_store().clone(), + workspace.client().clone(), + window, + cx, + ) + }, + ); - if edits.is_empty() { None } else { Some(edits) } + workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| { + update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .edit_prediction_provider = Some(EditPredictionProvider::None) + }); + }); + }) + .detach(); } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d5bad9ed8990769fd512ecfe523cf8d79aebca6 --- /dev/null +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -0,0 +1,1806 @@ +use super::*; +use crate::zeta1::MAX_EVENT_TOKENS; +use client::{UserStore, test::FakeServer}; +use clock::{FakeSystemClock, ReplicaId}; +use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; +use cloud_llm_client::{ + EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, + RejectEditPredictionsBody, +}; +use edit_prediction_context::Line; +use futures::{ + AsyncReadExt, StreamExt, + channel::{mpsc, oneshot}, +}; +use gpui::{ + Entity, TestAppContext, + http_client::{FakeHttpClient, Response}, +}; +use indoc::indoc; +use language::{Point, ToOffset as _}; +use lsp::LanguageServerId; +use open_ai::Usage; +use parking_lot::Mutex; +use pretty_assertions::{assert_eq, assert_matches}; +use project::{FakeFs, Project}; +use serde_json::json; +use settings::SettingsStore; +use std::{path::Path, sync::Arc, time::Duration}; +use util::{path, rel_path::rel_path}; +use uuid::Uuid; + +use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE}; + +#[gpui::test] +async fn test_current_state(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "1.txt": "Hello!\nHow\nBye\n", + "2.txt": "Hola!\nComo\nAdios\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_project(&project, cx); + }); + + let buffer1 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot1.anchor_before(language::Point::new(1, 3)); + + // Prediction for current file + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); + + respond_tx + .send(model_response(indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); + }); + + // Prediction for diagnostic in another file + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + + let (_request, respond_tx) = requests.predict.next().await.unwrap(); + respond_tx + .send(model_response(indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + Hola! + -Como + +Como estas? + Adios + "#})) + .unwrap(); + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!( + prediction, + BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) + ); + }); + + let buffer2 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer2, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); +} + +#[gpui::test] +async fn test_simple_request(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, &buffer, position, Default::default(), cx) + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + + // TODO Put back when we have a structured request again + // assert_eq!( + // request.excerpt_path.as_ref(), + // Path::new(path!("root/foo.md")) + // ); + // assert_eq!( + // request.cursor_point, + // Point { + // line: Line(1), + // column: 3 + // } + // ); + + respond_tx + .send(model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); +} + +#[gpui::test] +async fn test_request_events(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\n\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx); + }); + + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(7..7, "How")], None, cx); + }); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, &buffer, position, Default::default(), cx) + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + + let prompt = prompt_from_request(&request); + assert!( + prompt.contains(indoc! {" + --- a/root/foo.md + +++ b/root/foo.md + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "}), + "{prompt}" + ); + + respond_tx + .send(model_response(indoc! {r#" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "#})) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); +} + +#[gpui::test] +async fn test_empty_prediction(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + const NO_OP_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How + Bye + "}; + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let response = model_response(NO_OP_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::Empty, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_interpolated_empty(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + + buffer.update(cx, |buffer, cx| { + buffer.set_text("Hello!\nHow are you?\nBye", cx); + }); + + let response = model_response(SIMPLE_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::InterpolatedEmpty, + was_shown: false + }] + ); +} + +const SIMPLE_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye +"}; + +#[gpui::test] +async fn test_replace_current(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // second replaces first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as replaced + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_current_preferred(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + // worse than current prediction + let second_response = model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are + Bye + "}); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // first is preferred over second + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: second_id, + reason: EditPredictionRejectReason::CurrentPreferred, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + // start two refresh tasks + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_second) = requests.predict.next().await.unwrap(); + + // wait for throttle + cx.run_until_parked(); + + // second responds first + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_second.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is second + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is still second, since first was cancelled + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + // start two refresh tasks + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_second) = requests.predict.next().await.unwrap(); + + // wait for throttle, so requests are sent + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + // start a third request + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + + // 2 are pending, so 2nd is cancelled + assert_eq!( + ep_store + .get_or_init_project(&project, cx) + .cancelled_predictions + .iter() + .copied() + .collect::>(), + [1] + ); + }); + + // wait for throttle + cx.run_until_parked(); + + let (_, respond_third) = requests.predict.next().await.unwrap(); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let cancelled_response = model_response(SIMPLE_DIFF); + let cancelled_id = cancelled_response.id.clone(); + respond_second.send(cancelled_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is still first, since second was cancelled + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let third_response = model_response(SIMPLE_DIFF); + let third_response_id = third_response.id.clone(); + respond_third.send(third_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // third completes and replaces first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + third_response_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[ + EditPredictionRejection { + request_id: cancelled_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }, + EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + } + ] + ); +} + +#[gpui::test] +async fn test_rejections_flushing(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("test-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + ep_store.reject_prediction( + EditPredictionId("test-2".into()), + EditPredictionRejectReason::Canceled, + true, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + // batched + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!( + reject_request.rejections[0], + EditPredictionRejection { + request_id: "test-1".to_string(), + reason: EditPredictionRejectReason::Discarded, + was_shown: false + } + ); + assert_eq!( + reject_request.rejections[1], + EditPredictionRejection { + request_id: "test-2".to_string(), + reason: EditPredictionRejectReason::Canceled, + was_shown: true + } + ); + + // Reaching batch size limit sends without debounce + ep_store.update(cx, |ep_store, _cx| { + for i in 0..70 { + ep_store.reject_prediction( + EditPredictionId(format!("batch-{}", i).into()), + EditPredictionRejectReason::Discarded, + false, + ); + } + }); + + // First MAX/2 items are sent immediately + cx.run_until_parked(); + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 50); + assert_eq!(reject_request.rejections[0].request_id, "batch-0"); + assert_eq!(reject_request.rejections[49].request_id, "batch-49"); + + // Remaining items are debounced with the next batch + cx.executor().advance_clock(Duration::from_secs(15)); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 20); + assert_eq!(reject_request.rejections[0].request_id, "batch-50"); + assert_eq!(reject_request.rejections[19].request_id, "batch-69"); + + // Request failure + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("retry-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, _respond_tx) = requests.reject.next().await.unwrap(); + assert_eq!(reject_request.rejections.len(), 1); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + // Simulate failure + drop(_respond_tx); + + // Add another rejection + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("retry-2".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + // Retry should include both the failed item and the new one + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + assert_eq!(reject_request.rejections[1].request_id, "retry-2"); +} + +// Skipped until we start including diagnostics in prompt +// #[gpui::test] +// async fn test_request_diagnostics(cx: &mut TestAppContext) { +// let (ep_store, mut req_rx) = init_test_with_fake_client(cx); +// let fs = FakeFs::new(cx.executor()); +// fs.insert_tree( +// "/root", +// json!({ +// "foo.md": "Hello!\nBye" +// }), +// ) +// .await; +// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + +// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); +// let diagnostic = lsp::Diagnostic { +// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), +// severity: Some(lsp::DiagnosticSeverity::ERROR), +// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), +// ..Default::default() +// }; + +// project.update(cx, |project, cx| { +// project.lsp_store().update(cx, |lsp_store, cx| { +// // Create some diagnostics +// lsp_store +// .update_diagnostics( +// LanguageServerId(0), +// lsp::PublishDiagnosticsParams { +// uri: path_to_buffer_uri.clone(), +// diagnostics: vec![diagnostic], +// version: None, +// }, +// None, +// language::DiagnosticSourceKind::Pushed, +// &[], +// cx, +// ) +// .unwrap(); +// }); +// }); + +// let buffer = project +// .update(cx, |project, cx| { +// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); +// project.open_buffer(path, cx) +// }) +// .await +// .unwrap(); + +// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); +// let position = snapshot.anchor_before(language::Point::new(0, 0)); + +// let _prediction_task = ep_store.update(cx, |ep_store, cx| { +// ep_store.request_prediction(&project, &buffer, position, cx) +// }); + +// let (request, _respond_tx) = req_rx.next().await.unwrap(); + +// assert_eq!(request.diagnostic_groups.len(), 1); +// let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) +// .unwrap(); +// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 +// assert_eq!( +// value, +// json!({ +// "entries": [{ +// "range": { +// "start": 8, +// "end": 10 +// }, +// "diagnostic": { +// "source": null, +// "code": null, +// "code_description": null, +// "severity": 1, +// "message": "\"Hello\" deprecated. Use \"Hi\" instead", +// "markdown": null, +// "group_id": 0, +// "is_primary": true, +// "is_disk_based": false, +// "is_unnecessary": false, +// "source_kind": "Pushed", +// "data": null, +// "underline": true +// } +// }], +// "primary_ix": 0 +// }) +// ); +// } + +fn model_response(text: &str) -> open_ai::Response { + open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(text.to_string())), + tool_calls: vec![], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + } +} + +fn prompt_from_request(request: &open_ai::Request) -> &str { + assert_eq!(request.messages.len(), 1); + let open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(content), + .. + } = &request.messages[0] + else { + panic!( + "Request does not have single user message of type Plain. {:#?}", + request + ); + }; + content +} + +struct RequestChannels { + predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, + reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, +} + +fn init_test_with_fake_client( + cx: &mut TestAppContext, +) -> (Entity, RequestChannels) { + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + zlog::init_test(); + + let (predict_req_tx, predict_req_rx) = mpsc::unbounded(); + let (reject_req_tx, reject_req_rx) = mpsc::unbounded(); + + let http_client = FakeHttpClient::create({ + move |req| { + let uri = req.uri().path().to_string(); + let mut body = req.into_body(); + let predict_req_tx = predict_req_tx.clone(); + let reject_req_tx = reject_req_tx.clone(); + async move { + let resp = match uri.as_str() { + "/client/llm_tokens" => serde_json::to_string(&json!({ + "token": "test" + })) + .unwrap(), + "/predict_edits/raw" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + predict_req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + "/predict_edits/reject" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + reject_req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + _ => { + panic!("Unexpected path: {}", uri) + } + }; + + Ok(Response::builder().body(resp.into()).unwrap()) + } + } + }); + + let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); + client.cloud_client().set_credentials(1, "test".into()); + + language_model::init(client.clone(), cx); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let ep_store = EditPredictionStore::global(&client, &user_store, cx); + + ( + ep_store, + RequestChannels { + predict: predict_req_rx, + reject: reject_req_rx, + }, + ) + }) +} + +const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); + +#[gpui::test] +async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); + let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { + to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() + }); + + let edit_preview = cx + .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) + .await; + + let completion = EditPrediction { + edits, + edit_preview, + buffer: buffer.clone(), + snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + id: EditPredictionId("the-id".into()), + inputs: EditPredictionInputs { + events: Default::default(), + included_files: Default::default(), + cursor_point: cloud_llm_client::predict_edits_v3::Point { + line: Line(0), + column: 0, + }, + cursor_path: Path::new("").into(), + }, + buffer_snapshotted_at: Instant::now(), + response_received_at: Instant::now(), + }; + + cx.update(|cx| { + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..2, "REM".into()), (6..8, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(3..3, "EM".into()), (7..9, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); + assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); + }) +} + +#[gpui::test] +async fn test_clean_up_diff(cx: &mut TestAppContext) { + init_test(cx); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word.len()..word.len(); + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + "}, + ); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let story = \"the quick\" + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + "}, + ); +} + +#[gpui::test] +async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let buffer_content = "lorem\n"; + let completion_response = indoc! {" + ```animals.js + <|start_of_file|> + <|editable_region_start|> + lorem + ipsum + <|editable_region_end|> + ```"}; + + assert_eq!( + apply_edit_prediction(buffer_content, completion_response, cx).await, + "lorem\nipsum" + ); +} + +#[gpui::test] +async fn test_can_collect_data(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/src/main.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Disabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let buffer = cx.new(|_cx| { + Buffer::remote( + language::BufferId::new(1).unwrap(), + ReplicaId::new(1), + language::Capability::ReadWrite, + "fn main() {\n println!(\"Hello\");\n}", + ) + }); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "LICENSE": BSD_0_TXT, + ".env": "SECRET_KEY=secret" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/.env", cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let buffer = cx.new(|cx| Buffer::local("", cx)); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/main.rs", cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/open_source_worktree"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + ) + .await; + fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/open_source_worktree").as_ref(), + path!("/closed_source_worktree").as_ref(), + ], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + let closed_source_file = project + .update(cx, |project, cx| { + let worktree2 = project + .worktree_for_root_name("closed_source_worktree", cx) + .unwrap(); + worktree2.update(cx, |worktree2, cx| { + worktree2.load_file(rel_path("main.rs"), cx) + }) + }) + .await + .unwrap() + .file; + + buffer.update(cx, |buffer, cx| { + buffer.file_updated(closed_source_file, cx); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/worktree1"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + ) + .await; + fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree1/main.rs"), cx) + }) + .await + .unwrap(); + let private_buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree2/file.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + // this has a side effect of registering the buffer to watch for edits + run_edit_prediction(&private_buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + private_buffer.update(cx, |private_buffer, cx| { + private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + // make an edit that uses too many bytes, causing private_buffer edit to not be able to be + // included + buffer.update(cx, |buffer, cx| { + buffer.edit( + [( + 0..0, + " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS), + )], + None, + cx, + ); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); +} + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); +} + +async fn apply_edit_prediction( + buffer_content: &str, + completion_response: &str, + cx: &mut TestAppContext, +) -> String { + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); + let (ep_store, _, response) = make_test_ep_store(&project, cx).await; + *response.lock() = completion_response.to_string(); + let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await; + buffer.update(cx, |buffer, cx| { + buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) + }); + buffer.read_with(cx, |buffer, _| buffer.text()) +} + +async fn run_edit_prediction( + buffer: &Entity, + project: &Entity, + ep_store: &Entity, + cx: &mut TestAppContext, +) -> EditPrediction { + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(buffer, &project, cx) + }); + cx.background_executor.run_until_parked(); + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx) + }); + prediction_task.await.unwrap().unwrap().prediction.unwrap() +} + +async fn make_test_ep_store( + project: &Entity, + cx: &mut TestAppContext, +) -> ( + Entity, + Arc>>, + Arc>, +) { + let default_response = indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + hello world + <|editable_region_end|> + ```" + }; + let captured_request: Arc>> = Arc::new(Mutex::new(None)); + let completion_response: Arc> = + Arc::new(Mutex::new(default_response.to_string())); + let http_client = FakeHttpClient::create({ + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + let mut next_request_id = 0; + move |req| { + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => { + let mut request_body = String::new(); + req.into_body().read_to_string(&mut request_body).await?; + *captured_request.lock() = + Some(serde_json::from_str(&request_body).unwrap()); + next_request_id += 1; + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: format!("request-{next_request_id}"), + output_excerpt: completion_response.lock().clone(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + + let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + cx.update(|cx| { + RefreshLlmTokenListener::register(client.clone(), cx); + }); + let _server = FakeServer::for_client(42, &client, cx).await; + + let ep_store = cx.new(|cx| { + let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx); + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + + let worktrees = project.read(cx).worktrees(cx).collect::>(); + for worktree in worktrees { + let worktree_id = worktree.read(cx).id(); + ep_store + .get_or_init_project(project, cx) + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); + } + + ep_store + }); + + (ep_store, captured_request, completion_response) +} + +fn to_completion_edits( + iterator: impl IntoIterator, Arc)>, + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + iterator + .into_iter() + .map(|(range, text)| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + text, + ) + }) + .collect() +} + +fn from_completion_edits( + editor_edits: &[(Range, Arc)], + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + editor_edits + .iter() + .map(|(range, text)| { + ( + range.start.to_offset(buffer)..range.end.to_offset(buffer), + text.clone(), + ) + }) + .collect() +} + +#[ctor::ctor] +fn init_logger() { + zlog::init_test(); +} diff --git a/crates/zeta/src/license_detection.rs b/crates/edit_prediction/src/license_detection.rs similarity index 100% rename from crates/zeta/src/license_detection.rs rename to crates/edit_prediction/src/license_detection.rs diff --git a/crates/zeta/src/onboarding_modal.rs b/crates/edit_prediction/src/onboarding_modal.rs similarity index 100% rename from crates/zeta/src/onboarding_modal.rs rename to crates/edit_prediction/src/onboarding_modal.rs diff --git a/crates/zeta/src/prediction.rs b/crates/edit_prediction/src/prediction.rs similarity index 99% rename from crates/zeta/src/prediction.rs rename to crates/edit_prediction/src/prediction.rs index fd3241730030fe8bdd95e2cae9ee87b406ade735..d169cf26e1dc4477554bfe8821ff5eae083a6124 100644 --- a/crates/zeta/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -99,7 +99,7 @@ pub struct EditPrediction { #[derive(Debug, Clone, Serialize)] pub struct EditPredictionInputs { pub events: Vec>, - pub included_files: Vec, + pub included_files: Vec, pub cursor_point: cloud_llm_client::predict_edits_v3::Point, pub cursor_path: Arc, } diff --git a/crates/zeta/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs similarity index 99% rename from crates/zeta/src/sweep_ai.rs rename to crates/edit_prediction/src/sweep_ai.rs index 0bc0d1d41e2393212f865e402912f6d760aa252e..4bb014c640cb489db29c800835a58febf91a7270 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; -use edit_prediction_context2::RelatedFile; +use edit_prediction_context::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ App, AppContext as _, Entity, Task, @@ -197,7 +197,7 @@ impl SweepAi { let inputs = EditPredictionInputs { events, - included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { path: full_path.clone(), max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { diff --git a/crates/zeta/src/udiff.rs b/crates/edit_prediction/src/udiff.rs similarity index 100% rename from crates/zeta/src/udiff.rs rename to crates/edit_prediction/src/udiff.rs diff --git a/crates/zeta/src/xml_edits.rs b/crates/edit_prediction/src/xml_edits.rs similarity index 100% rename from crates/zeta/src/xml_edits.rs rename to crates/edit_prediction/src/xml_edits.rs diff --git a/crates/zeta/src/provider.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs similarity index 58% rename from crates/zeta/src/provider.rs rename to crates/edit_prediction/src/zed_edit_prediction_delegate.rs index 019d780e579c079f745f56136bdbd3a4add76b50..91371d539beca012e2ded4e9ec8702c8db39bd8a 100644 --- a/crates/zeta/src/provider.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -1,55 +1,56 @@ -use std::{cmp, sync::Arc, time::Duration}; +use std::{cmp, sync::Arc}; use client::{Client, UserStore}; use cloud_llm_client::EditPredictionRejectReason; -use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; +use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate}; use gpui::{App, Entity, prelude::*}; -use language::ToPoint as _; +use language::{Buffer, ToPoint as _}; use project::Project; -use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel}; +use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore}; -pub struct ZetaEditPredictionProvider { - zeta: Entity, +pub struct ZedEditPredictionDelegate { + store: Entity, project: Entity, + singleton_buffer: Option>, } -impl ZetaEditPredictionProvider { - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - +impl ZedEditPredictionDelegate { pub fn new( project: Entity, + singleton_buffer: Option>, client: &Arc, user_store: &Entity, cx: &mut Context, ) -> Self { - let zeta = Zeta::global(client, user_store, cx); - zeta.update(cx, |zeta, cx| { - zeta.register_project(&project, cx); + let store = EditPredictionStore::global(client, user_store, cx); + store.update(cx, |store, cx| { + store.register_project(&project, cx); }); - cx.observe(&zeta, |_this, _zeta, cx| { + cx.observe(&store, |_this, _ep_store, cx| { cx.notify(); }) .detach(); Self { project: project, - zeta, + store: store, + singleton_buffer, } } } -impl EditPredictionProvider for ZetaEditPredictionProvider { +impl EditPredictionDelegate for ZedEditPredictionDelegate { fn name() -> &'static str { - "zed-predict2" + "zed-predict" } fn display_name() -> &'static str { - "Zed's Edit Predictions 2" + "Zed's Edit Predictions" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -57,17 +58,38 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { true } - fn data_collection_state(&self, _cx: &App) -> DataCollectionState { - // TODO [zeta2] - DataCollectionState::Unsupported + fn data_collection_state(&self, cx: &App) -> DataCollectionState { + if let Some(buffer) = &self.singleton_buffer + && let Some(file) = buffer.read(cx).file() + { + let is_project_open_source = + self.store + .read(cx) + .is_file_open_source(&self.project, file, cx); + if self.store.read(cx).data_collection_choice.is_enabled() { + DataCollectionState::Enabled { + is_project_open_source, + } + } else { + DataCollectionState::Disabled { + is_project_open_source, + } + } + } else { + return DataCollectionState::Disabled { + is_project_open_source: false, + }; + } } - fn toggle_data_collection(&mut self, _cx: &mut App) { - // TODO [zeta2] + fn toggle_data_collection(&mut self, cx: &mut App) { + self.store.update(cx, |store, cx| { + store.toggle_data_collection_choice(cx); + }); } fn usage(&self, cx: &App) -> Option { - self.zeta.read(cx).usage(cx) + self.store.read(cx).usage(cx) } fn is_enabled( @@ -76,16 +98,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { _cursor_position: language::Anchor, cx: &App, ) -> bool { - let zeta = self.zeta.read(cx); - if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep { - zeta.has_sweep_api_token() + let store = self.store.read(cx); + if store.edit_prediction_model == EditPredictionModel::Sweep { + store.has_sweep_api_token() } else { true } } fn is_refreshing(&self, cx: &App) -> bool { - self.zeta.read(cx).is_refreshing(&self.project) + self.store.read(cx).is_refreshing(&self.project) } fn refresh( @@ -95,24 +117,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { _debounce: bool, cx: &mut Context, ) { - let zeta = self.zeta.read(cx); + let store = self.store.read(cx); - if zeta.user_store.read_with(cx, |user_store, _cx| { + if store.user_store.read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { return; } - if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx) + if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx) && let BufferEditPrediction::Local { prediction } = current && prediction.interpolate(buffer.read(cx)).is_some() { return; } - self.zeta.update(cx, |zeta, cx| { - zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx); - zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) + self.store.update(cx, |store, cx| { + store.refresh_context(&self.project, &buffer, cursor_position, cx); + store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) }); } @@ -126,20 +148,20 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn accept(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, cx| { - zeta.accept_current_prediction(&self.project, cx); + self.store.update(cx, |store, cx| { + store.accept_current_prediction(&self.project, cx); }); } fn discard(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, _cx| { - zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project); + self.store.update(cx, |store, _cx| { + store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project); }); } fn did_show(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, cx| { - zeta.did_show_current_prediction(&self.project, cx); + self.store.update(cx, |store, cx| { + store.did_show_current_prediction(&self.project, cx); }); } @@ -148,16 +170,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, - ) -> Option { + ) -> Option { let prediction = - self.zeta + self.store .read(cx) .current_prediction_for_buffer(buffer, &self.project, cx)?; let prediction = match prediction { BufferEditPrediction::Local { prediction } => prediction, BufferEditPrediction::Jump { prediction } => { - return Some(edit_prediction::EditPrediction::Jump { + return Some(edit_prediction_types::EditPrediction::Jump { id: Some(prediction.id.to_string().into()), snapshot: prediction.snapshot.clone(), target: prediction.edits.first().unwrap().0.start, @@ -169,8 +191,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let snapshot = buffer.snapshot(); let Some(edits) = prediction.interpolate(&snapshot) else { - self.zeta.update(cx, |zeta, _cx| { - zeta.reject_current_prediction( + self.store.update(cx, |store, _cx| { + store.reject_current_prediction( EditPredictionRejectReason::InterpolatedEmpty, &self.project, ); @@ -208,7 +230,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } } - Some(edit_prediction::EditPrediction::Local { + Some(edit_prediction_types::EditPrediction::Local { id: Some(prediction.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/zeta/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs similarity index 96% rename from crates/zeta/src/zeta1.rs rename to crates/edit_prediction/src/zeta1.rs index 0be5fad301242c51c4ad58c60a6d2fcb3441ea08..06248603464563db12fa55a90c9c0bccf153c5f4 100644 --- a/crates/zeta/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -3,7 +3,7 @@ mod input_excerpt; use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ - EditPredictionId, ZedUpdateRequiredError, Zeta, + EditPredictionId, EditPredictionStore, ZedUpdateRequiredError, prediction::{EditPredictionInputs, EditPredictionResult}, }; use anyhow::{Context as _, Result}; @@ -30,23 +30,23 @@ pub(crate) const MAX_REWRITE_TOKENS: usize = 350; pub(crate) const MAX_EVENT_TOKENS: usize = 500; pub(crate) fn request_prediction_with_zeta1( - zeta: &mut Zeta, + store: &mut EditPredictionStore, project: &Entity, buffer: &Entity, snapshot: BufferSnapshot, position: language::Anchor, events: Vec>, trigger: PredictEditsRequestTrigger, - cx: &mut Context, + cx: &mut Context, ) -> Task>> { let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); - let client = zeta.client.clone(); - let llm_token = zeta.llm_token.clone(); + let client = store.client.clone(); + let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = zeta.can_collect_file(project, file, cx); + let can_collect_file = store.can_collect_file(project, file, cx); let git_info = if can_collect_file { git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { @@ -102,7 +102,7 @@ pub(crate) fn request_prediction_with_zeta1( let http_client = client.http_client(); - let response = Zeta::send_api_request::( + let response = EditPredictionStore::send_api_request::( |request| { let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { predict_edits_url @@ -124,7 +124,7 @@ pub(crate) fn request_prediction_with_zeta1( let inputs = EditPredictionInputs { events: included_events.into(), - included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { path: full_path.clone(), max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { @@ -155,8 +155,8 @@ pub(crate) fn request_prediction_with_zeta1( Err(err) => { if err.is::() { cx.update(|cx| { - this.update(cx, |zeta, _cx| { - zeta.update_required = true; + this.update(cx, |ep_store, _cx| { + ep_store.update_required = true; }) .ok(); diff --git a/crates/zeta/src/zeta1/input_excerpt.rs b/crates/edit_prediction/src/zeta1/input_excerpt.rs similarity index 100% rename from crates/zeta/src/zeta1/input_excerpt.rs rename to crates/edit_prediction/src/zeta1/input_excerpt.rs diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs new file mode 100644 index 0000000000000000000000000000000000000000..4808f38fc529b1c34212dd0198d15fb03a0baddf --- /dev/null +++ b/crates/edit_prediction/src/zeta2.rs @@ -0,0 +1,358 @@ +#[cfg(feature = "eval-support")] +use crate::EvalCacheEntryKind; +use crate::prediction::EditPredictionResult; +use crate::{ + DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs, + EditPredictionRequestedDebugEvent, EditPredictionStore, +}; +use anyhow::{Result, anyhow, bail}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; +use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger}; +use cloud_zeta2_prompt::CURSOR_MARKER; +use edit_prediction_context::{EditPredictionExcerpt, Line}; +use edit_prediction_context::{RelatedExcerpt, RelatedFile}; +use futures::channel::oneshot; +use gpui::{Entity, Task, prelude::*}; +use language::{Anchor, BufferSnapshot}; +use language::{Buffer, Point, ToOffset as _, ToPoint}; +use project::{Project, ProjectItem as _}; +use release_channel::AppVersion; +use std::{ + env, + path::Path, + sync::Arc, + time::{Duration, Instant}, +}; + +pub fn request_prediction_with_zeta2( + store: &mut EditPredictionStore, + project: &Entity, + active_buffer: &Entity, + active_snapshot: BufferSnapshot, + position: Anchor, + events: Vec>, + mut included_files: Vec, + trigger: PredictEditsRequestTrigger, + cx: &mut Context, +) -> Task>> { + let options = store.options.clone(); + let buffer_snapshotted_at = Instant::now(); + + let Some((excerpt_path, active_project_path)) = active_snapshot + .file() + .map(|file| -> Arc { file.full_path(cx).into() }) + .zip(active_buffer.read(cx).project_path(cx)) + else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + + let client = store.client.clone(); + let llm_token = store.llm_token.clone(); + let app_version = AppVersion::global(cx); + let debug_tx = store.debug_tx.clone(); + + let file = active_buffer.read(cx).file(); + + let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); + + // TODO data collection + let can_collect_data = file + .as_ref() + .map_or(false, |file| store.can_collect_file(project, file, cx)); + + #[cfg(feature = "eval-support")] + let eval_cache = store.eval_cache.clone(); + + let request_task = cx.background_spawn({ + let active_buffer = active_buffer.clone(); + async move { + let cursor_offset = position.to_offset(&active_snapshot); + let cursor_point = cursor_offset.to_point(&active_snapshot); + + let before_retrieval = Instant::now(); + + let excerpt_options = options.context; + + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &active_snapshot, + &excerpt_options, + ) else { + return Ok((None, None)); + }; + + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); + let related_excerpt = RelatedExcerpt { + anchor_range: excerpt_anchor_range.clone(), + point_range: Point::new(excerpt.line_range.start.0, 0) + ..Point::new(excerpt.line_range.end.0, 0), + text: active_snapshot.as_rope().slice(excerpt.range), + }; + + if let Some(buffer_ix) = included_files + .iter() + .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) + { + let file = &mut included_files[buffer_ix]; + file.excerpts.push(related_excerpt); + file.merge_excerpts(); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + let active_file = RelatedFile { + path: active_project_path, + buffer: active_buffer.downgrade(), + excerpts: vec![related_excerpt], + max_row: active_snapshot.max_point().row, + }; + included_files.push(active_file); + } + + let included_files = included_files + .iter() + .map(|related_file| predict_edits_v3::RelatedFile { + path: Arc::from(related_file.path.path.as_std_path()), + max_row: Line(related_file.max_row), + excerpts: related_file + .excerpts + .iter() + .map(|excerpt| predict_edits_v3::Excerpt { + start_line: Line(excerpt.point_range.start.row), + text: excerpt.text.to_string().into(), + }) + .collect(), + }) + .collect::>(); + + let cloud_request = predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + related_files: included_files, + events, + can_collect_data, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + excerpt_parent: None, + git_info: None, + trigger, + }; + + let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); + + let inputs = EditPredictionInputs { + included_files: cloud_request.related_files, + events: cloud_request.events, + cursor_point: cloud_request.cursor_point, + cursor_path: cloud_request.excerpt_path, + }; + + let retrieval_time = Instant::now() - before_retrieval; + + let debug_response_tx = if let Some(debug_tx) = &debug_tx { + let (response_tx, response_rx) = oneshot::channel(); + + debug_tx + .unbounded_send(DebugEvent::EditPredictionRequested( + EditPredictionRequestedDebugEvent { + inputs: inputs.clone(), + retrieval_time, + buffer: active_buffer.downgrade(), + local_prompt: match prompt_result.as_ref() { + Ok(prompt) => Ok(prompt.clone()), + Err(err) => Err(err.to_string()), + }, + position, + response_rx, + }, + )) + .ok(); + Some(response_tx) + } else { + None + }; + + if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send((Err("Request skipped".to_string()), Duration::ZERO)) + .ok(); + } + anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") + } + + let prompt = prompt_result?; + let generation_params = + cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); + let request = open_ai::Request { + model: EDIT_PREDICTIONS_MODEL_ID.clone(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: generation_params.stop.unwrap_or_default(), + temperature: generation_params.temperature.unwrap_or(0.7), + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + log::trace!("Sending edit prediction request"); + + let before_request = Instant::now(); + let response = EditPredictionStore::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "eval-support")] + eval_cache, + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Prediction, + ) + .await; + let received_response_at = Instant::now(); + let request_time = received_response_at - before_request; + + log::trace!("Got edit prediction response"); + + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send(( + response + .as_ref() + .map_err(|err| err.to_string()) + .map(|response| response.0.clone()), + request_time, + )) + .ok(); + } + + let (res, usage) = response?; + let request_id = EditPredictionId(res.id.clone().into()); + let Some(mut output_text) = text_from_response(res) else { + return Ok((Some((request_id, None)), usage)); + }; + + if output_text.contains(CURSOR_MARKER) { + log::trace!("Stripping out {CURSOR_MARKER} from response"); + output_text = output_text.replace(CURSOR_MARKER, ""); + } + + let get_buffer_from_context = |path: &Path| { + if Some(path) == active_file_full_path.as_deref() { + Some(( + &active_snapshot, + std::slice::from_ref(&excerpt_anchor_range), + )) + } else { + None + } + }; + + let (_, edits) = match options.prompt_format { + PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { + if output_text.contains("--- a/\n+++ b/\nNo edits") { + let edits = vec![]; + (&active_snapshot, edits) + } else { + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + } + PromptFormat::OldTextNewText => { + crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await? + } + _ => { + bail!("unsupported prompt format {}", options.prompt_format) + } + }; + + anyhow::Ok(( + Some(( + request_id, + Some(( + inputs, + active_buffer, + active_snapshot.clone(), + edits, + received_response_at, + )), + )), + usage, + )) + } + }); + + cx.spawn(async move |this, cx| { + let Some((id, prediction)) = + EditPredictionStore::handle_api_response(&this, request_task.await, cx)? + else { + return Ok(None); + }; + + let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) = + prediction + else { + return Ok(Some(EditPredictionResult { + id, + prediction: Err(EditPredictionRejectReason::Empty), + })); + }; + + Ok(Some( + EditPredictionResult::new( + id, + &edited_buffer, + &edited_buffer_snapshot, + edits.into(), + buffer_snapshotted_at, + received_response_at, + inputs, + cx, + ) + .await, + )) + }) +} + +pub fn text_from_response(mut res: open_ai::Response) -> Option { + let choice = res.choices.pop()?; + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return None; + } + + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return None; + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return None; + } + }; + Some(output_text) +} diff --git a/crates/zeta_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml similarity index 84% rename from crates/zeta_cli/Cargo.toml rename to crates/edit_prediction_cli/Cargo.toml index 2dbca537f55377e84f306e13649dfb71ccf2f181..d1b0b3f912ed2143b6c75ae39e94c2f7780ec4fe 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "zeta_cli" +name = "edit_prediction_cli" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [[bin]] -name = "zeta" +name = "ep_cli" path = "src/main.rs" [dependencies] @@ -19,7 +19,7 @@ chrono.workspace = true clap.workspace = true client.workspace = true cloud_llm_client.workspace= true -cloud_zeta2_prompt.workspace= true +cloud_zeta2_prompt.workspace = true collections.workspace = true debug_adapter_extension.workspace = true edit_prediction_context.workspace = true @@ -35,9 +35,7 @@ language_models.workspace = true languages = { workspace = true, features = ["load-grammars"] } log.workspace = true node_runtime.workspace = true -ordered-float.workspace = true paths.workspace = true -polars = { version = "0.51", features = ["lazy", "dtype-struct", "parquet"] } project.workspace = true prompt_store.workspace = true pulldown-cmark.workspace = true @@ -48,12 +46,11 @@ serde_json.workspace = true settings.workspace = true shellexpand.workspace = true smol.workspace = true -soa-rs = "0.8.1" terminal_view.workspace = true toml.workspace = true util.workspace = true watch.workspace = true -zeta = { workspace = true, features = ["eval-support"] } +edit_prediction = { workspace = true, features = ["eval-support"] } zlog.workspace = true [dev-dependencies] diff --git a/crates/edit_prediction_button/LICENSE-GPL b/crates/edit_prediction_cli/LICENSE-GPL similarity index 100% rename from crates/edit_prediction_button/LICENSE-GPL rename to crates/edit_prediction_cli/LICENSE-GPL diff --git a/crates/zeta_cli/build.rs b/crates/edit_prediction_cli/build.rs similarity index 100% rename from crates/zeta_cli/build.rs rename to crates/edit_prediction_cli/build.rs diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/edit_prediction_cli/src/evaluate.rs similarity index 98% rename from crates/zeta_cli/src/evaluate.rs rename to crates/edit_prediction_cli/src/evaluate.rs index 043844768557ad46f61d5fd0d809e1e85c62574f..686c8ce7e7865f265d6bf17e51ca9477194e5252 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/edit_prediction_cli/src/evaluate.rs @@ -6,17 +6,17 @@ use std::{ }; use anyhow::Result; +use edit_prediction::{EditPredictionStore, udiff::DiffLine}; use gpui::{AsyncApp, Entity}; use project::Project; use util::ResultExt as _; -use zeta::{Zeta, udiff::DiffLine}; use crate::{ EvaluateArguments, PredictionOptions, example::{Example, NamedExample}, headless::ZetaCliAppState, paths::print_run_data_dir, - predict::{PredictionDetails, perform_predict, setup_zeta}, + predict::{PredictionDetails, perform_predict, setup_store}, }; #[derive(Debug)] @@ -45,7 +45,7 @@ pub async fn run_evaluate( let project = example.setup_project(&app_state, cx).await.unwrap(); let providers = (0..args.repetitions) - .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap()) + .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap()) .collect::>(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); @@ -53,7 +53,7 @@ pub async fn run_evaluate( let tasks = providers .into_iter() .enumerate() - .map(move |(repetition_ix, zeta)| { + .map(move |(repetition_ix, store)| { let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); let example = example.clone(); let project = project.clone(); @@ -65,7 +65,7 @@ pub async fn run_evaluate( example, repetition_ix, project, - zeta, + store, options, !args.skip_prediction, cx, @@ -154,7 +154,7 @@ pub async fn run_evaluate_one( example: NamedExample, repetition_ix: Option, project: Entity, - zeta: Entity, + store: Entity, prediction_options: PredictionOptions, predict: bool, cx: &mut AsyncApp, @@ -162,7 +162,7 @@ pub async fn run_evaluate_one( let predict_result = perform_predict( example.clone(), project, - zeta, + store, repetition_ix, prediction_options, cx, diff --git a/crates/zeta_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs similarity index 99% rename from crates/zeta_cli/src/example.rs rename to crates/edit_prediction_cli/src/example.rs index a9d4c4f47c5a05d4198b1cffaee51e14a122e88d..2f52b89c552b65072f753432eb63b656624fdf61 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -14,6 +14,7 @@ use anyhow::{Context as _, Result, anyhow}; use clap::ValueEnum; use cloud_zeta2_prompt::CURSOR_MARKER; use collections::HashMap; +use edit_prediction::udiff::OpenedBuffers; use futures::{ AsyncWriteExt as _, lock::{Mutex, OwnedMutexGuard}, @@ -25,7 +26,6 @@ use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; use util::{paths::PathStyle, rel_path::RelPath}; -use zeta::udiff::OpenedBuffers; use crate::paths::{REPOS_DIR, WORKTREES_DIR}; @@ -481,7 +481,7 @@ impl NamedExample { project: &Entity, cx: &mut AsyncApp, ) -> Result> { - zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await + edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await } } diff --git a/crates/zeta_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs similarity index 100% rename from crates/zeta_cli/src/headless.rs rename to crates/edit_prediction_cli/src/headless.rs diff --git a/crates/zeta_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs similarity index 84% rename from crates/zeta_cli/src/main.rs rename to crates/edit_prediction_cli/src/main.rs index 42c0ea185f4401a11c2798f9402a59829f8463df..f2887b98a0ce829a58374fdd10c3e346b6f5d16a 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,7 +5,6 @@ mod metrics; mod paths; mod predict; mod source_location; -mod syntax_retrieval_stats; mod util; use crate::{ @@ -14,13 +13,13 @@ use crate::{ headless::ZetaCliAppState, predict::run_predict, source_location::SourceLocation, - syntax_retrieval_stats::retrieval_stats, util::{open_buffer, open_buffer_with_language_server}, }; use ::util::paths::PathStyle; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; +use edit_prediction::udiff::DiffLine; use edit_prediction_context::EditPredictionExcerptOptions; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, Point}; @@ -28,10 +27,7 @@ use metrics::delta_chr_f; use project::{Project, Worktree, lsp_store::OpenLspBufferHandle}; use reqwest_client::ReqwestClient; use std::io::{self}; -use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; -use zeta::ContextMode; -use zeta::udiff::DiffLine; #[derive(Parser, Debug)] #[command(name = "zeta")] @@ -45,7 +41,6 @@ struct ZetaCliArgs { #[derive(Subcommand, Debug)] enum Command { Context(ContextArgs), - ContextStats(ContextStatsArgs), Predict(PredictArguments), Eval(EvaluateArguments), ConvertExample { @@ -60,20 +55,6 @@ enum Command { Clean, } -#[derive(Debug, Args)] -struct ContextStatsArgs { - #[arg(long)] - worktree: PathBuf, - #[arg(long)] - extension: Option, - #[arg(long)] - limit: Option, - #[arg(long)] - skip: Option, - #[clap(flatten)] - zeta2_args: Zeta2Args, -} - #[derive(Debug, Args)] struct ContextArgs { #[arg(long)] @@ -201,28 +182,22 @@ enum PredictionProvider { Sweep, } -fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions { - zeta::ZetaOptions { - context: ContextMode::Lsp(EditPredictionExcerptOptions { +fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions { + edit_prediction::ZetaOptions { + context: EditPredictionExcerptOptions { max_bytes: args.max_excerpt_bytes, min_bytes: args.min_excerpt_bytes, target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, - }), - max_diagnostic_bytes: args.max_diagnostic_bytes, + }, max_prompt_bytes: args.max_prompt_bytes, prompt_format: args.prompt_format.into(), - file_indexing_parallelism: args.file_indexing_parallelism, - buffer_change_grouping_interval: Duration::ZERO, } } #[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] enum PromptFormat { - MarkedExcerpt, - LabeledSections, OnlySnippets, #[default] - NumberedLines, OldTextNewText, Minimal, MinimalQwen, @@ -232,10 +207,7 @@ enum PromptFormat { impl Into for PromptFormat { fn into(self) -> predict_edits_v3::PromptFormat { match self { - Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt, - Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, - Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff, Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, Self::Minimal => predict_edits_v3::PromptFormat::Minimal, Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen, @@ -395,27 +367,29 @@ async fn zeta2_context( .await; let output = cx .update(|cx| { - let zeta = cx.new(|cx| { - zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) + let store = cx.new(|cx| { + edit_prediction::EditPredictionStore::new( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ) }); - let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2_args_to_options(&args.zeta2_args)); - zeta.register_buffer(&buffer, &project, cx); - zeta.wait_for_initial_indexing(&project, cx) + store.update(cx, |store, cx| { + store.set_options(zeta2_args_to_options(&args.zeta2_args)); + store.register_buffer(&buffer, &project, cx); }); cx.spawn(async move |cx| { - indexing_done_task.await?; - let updates_rx = zeta.update(cx, |zeta, cx| { + let updates_rx = store.update(cx, |store, cx| { let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - zeta.set_use_context(true); - zeta.refresh_context_if_needed(&project, &buffer, cursor, cx); - zeta.project_context_updates(&project).unwrap() + store.set_use_context(true); + store.refresh_context(&project, &buffer, cursor, cx); + store.project_context_updates(&project).unwrap() })?; updates_rx.recv().await.ok(); - let context = zeta.update(cx, |zeta, cx| { - zeta.context_for_project(&project, cx).to_vec() + let context = store.update(cx, |store, cx| { + store.context_for_project(&project, cx).to_vec() })?; anyhow::Ok(serde_json::to_string_pretty(&context).unwrap()) @@ -430,7 +404,7 @@ async fn zeta1_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, -) -> Result { +) -> Result { let LoadedContext { full_path_str, snapshot, @@ -445,7 +419,7 @@ async fn zeta1_context( let prompt_for_events = move || (events, 0); cx.update(|cx| { - zeta::zeta1::gather_context( + edit_prediction::zeta1::gather_context( full_path_str, &snapshot, clipped_cursor, @@ -475,19 +449,6 @@ fn main() { panic!("Expected a command"); } } - Some(Command::ContextStats(arguments)) => { - let result = retrieval_stats( - arguments.worktree, - app_state, - arguments.extension, - arguments.limit, - arguments.skip, - zeta2_args_to_options(&arguments.zeta2_args), - cx, - ) - .await; - println!("{}", result.unwrap()); - } Some(Command::Context(context_args)) => { let result = match context_args.provider { ContextProvider::Zeta1 => { diff --git a/crates/zeta_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs similarity index 99% rename from crates/zeta_cli/src/metrics.rs rename to crates/edit_prediction_cli/src/metrics.rs index dd08459678eef6d04a6b656d19a4572d51a5b5c1..0fdb7fb535df12d00341997a64a96b97867f6f28 100644 --- a/crates/zeta_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,5 +1,5 @@ use collections::{HashMap, HashSet}; -use zeta::udiff::DiffLine; +use edit_prediction::udiff::DiffLine; type Counts = HashMap; type CountsDelta = HashMap; @@ -287,7 +287,7 @@ fn count_ngrams(text: &str, n: usize) -> Counts { #[cfg(test)] mod test { use super::*; - use zeta::udiff::DiffLine; + use edit_prediction::udiff::DiffLine; #[test] fn test_delta_chr_f_perfect_match() { diff --git a/crates/zeta_cli/src/paths.rs b/crates/edit_prediction_cli/src/paths.rs similarity index 100% rename from crates/zeta_cli/src/paths.rs rename to crates/edit_prediction_cli/src/paths.rs diff --git a/crates/zeta_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs similarity index 85% rename from crates/zeta_cli/src/predict.rs rename to crates/edit_prediction_cli/src/predict.rs index 9fefc5ce97672796f79482e23acca3599aa1ff44..db1fed70d82a1a19713dfe54dfd6cea2bfa03d5d 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -7,6 +7,7 @@ use crate::{ use ::serde::Serialize; use anyhow::{Context, Result, anyhow}; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; +use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey}; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp, Entity}; use project::Project; @@ -18,7 +19,6 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; -use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; pub async fn run_predict( args: PredictArguments, @@ -27,9 +27,9 @@ pub async fn run_predict( ) { let example = NamedExample::load(args.example_path).unwrap(); let project = example.setup_project(app_state, cx).await.unwrap(); - let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap(); + let store = setup_store(args.options.provider, &project, app_state, cx).unwrap(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - let result = perform_predict(example, project, zeta, None, args.options, cx) + let result = perform_predict(example, project, store, None, args.options, cx) .await .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); @@ -37,45 +37,50 @@ pub async fn run_predict( print_run_data_dir(true, std::io::stdout().is_terminal()); } -pub fn setup_zeta( +pub fn setup_store( provider: PredictionProvider, project: &Entity, app_state: &Arc, cx: &mut AsyncApp, -) -> Result> { - let zeta = - cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; +) -> Result> { + let store = cx.new(|cx| { + edit_prediction::EditPredictionStore::new( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ) + })?; - zeta.update(cx, |zeta, _cx| { + store.update(cx, |store, _cx| { let model = match provider { - PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2, - PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep, + PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, + PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, }; - zeta.set_edit_prediction_model(model); + store.set_edit_prediction_model(model); })?; let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; cx.subscribe(&buffer_store, { let project = project.clone(); - let zeta = zeta.clone(); + let store = store.clone(); move |_, event, cx| match event { BufferStoreEvent::BufferAdded(buffer) => { - zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx)); + store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); } _ => {} } })? .detach(); - anyhow::Ok(zeta) + anyhow::Ok(store) } pub async fn perform_predict( example: NamedExample, project: Entity, - zeta: Entity, + store: Entity, repetition_ix: Option, options: PredictionOptions, cx: &mut AsyncApp, @@ -108,8 +113,8 @@ pub async fn perform_predict( std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) .context("creating latest link")?; - zeta.update(cx, |zeta, _cx| { - zeta.with_eval_cache(Arc::new(RunCache { + store.update(cx, |store, _cx| { + store.with_eval_cache(Arc::new(RunCache { example_run_dir: example_run_dir.clone(), cache_mode, })); @@ -121,16 +126,16 @@ pub async fn perform_predict( let prompt_format = options.zeta2.prompt_format; - zeta.update(cx, |zeta, _cx| { - let mut options = zeta.options().clone(); + store.update(cx, |store, _cx| { + let mut options = store.options().clone(); options.prompt_format = prompt_format.into(); - zeta.set_options(options); + store.set_options(options); })?; let mut debug_task = gpui::Task::ready(Ok(())); if options.provider == crate::PredictionProvider::Zeta2 { - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + let mut debug_rx = store.update(cx, |store, _| store.debug_info())?; debug_task = cx.background_spawn({ let result = result.clone(); @@ -139,14 +144,14 @@ pub async fn perform_predict( let mut retrieval_finished_at = None; while let Some(event) = debug_rx.next().await { match event { - zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => { + edit_prediction::DebugEvent::ContextRetrievalStarted(info) => { start_time = Some(info.timestamp); fs::write( example_run_dir.join("search_prompt.md"), &info.search_prompt, )?; } - zeta::ZetaDebugInfo::ContextRetrievalFinished(info) => { + edit_prediction::DebugEvent::ContextRetrievalFinished(info) => { retrieval_finished_at = Some(info.timestamp); for (key, value) in &info.metadata { if *key == "search_queries" { @@ -157,7 +162,7 @@ pub async fn perform_predict( } } } - zeta::ZetaDebugInfo::EditPredictionRequested(request) => { + edit_prediction::DebugEvent::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); let prompt = request.local_prompt.unwrap_or_default(); @@ -193,7 +198,8 @@ pub async fn perform_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta::text_from_response(response).unwrap_or_default(); + let response = edit_prediction::zeta2::text_from_response(response) + .unwrap_or_default(); let prediction_finished_at = Instant::now(); fs::write(example_run_dir.join("prediction_response.md"), &response)?; @@ -212,20 +218,14 @@ pub async fn perform_predict( } }); - zeta.update(cx, |zeta, cx| { - zeta.refresh_context_with_agentic_retrieval( - project.clone(), - cursor_buffer.clone(), - cursor_anchor, - cx, - ) - })? - .await?; + store.update(cx, |store, cx| { + store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx) + })?; } - let prediction = zeta - .update(cx, |zeta, cx| { - zeta.request_prediction( + let prediction = store + .update(cx, |store, cx| { + store.request_prediction( &project, &cursor_buffer, cursor_anchor, diff --git a/crates/zeta_cli/src/source_location.rs b/crates/edit_prediction_cli/src/source_location.rs similarity index 100% rename from crates/zeta_cli/src/source_location.rs rename to crates/edit_prediction_cli/src/source_location.rs diff --git a/crates/zeta_cli/src/util.rs b/crates/edit_prediction_cli/src/util.rs similarity index 100% rename from crates/zeta_cli/src/util.rs rename to crates/edit_prediction_cli/src/util.rs diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 6976831b8cbbe2b998f713ff65f1585f28fc3005..f113c3c46075ca70e61d8d07947d37502e8528e8 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -12,41 +12,32 @@ workspace = true path = "src/edit_prediction_context.rs" [dependencies] +parking_lot.workspace = true anyhow.workspace = true -arrayvec.workspace = true cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true -hashbrown.workspace = true -indoc.workspace = true -itertools.workspace = true language.workspace = true -log.workspace = true -ordered-float.workspace = true -postage.workspace = true +lsp.workspace = true project.workspace = true -regex.workspace = true +log.workspace = true serde.workspace = true -slotmap.workspace = true -strum.workspace = true -text.workspace = true +smallvec.workspace = true tree-sitter.workspace = true util.workspace = true [dev-dependencies] -clap.workspace = true +env_logger.workspace = true +indoc.workspace = true futures.workspace = true gpui = { workspace = true, features = ["test-support"] } -indoc.workspace = true language = { workspace = true, features = ["test-support"] } +lsp = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true project = {workspace= true, features = ["test-support"]} serde_json.workspace = true settings = {workspace= true, features = ["test-support"]} text = { workspace = true, features = ["test-support"] } -tree-sitter-c.workspace = true -tree-sitter-cpp.workspace = true -tree-sitter-go.workspace = true util = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/edit_prediction_context2/src/assemble_excerpts.rs b/crates/edit_prediction_context/src/assemble_excerpts.rs similarity index 100% rename from crates/edit_prediction_context2/src/assemble_excerpts.rs rename to crates/edit_prediction_context/src/assemble_excerpts.rs diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs deleted file mode 100644 index cc32640425ecc563b1f24a6c695be1c13199cd73..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/declaration.rs +++ /dev/null @@ -1,350 +0,0 @@ -use cloud_llm_client::predict_edits_v3::{self, Line}; -use language::{Language, LanguageId}; -use project::ProjectEntryId; -use std::ops::Range; -use std::sync::Arc; -use std::{borrow::Cow, path::Path}; -use text::{Bias, BufferId, Rope}; -use util::paths::{path_ends_with, strip_path_suffix}; -use util::rel_path::RelPath; - -use crate::outline::OutlineDeclaration; - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Identifier { - pub name: Arc, - pub language_id: LanguageId, -} - -slotmap::new_key_type! { - pub struct DeclarationId; -} - -#[derive(Debug, Clone)] -pub enum Declaration { - File { - project_entry_id: ProjectEntryId, - declaration: FileDeclaration, - cached_path: CachedDeclarationPath, - }, - Buffer { - project_entry_id: ProjectEntryId, - buffer_id: BufferId, - rope: Rope, - declaration: BufferDeclaration, - cached_path: CachedDeclarationPath, - }, -} - -const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024; - -impl Declaration { - pub fn identifier(&self) -> &Identifier { - match self { - Declaration::File { declaration, .. } => &declaration.identifier, - Declaration::Buffer { declaration, .. } => &declaration.identifier, - } - } - - pub fn parent(&self) -> Option { - match self { - Declaration::File { declaration, .. } => declaration.parent, - Declaration::Buffer { declaration, .. } => declaration.parent, - } - } - - pub fn as_buffer(&self) -> Option<&BufferDeclaration> { - match self { - Declaration::File { .. } => None, - Declaration::Buffer { declaration, .. } => Some(declaration), - } - } - - pub fn as_file(&self) -> Option<&FileDeclaration> { - match self { - Declaration::Buffer { .. } => None, - Declaration::File { declaration, .. } => Some(declaration), - } - } - - pub fn project_entry_id(&self) -> ProjectEntryId { - match self { - Declaration::File { - project_entry_id, .. - } => *project_entry_id, - Declaration::Buffer { - project_entry_id, .. - } => *project_entry_id, - } - } - - pub fn cached_path(&self) -> &CachedDeclarationPath { - match self { - Declaration::File { cached_path, .. } => cached_path, - Declaration::Buffer { cached_path, .. } => cached_path, - } - } - - pub fn item_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.item_range.clone(), - Declaration::Buffer { declaration, .. } => declaration.item_range.clone(), - } - } - - pub fn item_line_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.item_line_range.clone(), - Declaration::Buffer { - declaration, rope, .. - } => { - Line(rope.offset_to_point(declaration.item_range.start).row) - ..Line(rope.offset_to_point(declaration.item_range.end).row) - } - } - } - - pub fn item_text(&self) -> (Cow<'_, str>, bool) { - match self { - Declaration::File { declaration, .. } => ( - declaration.text.as_ref().into(), - declaration.text_is_truncated, - ), - Declaration::Buffer { - rope, declaration, .. - } => ( - rope.chunks_in_range(declaration.item_range.clone()) - .collect::>(), - declaration.item_range_is_truncated, - ), - } - } - - pub fn signature_text(&self) -> (Cow<'_, str>, bool) { - match self { - Declaration::File { declaration, .. } => ( - declaration.text[self.signature_range_in_item_text()].into(), - declaration.signature_is_truncated, - ), - Declaration::Buffer { - rope, declaration, .. - } => ( - rope.chunks_in_range(declaration.signature_range.clone()) - .collect::>(), - declaration.signature_range_is_truncated, - ), - } - } - - pub fn signature_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.signature_range.clone(), - Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(), - } - } - - pub fn signature_line_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.signature_line_range.clone(), - Declaration::Buffer { - declaration, rope, .. - } => { - Line(rope.offset_to_point(declaration.signature_range.start).row) - ..Line(rope.offset_to_point(declaration.signature_range.end).row) - } - } - } - - pub fn signature_range_in_item_text(&self) -> Range { - let signature_range = self.signature_range(); - let item_range = self.item_range(); - signature_range.start.saturating_sub(item_range.start) - ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len()) - } -} - -fn expand_range_to_line_boundaries_and_truncate( - range: &Range, - limit: usize, - rope: &Rope, -) -> (Range, Range, bool) { - let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end); - point_range.start.column = 0; - point_range.end.row += 1; - point_range.end.column = 0; - - let mut item_range = - rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end); - let is_truncated = item_range.len() > limit; - if is_truncated { - item_range.end = item_range.start + limit; - } - item_range.end = rope.clip_offset(item_range.end, Bias::Left); - - let line_range = - predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row); - (item_range, line_range, is_truncated) -} - -#[derive(Debug, Clone)] -pub struct FileDeclaration { - pub parent: Option, - pub identifier: Identifier, - /// offset range of the declaration in the file, expanded to line boundaries and truncated - pub item_range: Range, - /// line range of the declaration in the file, potentially truncated - pub item_line_range: Range, - /// text of `item_range` - pub text: Arc, - /// whether `text` was truncated - pub text_is_truncated: bool, - /// offset range of the signature in the file, expanded to line boundaries and truncated - pub signature_range: Range, - /// line range of the signature in the file, truncated - pub signature_line_range: Range, - /// whether `signature` was truncated - pub signature_is_truncated: bool, -} - -impl FileDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration { - let (item_range_in_file, item_line_range_in_file, text_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - - let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - - if signature_range_in_file.start < item_range_in_file.start { - signature_range_in_file.start = item_range_in_file.start; - signature_is_truncated = true; - } - if signature_range_in_file.end > item_range_in_file.end { - signature_range_in_file.end = item_range_in_file.end; - signature_is_truncated = true; - } - - FileDeclaration { - parent: None, - identifier: declaration.identifier, - signature_range: signature_range_in_file, - signature_line_range, - signature_is_truncated, - text: rope - .chunks_in_range(item_range_in_file.clone()) - .collect::() - .into(), - text_is_truncated, - item_range: item_range_in_file, - item_line_range: item_line_range_in_file, - } - } -} - -#[derive(Debug, Clone)] -pub struct BufferDeclaration { - pub parent: Option, - pub identifier: Identifier, - pub item_range: Range, - pub item_range_is_truncated: bool, - pub signature_range: Range, - pub signature_range_is_truncated: bool, -} - -impl BufferDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self { - let (item_range, _item_line_range, item_range_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - let (signature_range, _signature_line_range, signature_range_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - Self { - parent: None, - identifier: declaration.identifier, - item_range, - item_range_is_truncated, - signature_range, - signature_range_is_truncated, - } - } -} - -#[derive(Debug, Clone)] -pub struct CachedDeclarationPath { - pub worktree_abs_path: Arc, - pub rel_path: Arc, - /// The relative path of the file, possibly stripped according to `import_path_strip_regex`. - pub rel_path_after_regex_stripping: Arc, -} - -impl CachedDeclarationPath { - pub fn new( - worktree_abs_path: Arc, - path: &Arc, - language: Option<&Arc>, - ) -> Self { - let rel_path = path.clone(); - let rel_path_after_regex_stripping = if let Some(language) = language - && let Some(strip_regex) = language.config().import_path_strip_regex.as_ref() - && let Ok(stripped) = RelPath::unix(&Path::new( - strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(), - )) { - Arc::from(stripped) - } else { - rel_path.clone() - }; - CachedDeclarationPath { - worktree_abs_path, - rel_path, - rel_path_after_regex_stripping, - } - } - - #[cfg(test)] - pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self { - let rel_path: Arc = util::rel_path::rel_path(rel_path).into(); - CachedDeclarationPath { - worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(), - rel_path_after_regex_stripping: rel_path.clone(), - rel_path, - } - } - - pub fn ends_with_posix_path(&self, path: &Path) -> bool { - if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() { - path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path) - } else { - if let Some(remaining) = - strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path()) - { - path_ends_with(&self.worktree_abs_path, remaining) - } else { - false - } - } - } - - pub fn equals_absolute_path(&self, path: &Path) -> bool { - if let Some(remaining) = - strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path()) - { - self.worktree_abs_path.as_ref() == remaining - } else { - false - } - } -} diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs deleted file mode 100644 index 48a823362769770c836b44e7d8a6c1942d3a1196..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ /dev/null @@ -1,539 +0,0 @@ -use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; -use collections::HashMap; -use language::BufferSnapshot; -use ordered_float::OrderedFloat; -use project::ProjectEntryId; -use serde::Serialize; -use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc}; -use strum::EnumIter; -use text::{Point, ToPoint}; -use util::RangeExt as _; - -use crate::{ - CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier, - imports::{Import, Imports, Module}, - reference::{Reference, ReferenceRegion}, - syntax_index::SyntaxIndexState, - text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient}, -}; - -const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct EditPredictionScoreOptions { - pub omit_excerpt_overlaps: bool, -} - -#[derive(Clone, Debug)] -pub struct ScoredDeclaration { - /// identifier used by the local reference - pub identifier: Identifier, - pub declaration: Declaration, - pub components: DeclarationScoreComponents, -} - -#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub enum DeclarationStyle { - Signature, - Declaration, -} - -#[derive(Clone, Debug, Serialize, Default)] -pub struct DeclarationScores { - pub signature: f32, - pub declaration: f32, - pub retrieval: f32, -} - -impl ScoredDeclaration { - /// Returns the score for this declaration with the specified style. - pub fn score(&self, style: DeclarationStyle) -> f32 { - // TODO: handle truncation - - // Score related to how likely this is the correct declaration, range 0 to 1 - let retrieval = self.retrieval_score(); - - // Score related to the distance between the reference and cursor, range 0 to 1 - let distance_score = if self.components.is_referenced_nearby { - 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0) - } else { - // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures - 0.5 - }; - - // For now instead of linear combination, the scores are just multiplied together. - let combined_score = 10.0 * retrieval * distance_score; - - match style { - DeclarationStyle::Signature => { - combined_score * self.components.excerpt_vs_signature_weighted_overlap - } - DeclarationStyle::Declaration => { - 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap - } - } - } - - pub fn retrieval_score(&self) -> f32 { - let mut score = if self.components.is_same_file { - 10.0 / self.components.same_file_declaration_count as f32 - } else if self.components.path_import_match_count > 0 { - 3.0 - } else if self.components.wildcard_path_import_match_count > 0 { - 1.0 - } else if self.components.normalized_import_similarity > 0.0 { - self.components.normalized_import_similarity - } else if self.components.normalized_wildcard_import_similarity > 0.0 { - 0.5 * self.components.normalized_wildcard_import_similarity - } else { - 1.0 / self.components.declaration_count as f32 - }; - score *= 1. + self.components.included_by_others as f32 / 2.; - score *= 1. + self.components.includes_others as f32 / 4.; - score - } - - pub fn size(&self, style: DeclarationStyle) -> usize { - match &self.declaration { - Declaration::File { declaration, .. } => match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.text.len(), - }, - Declaration::Buffer { declaration, .. } => match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.item_range.len(), - }, - } - } - - pub fn score_density(&self, style: DeclarationStyle) -> f32 { - self.score(style) / self.size(style) as f32 - } -} - -pub fn scored_declarations( - options: &EditPredictionScoreOptions, - index: &SyntaxIndexState, - excerpt: &EditPredictionExcerpt, - excerpt_occurrences: &Occurrences, - adjacent_occurrences: &Occurrences, - imports: &Imports, - identifier_to_references: HashMap>, - cursor_offset: usize, - current_buffer: &BufferSnapshot, -) -> Vec { - let cursor_point = cursor_offset.to_point(¤t_buffer); - - let mut wildcard_import_occurrences = Vec::new(); - let mut wildcard_import_paths = Vec::new(); - for wildcard_import in imports.wildcard_modules.iter() { - match wildcard_import { - Module::Namespace(namespace) => { - wildcard_import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => wildcard_import_paths.push(path), - Module::SourceFuzzy(path) => { - wildcard_import_occurrences.push(Occurrences::from_path(&path)) - } - } - } - - let mut scored_declarations = Vec::new(); - let mut project_entry_id_to_outline_ranges: HashMap>> = - HashMap::default(); - for (identifier, references) in identifier_to_references { - let mut import_occurrences = Vec::new(); - let mut import_paths = Vec::new(); - let mut found_external_identifier: Option<&Identifier> = None; - - if let Some(imports) = imports.identifier_to_imports.get(&identifier) { - // only use alias when it's the only import, could be generalized if some language - // has overlapping aliases - // - // TODO: when an aliased declaration is included in the prompt, should include the - // aliasing in the prompt. - // - // TODO: For SourceFuzzy consider having componentwise comparison that pays - // attention to ordering. - if let [ - Import::Alias { - module, - external_identifier, - }, - ] = imports.as_slice() - { - match module { - Module::Namespace(namespace) => { - import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => import_paths.push(path), - Module::SourceFuzzy(path) => { - import_occurrences.push(Occurrences::from_path(&path)) - } - } - found_external_identifier = Some(&external_identifier); - } else { - for import in imports { - match import { - Import::Direct { module } => match module { - Module::Namespace(namespace) => { - import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => import_paths.push(path), - Module::SourceFuzzy(path) => { - import_occurrences.push(Occurrences::from_path(&path)) - } - }, - Import::Alias { .. } => {} - } - } - } - } - - let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier); - // TODO: update this to be able to return more declarations? Especially if there is the - // ability to quickly filter a large list (based on imports) - let identifier_declarations = index - .declarations_for_identifier::(&identifier_to_lookup); - let declaration_count = identifier_declarations.len(); - - if declaration_count == 0 { - continue; - } - - // TODO: option to filter out other candidates when same file / import match - let mut checked_declarations = Vec::with_capacity(declaration_count); - for (declaration_id, declaration) in identifier_declarations { - match declaration { - Declaration::Buffer { - buffer_id, - declaration: buffer_declaration, - .. - } => { - if buffer_id == ¤t_buffer.remote_id() { - let already_included_in_prompt = - range_intersection(&buffer_declaration.item_range, &excerpt.range) - .is_some() - || excerpt - .parent_declarations - .iter() - .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id); - if !options.omit_excerpt_overlaps || !already_included_in_prompt { - let declaration_line = buffer_declaration - .item_range - .start - .to_point(current_buffer) - .row; - let declaration_line_distance = - (cursor_point.row as i32 - declaration_line as i32).unsigned_abs(); - checked_declarations.push(CheckedDeclaration { - declaration, - same_file_line_distance: Some(declaration_line_distance), - path_import_match_count: 0, - wildcard_path_import_match_count: 0, - }); - } - continue; - } else { - } - } - Declaration::File { .. } => {} - } - let declaration_path = declaration.cached_path(); - let path_import_match_count = import_paths - .iter() - .filter(|import_path| { - declaration_path_matches_import(&declaration_path, import_path) - }) - .count(); - let wildcard_path_import_match_count = wildcard_import_paths - .iter() - .filter(|import_path| { - declaration_path_matches_import(&declaration_path, import_path) - }) - .count(); - checked_declarations.push(CheckedDeclaration { - declaration, - same_file_line_distance: None, - path_import_match_count, - wildcard_path_import_match_count, - }); - } - - let mut max_import_similarity = 0.0; - let mut max_wildcard_import_similarity = 0.0; - - let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len()); - for checked_declaration in checked_declarations { - let same_file_declaration_count = - index.file_declaration_count(checked_declaration.declaration); - - let declaration = score_declaration( - &identifier, - &references, - checked_declaration, - same_file_declaration_count, - declaration_count, - &excerpt_occurrences, - &adjacent_occurrences, - &import_occurrences, - &wildcard_import_occurrences, - cursor_point, - current_buffer, - ); - - if declaration.components.import_similarity > max_import_similarity { - max_import_similarity = declaration.components.import_similarity; - } - - if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity { - max_wildcard_import_similarity = declaration.components.wildcard_import_similarity; - } - - project_entry_id_to_outline_ranges - .entry(declaration.declaration.project_entry_id()) - .or_default() - .push(declaration.declaration.item_range()); - scored_declarations_for_identifier.push(declaration); - } - - if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 { - for declaration in scored_declarations_for_identifier.iter_mut() { - if max_import_similarity > 0.0 { - declaration.components.max_import_similarity = max_import_similarity; - declaration.components.normalized_import_similarity = - declaration.components.import_similarity / max_import_similarity; - } - if max_wildcard_import_similarity > 0.0 { - declaration.components.normalized_wildcard_import_similarity = - declaration.components.wildcard_import_similarity - / max_wildcard_import_similarity; - } - } - } - - scored_declarations.extend(scored_declarations_for_identifier); - } - - // TODO: Inform this via import / retrieval scores of outline items - // TODO: Consider using a sweepline - for scored_declaration in scored_declarations.iter_mut() { - let project_entry_id = scored_declaration.declaration.project_entry_id(); - let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else { - continue; - }; - for range in ranges { - if range.contains_inclusive(&scored_declaration.declaration.item_range()) { - scored_declaration.components.included_by_others += 1 - } else if scored_declaration - .declaration - .item_range() - .contains_inclusive(range) - { - scored_declaration.components.includes_others += 1 - } - } - } - - scored_declarations.sort_unstable_by_key(|declaration| { - Reverse(OrderedFloat( - declaration.score(DeclarationStyle::Declaration), - )) - }); - - scored_declarations -} - -struct CheckedDeclaration<'a> { - declaration: &'a Declaration, - same_file_line_distance: Option, - path_import_match_count: usize, - wildcard_path_import_match_count: usize, -} - -fn declaration_path_matches_import( - declaration_path: &CachedDeclarationPath, - import_path: &Arc, -) -> bool { - if import_path.is_absolute() { - declaration_path.equals_absolute_path(import_path) - } else { - declaration_path.ends_with_posix_path(import_path) - } -} - -fn range_intersection(a: &Range, b: &Range) -> Option> { - let start = a.start.clone().max(b.start.clone()); - let end = a.end.clone().min(b.end.clone()); - if start < end { - Some(Range { start, end }) - } else { - None - } -} - -fn score_declaration( - identifier: &Identifier, - references: &[Reference], - checked_declaration: CheckedDeclaration, - same_file_declaration_count: usize, - declaration_count: usize, - excerpt_occurrences: &Occurrences, - adjacent_occurrences: &Occurrences, - import_occurrences: &[Occurrences], - wildcard_import_occurrences: &[Occurrences], - cursor: Point, - current_buffer: &BufferSnapshot, -) -> ScoredDeclaration { - let CheckedDeclaration { - declaration, - same_file_line_distance, - path_import_match_count, - wildcard_path_import_match_count, - } = checked_declaration; - - let is_referenced_nearby = references - .iter() - .any(|r| r.region == ReferenceRegion::Nearby); - let is_referenced_in_breadcrumb = references - .iter() - .any(|r| r.region == ReferenceRegion::Breadcrumb); - let reference_count = references.len(); - let reference_line_distance = references - .iter() - .map(|r| { - let reference_line = r.range.start.to_point(current_buffer).row as i32; - (cursor.row as i32 - reference_line).unsigned_abs() - }) - .min() - .unwrap(); - - let is_same_file = same_file_line_distance.is_some(); - let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX); - - let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0); - let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0); - let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences); - let excerpt_vs_signature_jaccard = - jaccard_similarity(excerpt_occurrences, &item_signature_occurrences); - let adjacent_vs_item_jaccard = - jaccard_similarity(adjacent_occurrences, &item_source_occurrences); - let adjacent_vs_signature_jaccard = - jaccard_similarity(adjacent_occurrences, &item_signature_occurrences); - - let excerpt_vs_item_weighted_overlap = - weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences); - let excerpt_vs_signature_weighted_overlap = - weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences); - let adjacent_vs_item_weighted_overlap = - weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences); - let adjacent_vs_signature_weighted_overlap = - weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences); - - let mut import_similarity = 0f32; - let mut wildcard_import_similarity = 0f32; - if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() { - let cached_path = declaration.cached_path(); - let path_occurrences = Occurrences::from_worktree_path( - cached_path - .worktree_abs_path - .file_name() - .map(|f| f.to_string_lossy()), - &cached_path.rel_path, - ); - import_similarity = import_occurrences - .iter() - .map(|namespace_occurrences| { - OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) - }) - .max() - .map(|similarity| similarity.into_inner()) - .unwrap_or_default(); - - // TODO: Consider something other than max - wildcard_import_similarity = wildcard_import_occurrences - .iter() - .map(|namespace_occurrences| { - OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) - }) - .max() - .map(|similarity| similarity.into_inner()) - .unwrap_or_default(); - } - - // TODO: Consider adding declaration_file_count - let score_components = DeclarationScoreComponents { - is_same_file, - is_referenced_nearby, - is_referenced_in_breadcrumb, - reference_line_distance, - declaration_line_distance, - reference_count, - same_file_declaration_count, - declaration_count, - excerpt_vs_item_jaccard, - excerpt_vs_signature_jaccard, - adjacent_vs_item_jaccard, - adjacent_vs_signature_jaccard, - excerpt_vs_item_weighted_overlap, - excerpt_vs_signature_weighted_overlap, - adjacent_vs_item_weighted_overlap, - adjacent_vs_signature_weighted_overlap, - path_import_match_count, - wildcard_path_import_match_count, - import_similarity, - max_import_similarity: 0.0, - normalized_import_similarity: 0.0, - wildcard_import_similarity, - normalized_wildcard_import_similarity: 0.0, - included_by_others: 0, - includes_others: 0, - }; - - ScoredDeclaration { - identifier: identifier.clone(), - declaration: declaration.clone(), - components: score_components, - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_declaration_path_matches() { - let declaration_path = - CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts"); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("project/src/maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("user/project/src/maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("/home/user/project/src/maths.ts").into() - )); - - assert!(!declaration_path_matches_import( - &declaration_path, - &Path::new("other.ts").into() - )); - - assert!(!declaration_path_matches_import( - &declaration_path, - &Path::new("/home/user/project/src/other.ts").into() - )); - } -} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 65623a825c2f7e2db42b98174748e5f04fb91d2a..e316c5a052acd241e7d33356bd0d5dfa5fd075bd 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,335 +1,469 @@ -mod declaration; -mod declaration_scoring; +use crate::assemble_excerpts::assemble_excerpts; +use anyhow::Result; +use collections::HashMap; +use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; +use project::{LocationLink, Project, ProjectPath}; +use serde::{Serialize, Serializer}; +use smallvec::SmallVec; +use std::{ + collections::hash_map, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; +use util::{RangeExt as _, ResultExt}; + +mod assemble_excerpts; +#[cfg(test)] +mod edit_prediction_context_tests; mod excerpt; -mod imports; -mod outline; -mod reference; -mod syntax_index; -pub mod text_similarity; +#[cfg(test)] +mod fake_definition_lsp; -use std::{path::Path, sync::Arc}; +pub use cloud_llm_client::predict_edits_v3::Line; +pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; -use cloud_llm_client::predict_edits_v3; -use collections::HashMap; -use gpui::{App, AppContext as _, Entity, Task}; -use language::BufferSnapshot; -use text::{Point, ToOffset as _}; - -pub use declaration::*; -pub use declaration_scoring::*; -pub use excerpt::*; -pub use imports::*; -pub use reference::*; -pub use syntax_index::*; - -pub use predict_edits_v3::Line; - -#[derive(Clone, Debug, PartialEq)] -pub struct EditPredictionContextOptions { - pub use_imports: bool, - pub excerpt: EditPredictionExcerptOptions, - pub score: EditPredictionScoreOptions, - pub max_retrieved_declarations: u8, +pub struct RelatedExcerptStore { + project: WeakEntity, + related_files: Vec, + cache: HashMap>, + update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, +} + +pub enum RelatedExcerptStoreEvent { + StartedRefresh, + FinishedRefresh { + cache_hit_count: usize, + cache_miss_count: usize, + mean_definition_latency: Duration, + max_definition_latency: Duration, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Identifier { + pub name: String, + pub range: Range, +} + +enum DefinitionTask { + CacheHit(Arc), + CacheMiss(Task>>>), +} + +#[derive(Debug)] +struct CacheEntry { + definitions: SmallVec<[CachedDefinition; 1]>, } #[derive(Clone, Debug)] -pub struct EditPredictionContext { - pub excerpt: EditPredictionExcerpt, - pub excerpt_text: EditPredictionExcerptText, - pub cursor_point: Point, - pub declarations: Vec, +struct CachedDefinition { + path: ProjectPath, + buffer: Entity, + anchor_range: Range, +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedFile { + #[serde(serialize_with = "serialize_project_path")] + pub path: ProjectPath, + #[serde(skip)] + pub buffer: WeakEntity, + pub excerpts: Vec, + pub max_row: u32, } -impl EditPredictionContext { - pub fn gather_context_in_background( - cursor_point: Point, - buffer: BufferSnapshot, - options: EditPredictionContextOptions, - syntax_index: Option>, - cx: &mut App, - ) -> Task> { - let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } +impl RelatedFile { + pub fn merge_excerpts(&mut self) { + self.excerpts.sort_unstable_by(|a, b| { + a.point_range + .start + .cmp(&b.point_range.start) + .then(b.point_range.end.cmp(&a.point_range.end)) }); - if let Some(syntax_index) = syntax_index { - let index_state = - syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state())); - cx.background_spawn(async move { - let parent_abs_path = parent_abs_path.as_deref(); - let index_state = index_state.upgrade()?; - let index_state = index_state.lock().await; - Self::gather_context( - cursor_point, - &buffer, - parent_abs_path, - &options, - Some(&index_state), - ) - }) - } else { - cx.background_spawn(async move { - let parent_abs_path = parent_abs_path.as_deref(); - Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None) - }) + let mut index = 1; + while index < self.excerpts.len() { + if self.excerpts[index - 1] + .point_range + .end + .cmp(&self.excerpts[index].point_range.start) + .is_ge() + { + let removed = self.excerpts.remove(index); + if removed + .point_range + .end + .cmp(&self.excerpts[index - 1].point_range.end) + .is_gt() + { + self.excerpts[index - 1].point_range.end = removed.point_range.end; + self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; + } + } else { + index += 1; + } } } +} - pub fn gather_context( - cursor_point: Point, - buffer: &BufferSnapshot, - parent_abs_path: Option<&Path>, - options: &EditPredictionContextOptions, - index_state: Option<&SyntaxIndexState>, - ) -> Option { - let imports = if options.use_imports { - Imports::gather(&buffer, parent_abs_path) - } else { - Imports::default() - }; - Self::gather_context_with_references_fn( - cursor_point, - buffer, - &imports, - options, - index_state, - references_in_excerpt, - ) - } +#[derive(Clone, Debug, Serialize)] +pub struct RelatedExcerpt { + #[serde(skip)] + pub anchor_range: Range, + #[serde(serialize_with = "serialize_point_range")] + pub point_range: Range, + #[serde(serialize_with = "serialize_rope")] + pub text: Rope, +} - pub fn gather_context_with_references_fn( - cursor_point: Point, - buffer: &BufferSnapshot, - imports: &Imports, - options: &EditPredictionContextOptions, - index_state: Option<&SyntaxIndexState>, - get_references: impl FnOnce( - &EditPredictionExcerpt, - &EditPredictionExcerptText, - &BufferSnapshot, - ) -> HashMap>, - ) -> Option { - let excerpt = EditPredictionExcerpt::select_from_buffer( - cursor_point, - buffer, - &options.excerpt, - index_state, - )?; - let excerpt_text = excerpt.text(buffer); - - let declarations = if options.max_retrieved_declarations > 0 - && let Some(index_state) = index_state - { - let excerpt_occurrences = - text_similarity::Occurrences::within_string(&excerpt_text.body); - - let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0); - let adjacent_end = Point::new(cursor_point.row + 1, 0); - let adjacent_occurrences = text_similarity::Occurrences::within_string( - &buffer - .text_for_range(adjacent_start..adjacent_end) - .collect::(), - ); +fn serialize_project_path( + project_path: &ProjectPath, + serializer: S, +) -> Result { + project_path.path.serialize(serializer) +} - let cursor_offset_in_file = cursor_point.to_offset(buffer); +fn serialize_rope(rope: &Rope, serializer: S) -> Result { + rope.to_string().serialize(serializer) +} - let references = get_references(&excerpt, &excerpt_text, buffer); +fn serialize_point_range( + range: &Range, + serializer: S, +) -> Result { + [ + [range.start.row, range.start.column], + [range.end.row, range.end.column], + ] + .serialize(serializer) +} - let mut declarations = scored_declarations( - &options.score, - &index_state, - &excerpt, - &excerpt_occurrences, - &adjacent_occurrences, - &imports, - references, - cursor_offset_in_file, - buffer, - ); - // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way - declarations.truncate(options.max_retrieved_declarations as usize); - declarations - } else { - vec![] - }; +const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); + +impl EventEmitter for RelatedExcerptStore {} + +impl RelatedExcerptStore { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity, Anchor)>(); + cx.spawn(async move |this, cx| { + let executor = cx.background_executor().clone(); + while let Some((mut buffer, mut position)) = update_rx.next().await { + let mut timer = executor.timer(DEBOUNCE_DURATION).fuse(); + loop { + futures::select_biased! { + next = update_rx.next() => { + if let Some((new_buffer, new_position)) = next { + buffer = new_buffer; + position = new_position; + timer = executor.timer(DEBOUNCE_DURATION).fuse(); + } else { + return anyhow::Ok(()); + } + } + _ = timer => break, + } + } - Some(Self { - excerpt, - excerpt_text, - cursor_point, - declarations, + Self::fetch_excerpts(this.clone(), buffer, position, cx).await?; + } + anyhow::Ok(()) }) + .detach_and_log_err(cx); + + RelatedExcerptStore { + project: project.downgrade(), + update_tx, + related_files: Vec::new(), + cache: Default::default(), + } } -} -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use gpui::{Entity, TestAppContext}; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - use crate::{EditPredictionExcerptOptions, SyntaxIndex}; - - #[gpui::test] - async fn test_call_site(cx: &mut TestAppContext) { - let (project, index, _rust_lang_id) = init_test(cx).await; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - // first process_data call site - let cursor_point = language::Point::new(8, 21); - let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - - let context = cx - .update(|cx| { - EditPredictionContext::gather_context_in_background( - cursor_point, - buffer_snapshot, - EditPredictionContextOptions { - use_imports: true, - excerpt: EditPredictionExcerptOptions { - max_bytes: 60, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps: true, - }, - max_retrieved_declarations: u8::MAX, - }, - Some(index.clone()), - cx, - ) - }) - .await - .unwrap(); - - let mut snippet_identifiers = context - .declarations - .iter() - .map(|snippet| snippet.identifier.name.as_ref()) - .collect::>(); - snippet_identifiers.sort(); - assert_eq!(snippet_identifiers, vec!["main", "process_data"]); - drop(buffer); + pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { + self.update_tx.unbounded_send((buffer, position)).ok(); } - async fn init_test( - cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); + pub fn related_files(&self) -> &[RelatedFile] { + &self.related_files + } - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "a.rs": indoc! {r#" - fn main() { - let x = 1; - let y = 2; - let z = add(x, y); - println!("Result: {}", z); - } + async fn fetch_excerpts( + this: WeakEntity, + buffer: Entity, + position: Anchor, + cx: &mut AsyncApp, + ) -> Result<()> { + let (project, snapshot) = this.read_with(cx, |this, cx| { + (this.project.upgrade(), buffer.read(cx).snapshot()) + })?; + let Some(project) = project else { + return Ok(()); + }; - fn add(a: i32, b: i32) -> i32 { - a + b - } - "#}, - "b.rs": indoc! {" - pub struct Config { - pub name: String, - pub value: i32, - } + let file = snapshot.file().cloned(); + if let Some(file) = &file { + log::debug!("retrieving_context buffer:{}", file.path().as_unix_str()); + } - impl Config { - pub fn new(name: String, value: i32) -> Self { - Config { name, value } + this.update(cx, |_, cx| { + cx.emit(RelatedExcerptStoreEvent::StartedRefresh); + })?; + + let identifiers = cx + .background_spawn(async move { identifiers_for_position(&snapshot, position) }) + .await; + + let async_cx = cx.clone(); + let start_time = Instant::now(); + let futures = this.update(cx, |this, cx| { + identifiers + .into_iter() + .filter_map(|identifier| { + let task = if let Some(entry) = this.cache.get(&identifier) { + DefinitionTask::CacheHit(entry.clone()) + } else { + DefinitionTask::CacheMiss( + this.project + .update(cx, |project, cx| { + project.definitions(&buffer, identifier.range.start, cx) + }) + .ok()?, + ) + }; + + let cx = async_cx.clone(); + let project = project.clone(); + Some(async move { + match task { + DefinitionTask::CacheHit(cache_entry) => { + Some((identifier, cache_entry, None)) + } + DefinitionTask::CacheMiss(task) => { + let locations = task.await.log_err()??; + let duration = start_time.elapsed(); + cx.update(|cx| { + ( + identifier, + Arc::new(CacheEntry { + definitions: locations + .into_iter() + .filter_map(|location| { + process_definition(location, &project, cx) + }) + .collect(), + }), + Some(duration), + ) + }) + .ok() + } } - } - "}, - "c.rs": indoc! {r#" - use std::collections::HashMap; - - fn main() { - let args: Vec = std::env::args().collect(); - let data: Vec = args[1..] - .iter() - .filter_map(|s| s.parse().ok()) - .collect(); - let result = process_data(data); - println!("{:?}", result); - } + }) + }) + .collect::>() + })?; + + let mut cache_hit_count = 0; + let mut cache_miss_count = 0; + let mut mean_definition_latency = Duration::ZERO; + let mut max_definition_latency = Duration::ZERO; + let mut new_cache = HashMap::default(); + new_cache.reserve(futures.len()); + for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() { + new_cache.insert(identifier, entry); + if let Some(duration) = duration { + cache_miss_count += 1; + mean_definition_latency += duration; + max_definition_latency = max_definition_latency.max(duration); + } else { + cache_hit_count += 1; + } + } + mean_definition_latency /= cache_miss_count.max(1) as u32; - fn process_data(data: Vec) -> HashMap { - let mut counts = HashMap::new(); - for value in data { - *counts.entry(value).or_insert(0) += 1; - } - counts - } + let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; - #[cfg(test)] - mod tests { - use super::*; + if let Some(file) = &file { + log::debug!( + "finished retrieving context buffer:{}, latency:{:?}", + file.path().as_unix_str(), + start_time.elapsed() + ); + } - #[test] - fn test_process_data() { - let data = vec![1, 2, 2, 3]; - let result = process_data(data); - assert_eq!(result.get(&2), Some(&2)); - } - } - "#} - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let lang = rust_lang(); - let lang_id = lang.id(); - language_registry.add(Arc::new(lang)); - - let file_indexing_parallelism = 2; - let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx)); - cx.run_until_parked(); - - (project, index, lang_id) + this.update(cx, |this, cx| { + this.cache = new_cache; + this.related_files = related_files; + cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + }); + })?; + + anyhow::Ok(()) } +} - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() +async fn rebuild_related_files( + new_entries: HashMap>, + cx: &mut AsyncApp, +) -> Result<(HashMap>, Vec)> { + let mut snapshots = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { + definition + .buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + e.insert( + definition + .buffer + .read_with(cx, |buffer, _| buffer.snapshot())?, + ); + } + } } + + Ok(cx + .background_spawn(async move { + let mut files = Vec::::new(); + let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); + let mut paths_by_buffer = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else { + continue; + }; + paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone()); + ranges_by_buffer + .entry(definition.buffer.clone()) + .or_default() + .push(definition.anchor_range.to_point(snapshot)); + } + } + + for (buffer, ranges) in ranges_by_buffer { + let Some(snapshot) = snapshots.get(&buffer.entity_id()) else { + continue; + }; + let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else { + continue; + }; + let excerpts = assemble_excerpts(snapshot, ranges); + files.push(RelatedFile { + path: project_path.clone(), + buffer: buffer.downgrade(), + excerpts, + max_row: snapshot.max_point().row, + }); + } + + files.sort_by_key(|file| file.path.clone()); + (new_entries, files) + }) + .await) +} + +fn process_definition( + location: LocationLink, + project: &Entity, + cx: &mut App, +) -> Option { + let buffer = location.target.buffer.read(cx); + let anchor_range = location.target.range; + let file = buffer.file()?; + let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?; + if worktree.read(cx).is_single_file() { + return None; + } + Some(CachedDefinition { + path: ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }, + buffer: location.target.buffer, + anchor_range, + }) +} + +/// Gets all of the identifiers that are present in the given line, and its containing +/// outline items. +fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec { + let offset = position.to_offset(buffer); + let point = buffer.offset_to_point(offset); + + let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point()); + let mut ranges = vec![line_range.to_offset(&buffer)]; + + // Include the range of the outline item itself, but not its body. + let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); + for item in outline_items { + if let Some(body_range) = item.body_range(&buffer) { + ranges.push(item.range.start..body_range.start.to_offset(&buffer)); + } else { + ranges.push(item.range.clone()); + } + } + + ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + ranges.dedup_by(|a, b| { + if a.start <= b.end { + b.start = b.start.min(a.start); + b.end = b.end.max(a.end); + true + } else { + false + } + }); + + let mut identifiers = Vec::new(); + let outer_range = + ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end); + + let mut captures = buffer + .syntax + .captures(outer_range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + for range in ranges { + captures.set_byte_range(range.start..outer_range.end); + + let mut last_range = None; + while let Some(capture) = captures.peek() { + let node_range = capture.node.byte_range(); + if node_range.start > range.end { + break; + } + let config = captures.grammars()[capture.grammar_index] + .highlights_config + .as_ref(); + + if let Some(config) = config + && config.identifier_capture_indices.contains(&capture.index) + && range.contains_inclusive(&node_range) + && Some(&node_range) != last_range.as_ref() + { + let name = buffer.text_for_range(node_range.clone()).collect(); + identifiers.push(Identifier { + range: buffer.anchor_after(node_range.start) + ..buffer.anchor_before(node_range.end), + name, + }); + last_range = Some(node_range); + } + + captures.advance(); + } + } + + identifiers } diff --git a/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs similarity index 100% rename from crates/edit_prediction_context2/src/edit_prediction_context_tests.rs rename to crates/edit_prediction_context/src/edit_prediction_context_tests.rs diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index 7a4bb73edfa131b620a930d7f0e1c0da77e0afe6..55a3d8f03b277d0ce40f1d2ac947c55abf93f1c9 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -1,11 +1,9 @@ -use language::{BufferSnapshot, LanguageId}; +use cloud_llm_client::predict_edits_v3::Line; +use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _}; use std::ops::Range; -use text::{Point, ToOffset as _, ToPoint as _}; use tree_sitter::{Node, TreeCursor}; use util::RangeExt; -use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState}; - // TODO: // // - Test parent signatures @@ -31,19 +29,16 @@ pub struct EditPredictionExcerptOptions { pub target_before_cursor_over_total_bytes: f32, } -// TODO: consider merging these #[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, pub line_range: Range, - pub parent_declarations: Vec<(DeclarationId, Range)>, pub size: usize, } #[derive(Debug, Clone)] pub struct EditPredictionExcerptText { pub body: String, - pub parent_signatures: Vec, pub language_id: Option, } @@ -52,17 +47,8 @@ impl EditPredictionExcerpt { let body = buffer .text_for_range(self.range.clone()) .collect::(); - let parent_signatures = self - .parent_declarations - .iter() - .map(|(_, range)| buffer.text_for_range(range.clone()).collect::()) - .collect(); let language_id = buffer.language().map(|l| l.id()); - EditPredictionExcerptText { - body, - parent_signatures, - language_id, - } + EditPredictionExcerptText { body, language_id } } /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based @@ -79,7 +65,6 @@ impl EditPredictionExcerpt { query_point: Point, buffer: &BufferSnapshot, options: &EditPredictionExcerptOptions, - syntax_index: Option<&SyntaxIndexState>, ) -> Option { if buffer.len() <= options.max_bytes { log::debug!( @@ -89,11 +74,7 @@ impl EditPredictionExcerpt { ); let offset_range = 0..buffer.len(); let line_range = Line(0)..Line(buffer.max_point().row); - return Some(EditPredictionExcerpt::new( - offset_range, - line_range, - Vec::new(), - )); + return Some(EditPredictionExcerpt::new(offset_range, line_range)); } let query_offset = query_point.to_offset(buffer); @@ -104,19 +85,10 @@ impl EditPredictionExcerpt { return None; } - let parent_declarations = if let Some(syntax_index) = syntax_index { - syntax_index - .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone()) - .collect() - } else { - Vec::new() - }; - let excerpt_selector = ExcerptSelector { query_offset, query_range, query_line_range: Line(query_line_range.start)..Line(query_line_range.end), - parent_declarations: &parent_declarations, buffer, options, }; @@ -139,20 +111,10 @@ impl EditPredictionExcerpt { excerpt_selector.select_lines() } - fn new( - range: Range, - line_range: Range, - parent_declarations: Vec<(DeclarationId, Range)>, - ) -> Self { - let size = range.len() - + parent_declarations - .iter() - .map(|(_, range)| range.len()) - .sum::(); + fn new(range: Range, line_range: Range) -> Self { Self { + size: range.len(), range, - parent_declarations, - size, line_range, } } @@ -162,14 +124,7 @@ impl EditPredictionExcerpt { // this is an issue because parent_signature_ranges may be incorrect log::error!("bug: with_expanded_range called with disjoint range"); } - let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len()); - for (declaration_id, range) in &self.parent_declarations { - if !range.contains_inclusive(&new_range) { - break; - } - parent_declarations.push((*declaration_id, range.clone())); - } - Self::new(new_range, new_line_range, parent_declarations) + Self::new(new_range, new_line_range) } fn parent_signatures_size(&self) -> usize { @@ -181,7 +136,6 @@ struct ExcerptSelector<'a> { query_offset: usize, query_range: Range, query_line_range: Range, - parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)], buffer: &'a BufferSnapshot, options: &'a EditPredictionExcerptOptions, } @@ -409,13 +363,7 @@ impl<'a> ExcerptSelector<'a> { } fn make_excerpt(&self, range: Range, line_range: Range) -> EditPredictionExcerpt { - let parent_declarations = self - .parent_declarations - .iter() - .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range)) - .map(|(id, declaration)| (*id, declaration.signature_range.clone())) - .collect(); - EditPredictionExcerpt::new(range, line_range, parent_declarations) + EditPredictionExcerpt::new(range, line_range) } /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. @@ -506,9 +454,8 @@ mod tests { let buffer = create_buffer(&text, cx); let cursor_point = cursor.to_point(&buffer); - let excerpt = - EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None) - .expect("Should select an excerpt"); + let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) + .expect("Should select an excerpt"); pretty_assertions::assert_eq!( generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), generate_marked_text(&text, &[expected_excerpt], false) diff --git a/crates/edit_prediction_context2/src/fake_definition_lsp.rs b/crates/edit_prediction_context/src/fake_definition_lsp.rs similarity index 100% rename from crates/edit_prediction_context2/src/fake_definition_lsp.rs rename to crates/edit_prediction_context/src/fake_definition_lsp.rs diff --git a/crates/edit_prediction_context/src/imports.rs b/crates/edit_prediction_context/src/imports.rs deleted file mode 100644 index 70f175159340ddb9a6f26f23db0c1b3c843e7b96..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/imports.rs +++ /dev/null @@ -1,1319 +0,0 @@ -use collections::HashMap; -use language::BufferSnapshot; -use language::ImportsConfig; -use language::Language; -use std::ops::Deref; -use std::path::Path; -use std::sync::Arc; -use std::{borrow::Cow, ops::Range}; -use text::OffsetRangeExt as _; -use util::RangeExt; -use util::paths::PathStyle; - -use crate::Identifier; -use crate::text_similarity::Occurrences; - -// TODO: Write documentation for extension authors. The @import capture must match before or in the -// same pattern as all all captures it contains - -// Future improvements to consider: -// -// * Distinguish absolute vs relative paths in captures. `#include "maths.h"` is relative whereas -// `#include ` is not. -// -// * Provide the name used when importing whole modules (see tests with "named_module" in the name). -// To be useful, will require parsing of identifier qualification. -// -// * Scoping for imports that aren't at the top level -// -// * Only scan a prefix of the file, when possible. This could look like having query matches that -// indicate it reached a declaration that is not allowed in the import section. -// -// * Support directly parsing to occurrences instead of storing namespaces / paths. Types should be -// generic on this, so that tests etc can still use strings. Could do similar in syntax index. -// -// * Distinguish different types of namespaces when known. E.g. "name.type" capture. Once capture -// names are more open-ended like this may make sense to build and cache a jump table (direct -// dispatch from capture index). -// -// * There are a few "Language specific:" comments on behavior that gets applied to all languages. -// Would be cleaner to be conditional on the language or otherwise configured. - -#[derive(Debug, Clone, Default)] -pub struct Imports { - pub identifier_to_imports: HashMap>, - pub wildcard_modules: Vec, -} - -#[derive(Debug, Clone)] -pub enum Import { - Direct { - module: Module, - }, - Alias { - module: Module, - external_identifier: Identifier, - }, -} - -#[derive(Debug, Clone)] -pub enum Module { - SourceExact(Arc), - SourceFuzzy(Arc), - Namespace(Namespace), -} - -impl Module { - fn empty() -> Self { - Module::Namespace(Namespace::default()) - } - - fn push_range( - &mut self, - range: &ModuleRange, - snapshot: &BufferSnapshot, - language: &Language, - parent_abs_path: Option<&Path>, - ) -> usize { - if range.is_empty() { - return 0; - } - - match range { - ModuleRange::Source(range) => { - if let Self::Namespace(namespace) = self - && namespace.0.is_empty() - { - let path = snapshot.text_for_range(range.clone()).collect::>(); - - let path = if let Some(strip_regex) = - language.config().import_path_strip_regex.as_ref() - { - strip_regex.replace_all(&path, "") - } else { - path - }; - - let path = Path::new(path.as_ref()); - if (path.starts_with(".") || path.starts_with("..")) - && let Some(parent_abs_path) = parent_abs_path - && let Ok(abs_path) = - util::paths::normalize_lexically(&parent_abs_path.join(path)) - { - *self = Self::SourceExact(abs_path.into()); - } else { - *self = Self::SourceFuzzy(path.into()); - }; - } else if matches!(self, Self::SourceExact(_)) - || matches!(self, Self::SourceFuzzy(_)) - { - log::warn!("bug in imports query: encountered multiple @source matches"); - } else { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - } - ModuleRange::Namespace(range) => { - if let Self::Namespace(namespace) = self { - let segment = range_text(snapshot, range); - if language.config().ignored_import_segments.contains(&segment) { - return 0; - } else { - namespace.0.push(segment); - return 1; - } - } else { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - } - } - 0 - } -} - -#[derive(Debug, Clone)] -enum ModuleRange { - Source(Range), - Namespace(Range), -} - -impl Deref for ModuleRange { - type Target = Range; - - fn deref(&self) -> &Self::Target { - match self { - ModuleRange::Source(range) => range, - ModuleRange::Namespace(range) => range, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct Namespace(pub Vec>); - -impl Namespace { - pub fn occurrences(&self) -> Occurrences { - Occurrences::from_identifiers(&self.0) - } -} - -impl Imports { - pub fn gather(snapshot: &BufferSnapshot, parent_abs_path: Option<&Path>) -> Self { - // Query to match different import patterns - let mut matches = snapshot - .syntax - .matches(0..snapshot.len(), &snapshot.text, |grammar| { - grammar.imports_config().map(|imports| &imports.query) - }); - - let mut detached_nodes: Vec = Vec::new(); - let mut identifier_to_imports = HashMap::default(); - let mut wildcard_modules = Vec::new(); - let mut import_range = None; - - while let Some(query_match) = matches.peek() { - let ImportsConfig { - query: _, - import_ix, - name_ix, - namespace_ix, - source_ix, - list_ix, - wildcard_ix, - alias_ix, - } = matches.grammars()[query_match.grammar_index] - .imports_config() - .unwrap(); - - let mut new_import_range = None; - let mut alias_range = None; - let mut modules = Vec::new(); - let mut content: Option<(Range, ContentKind)> = None; - for capture in query_match.captures { - let capture_range = capture.node.byte_range(); - - if capture.index == *import_ix { - new_import_range = Some(capture_range); - } else if Some(capture.index) == *namespace_ix { - modules.push(ModuleRange::Namespace(capture_range)); - } else if Some(capture.index) == *source_ix { - modules.push(ModuleRange::Source(capture_range)); - } else if Some(capture.index) == *alias_ix { - alias_range = Some(capture_range); - } else { - let mut found_content = None; - if Some(capture.index) == *name_ix { - found_content = Some((capture_range, ContentKind::Name)); - } else if Some(capture.index) == *list_ix { - found_content = Some((capture_range, ContentKind::List)); - } else if Some(capture.index) == *wildcard_ix { - found_content = Some((capture_range, ContentKind::Wildcard)); - } - if let Some((found_content_range, found_kind)) = found_content { - if let Some((_, old_kind)) = content { - let point = found_content_range.to_point(snapshot); - log::warn!( - "bug in {} imports query: unexpected multiple captures of {} and {} ({}:{}:{})", - query_match.language.name(), - old_kind.capture_name(), - found_kind.capture_name(), - snapshot - .file() - .map(|p| p.path().display(PathStyle::Posix)) - .unwrap_or_default(), - point.start.row + 1, - point.start.column + 1 - ); - } - content = Some((found_content_range, found_kind)); - } - } - } - - if let Some(new_import_range) = new_import_range { - log::trace!("starting new import {:?}", new_import_range); - Self::gather_from_import_statement( - &detached_nodes, - &snapshot, - parent_abs_path, - &mut identifier_to_imports, - &mut wildcard_modules, - ); - detached_nodes.clear(); - import_range = Some(new_import_range.clone()); - } - - if let Some((content, content_kind)) = content { - if import_range - .as_ref() - .is_some_and(|import_range| import_range.contains_inclusive(&content)) - { - detached_nodes.push(DetachedNode { - modules, - content: content.clone(), - content_kind, - alias: alias_range.unwrap_or(0..0), - language: query_match.language.clone(), - }); - } else { - log::trace!( - "filtered out match not inside import range: {content_kind:?} at {content:?}" - ); - } - } - - matches.advance(); - } - - Self::gather_from_import_statement( - &detached_nodes, - &snapshot, - parent_abs_path, - &mut identifier_to_imports, - &mut wildcard_modules, - ); - - Imports { - identifier_to_imports, - wildcard_modules, - } - } - - fn gather_from_import_statement( - detached_nodes: &[DetachedNode], - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - identifier_to_imports: &mut HashMap>, - wildcard_modules: &mut Vec, - ) { - let mut trees = Vec::new(); - - for detached_node in detached_nodes { - if let Some(node) = Self::attach_node(detached_node.into(), &mut trees) { - trees.push(node); - } - log::trace!( - "Attached node to tree\n{:#?}\nAttach result:\n{:#?}", - detached_node, - trees - .iter() - .map(|tree| tree.debug(snapshot)) - .collect::>() - ); - } - - for tree in &trees { - let mut module = Module::empty(); - Self::gather_from_tree( - tree, - snapshot, - parent_abs_path, - &mut module, - identifier_to_imports, - wildcard_modules, - ); - } - } - - fn attach_node(mut node: ImportTree, trees: &mut Vec) -> Option { - let mut tree_index = 0; - while tree_index < trees.len() { - let tree = &mut trees[tree_index]; - if !node.content.is_empty() && node.content == tree.content { - // multiple matches can apply to the same name/list/wildcard. This keeps the queries - // simpler by combining info from these matches. - if tree.module.is_empty() { - tree.module = node.module; - tree.module_children = node.module_children; - } - if tree.alias.is_empty() { - tree.alias = node.alias; - } - return None; - } else if !node.module.is_empty() && node.module.contains_inclusive(&tree.range()) { - node.module_children.push(trees.remove(tree_index)); - continue; - } else if !node.content.is_empty() && node.content.contains_inclusive(&tree.content) { - node.content_children.push(trees.remove(tree_index)); - continue; - } else if !tree.content.is_empty() && tree.content.contains_inclusive(&node.content) { - if let Some(node) = Self::attach_node(node, &mut tree.content_children) { - tree.content_children.push(node); - } - return None; - } - tree_index += 1; - } - Some(node) - } - - fn gather_from_tree( - tree: &ImportTree, - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - current_module: &mut Module, - identifier_to_imports: &mut HashMap>, - wildcard_modules: &mut Vec, - ) { - let mut pop_count = 0; - - if tree.module_children.is_empty() { - pop_count += - current_module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); - } else { - for child in &tree.module_children { - pop_count += Self::extend_namespace_from_tree( - child, - snapshot, - parent_abs_path, - current_module, - ); - } - }; - - if tree.content_children.is_empty() && !tree.content.is_empty() { - match tree.content_kind { - ContentKind::Name | ContentKind::List => { - if tree.alias.is_empty() { - identifier_to_imports - .entry(Identifier { - language_id: tree.language.id(), - name: range_text(snapshot, &tree.content), - }) - .or_default() - .push(Import::Direct { - module: current_module.clone(), - }); - } else { - let alias_name: Arc = range_text(snapshot, &tree.alias); - let external_name = range_text(snapshot, &tree.content); - // Language specific: skip "_" aliases for Rust - if alias_name.as_ref() != "_" { - identifier_to_imports - .entry(Identifier { - language_id: tree.language.id(), - name: alias_name, - }) - .or_default() - .push(Import::Alias { - module: current_module.clone(), - external_identifier: Identifier { - language_id: tree.language.id(), - name: external_name, - }, - }); - } - } - } - ContentKind::Wildcard => wildcard_modules.push(current_module.clone()), - } - } else { - for child in &tree.content_children { - Self::gather_from_tree( - child, - snapshot, - parent_abs_path, - current_module, - identifier_to_imports, - wildcard_modules, - ); - } - } - - if pop_count > 0 { - match current_module { - Module::SourceExact(_) | Module::SourceFuzzy(_) => { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - Module::Namespace(namespace) => { - namespace.0.drain(namespace.0.len() - pop_count..); - } - } - } - } - - fn extend_namespace_from_tree( - tree: &ImportTree, - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - module: &mut Module, - ) -> usize { - let mut pop_count = 0; - if tree.module_children.is_empty() { - pop_count += module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); - } else { - for child in &tree.module_children { - pop_count += - Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); - } - } - if tree.content_children.is_empty() { - pop_count += module.push_range( - &ModuleRange::Namespace(tree.content.clone()), - snapshot, - &tree.language, - parent_abs_path, - ); - } else { - for child in &tree.content_children { - pop_count += - Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); - } - } - pop_count - } -} - -fn range_text(snapshot: &BufferSnapshot, range: &Range) -> Arc { - snapshot - .text_for_range(range.clone()) - .collect::>() - .into() -} - -#[derive(Debug)] -struct DetachedNode { - modules: Vec, - content: Range, - content_kind: ContentKind, - alias: Range, - language: Arc, -} - -#[derive(Debug, Clone, Copy)] -enum ContentKind { - Name, - Wildcard, - List, -} - -impl ContentKind { - fn capture_name(&self) -> &'static str { - match self { - ContentKind::Name => "name", - ContentKind::Wildcard => "wildcard", - ContentKind::List => "list", - } - } -} - -#[derive(Debug)] -struct ImportTree { - module: ModuleRange, - /// When non-empty, provides namespace / source info which should be used instead of `module`. - module_children: Vec, - content: Range, - /// When non-empty, provides content which should be used instead of `content`. - content_children: Vec, - content_kind: ContentKind, - alias: Range, - language: Arc, -} - -impl ImportTree { - fn range(&self) -> Range { - self.module.start.min(self.content.start)..self.module.end.max(self.content.end) - } - - #[allow(dead_code)] - fn debug<'a>(&'a self, snapshot: &'a BufferSnapshot) -> ImportTreeDebug<'a> { - ImportTreeDebug { - tree: self, - snapshot, - } - } - - fn from_module_range(module: &ModuleRange, language: Arc) -> Self { - ImportTree { - module: module.clone(), - module_children: Vec::new(), - content: 0..0, - content_children: Vec::new(), - content_kind: ContentKind::Name, - alias: 0..0, - language, - } - } -} - -impl From<&DetachedNode> for ImportTree { - fn from(value: &DetachedNode) -> Self { - let module; - let module_children; - match value.modules.len() { - 0 => { - module = ModuleRange::Namespace(0..0); - module_children = Vec::new(); - } - 1 => { - module = value.modules[0].clone(); - module_children = Vec::new(); - } - _ => { - module = ModuleRange::Namespace( - value.modules.first().unwrap().start..value.modules.last().unwrap().end, - ); - module_children = value - .modules - .iter() - .map(|module| ImportTree::from_module_range(module, value.language.clone())) - .collect(); - } - } - - ImportTree { - module, - module_children, - content: value.content.clone(), - content_children: Vec::new(), - content_kind: value.content_kind, - alias: value.alias.clone(), - language: value.language.clone(), - } - } -} - -struct ImportTreeDebug<'a> { - tree: &'a ImportTree, - snapshot: &'a BufferSnapshot, -} - -impl std::fmt::Debug for ImportTreeDebug<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ImportTree") - .field("module_range", &self.tree.module) - .field("module_text", &range_text(self.snapshot, &self.tree.module)) - .field( - "module_children", - &self - .tree - .module_children - .iter() - .map(|child| child.debug(&self.snapshot)) - .collect::>(), - ) - .field("content_range", &self.tree.content) - .field( - "content_text", - &range_text(self.snapshot, &self.tree.content), - ) - .field( - "content_children", - &self - .tree - .content_children - .iter() - .map(|child| child.debug(&self.snapshot)) - .collect::>(), - ) - .field("content_kind", &self.tree.content_kind) - .field("alias_range", &self.tree.alias) - .field("alias_text", &range_text(self.snapshot, &self.tree.alias)) - .finish() - } -} - -#[cfg(test)] -mod test { - use std::path::PathBuf; - use std::sync::{Arc, LazyLock}; - - use super::*; - use collections::HashSet; - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{ - Buffer, Language, LanguageConfig, tree_sitter_python, tree_sitter_rust, - tree_sitter_typescript, - }; - use regex::Regex; - - #[gpui::test] - fn test_rust_simple(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::collections::HashMap;", - &[&["std", "collections", "HashMap"]], - cx, - ); - - check_imports( - &RUST, - "pub use std::collections::HashMap;", - &[&["std", "collections", "HashMap"]], - cx, - ); - - check_imports( - &RUST, - "use std::collections::{HashMap, HashSet};", - &[ - &["std", "collections", "HashMap"], - &["std", "collections", "HashSet"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_nested(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::{any::TypeId, collections::{HashMap, HashSet}};", - &[ - &["std", "any", "TypeId"], - &["std", "collections", "HashMap"], - &["std", "collections", "HashSet"], - ], - cx, - ); - - check_imports( - &RUST, - "use a::b::c::{d::e::F, g::h::I};", - &[ - &["a", "b", "c", "d", "e", "F"], - &["a", "b", "c", "g", "h", "I"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_multiple_imports(cx: &mut TestAppContext) { - check_imports( - &RUST, - indoc! {" - use std::collections::HashMap; - use std::any::{TypeId, Any}; - "}, - &[ - &["std", "collections", "HashMap"], - &["std", "any", "TypeId"], - &["std", "any", "Any"], - ], - cx, - ); - - check_imports( - &RUST, - indoc! {" - use std::collections::HashSet; - - fn main() { - let unqualified = HashSet::new(); - let qualified = std::collections::HashMap::new(); - } - - use std::any::TypeId; - "}, - &[ - &["std", "collections", "HashSet"], - &["std", "any", "TypeId"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_wildcard(cx: &mut TestAppContext) { - check_imports(&RUST, "use prelude::*;", &[&["prelude", "WILDCARD"]], cx); - - check_imports( - &RUST, - "use zed::prelude::*;", - &[&["zed", "prelude", "WILDCARD"]], - cx, - ); - - check_imports(&RUST, "use prelude::{*};", &[&["prelude", "WILDCARD"]], cx); - - check_imports( - &RUST, - "use prelude::{File, *};", - &[&["prelude", "File"], &["prelude", "WILDCARD"]], - cx, - ); - - check_imports( - &RUST, - "use zed::{App, prelude::*};", - &[&["zed", "App"], &["zed", "prelude", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_rust_alias(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::io::Result as IoResult;", - &[&["std", "io", "Result AS IoResult"]], - cx, - ); - } - - #[gpui::test] - fn test_rust_crate_and_super(cx: &mut TestAppContext) { - check_imports(&RUST, "use crate::a::b::c;", &[&["a", "b", "c"]], cx); - check_imports(&RUST, "use super::a::b::c;", &[&["a", "b", "c"]], cx); - // TODO: Consider stripping leading "::". Not done for now because for the text similarity matching usecase this - // is fine. - check_imports(&RUST, "use ::a::b::c;", &[&["::a", "b", "c"]], cx); - } - - #[gpui::test] - fn test_typescript_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import "./maths.js";"#, - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import "../maths.js";"#, - &[&["SOURCE /home/user/maths", "WILDCARD"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import RandomNumberGenerator, { pi as π } from "./maths.js";"#, - &[ - &["SOURCE /home/user/project/maths", "RandomNumberGenerator"], - &["SOURCE /home/user/project/maths", "pi AS π"], - ], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { pi, phi, absolute } from "./maths.js";"#, - &[ - &["SOURCE /home/user/project/maths", "pi"], - &["SOURCE /home/user/project/maths", "phi"], - &["SOURCE /home/user/project/maths", "absolute"], - ], - cx, - ); - - // index.js is removed by import_path_strip_regex - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { pi, phi, absolute } from "./maths/index.js";"#, - &[ - &["SOURCE /home/user/project/maths", "pi"], - &["SOURCE /home/user/project/maths", "phi"], - &["SOURCE /home/user/project/maths", "absolute"], - ], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import type { SomeThing } from "./some-module.js";"#, - &[&["SOURCE /home/user/project/some-module", "SomeThing"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "./some-module.js";"#, - &[ - &["SOURCE /home/user/project/some-module", "SomeThing"], - &["SOURCE /home/user/project/some-module", "OtherThing"], - ], - cx, - ); - - // index.js is removed by import_path_strip_regex - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "./some-module/index.js";"#, - &[ - &["SOURCE /home/user/project/some-module", "SomeThing"], - &["SOURCE /home/user/project/some-module", "OtherThing"], - ], - cx, - ); - - // fuzzy paths - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "@my-app/some-module.js";"#, - &[ - &["SOURCE FUZZY @my-app/some-module", "SomeThing"], - &["SOURCE FUZZY @my-app/some-module", "OtherThing"], - ], - cx, - ); - } - - #[gpui::test] - fn test_typescript_named_module_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import * as math from "./maths.js";"#, - // &[&["/home/user/project/maths.js", "WILDCARD AS math"]], - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import math = require("./maths");"#, - // &[&["/home/user/project/maths", "WILDCARD AS math"]], - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_python_imports(cx: &mut TestAppContext) { - check_imports(&PYTHON, "from math import pi", &[&["math", "pi"]], cx); - - check_imports( - &PYTHON, - "from math import pi, sin, cos", - &[&["math", "pi"], &["math", "sin"], &["math", "cos"]], - cx, - ); - - check_imports(&PYTHON, "from math import *", &[&["math", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "from math import foo.bar.baz", - &[&["math", "foo", "bar", "baz"]], - cx, - ); - - check_imports( - &PYTHON, - "from math import pi as PI", - &[&["math", "pi AS PI"]], - cx, - ); - - check_imports( - &PYTHON, - "from serializers.json import JsonSerializer", - &[&["serializers", "json", "JsonSerializer"]], - cx, - ); - - check_imports( - &PYTHON, - "from custom.serializers import json, xml, yaml", - &[ - &["custom", "serializers", "json"], - &["custom", "serializers", "xml"], - &["custom", "serializers", "yaml"], - ], - cx, - ); - } - - #[gpui::test] - fn test_python_named_module_imports(cx: &mut TestAppContext) { - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - // - // check_imports(&PYTHON, "import math", &[&["math", "WILDCARD as math"]], cx); - // check_imports(&PYTHON, "import math as maths", &[&["math", "WILDCARD AS maths"]], cx); - // - // Something like: - // - // (import_statement - // name: [ - // (dotted_name - // (identifier)* @namespace - // (identifier) @name.module .) - // (aliased_import - // name: (dotted_name - // ((identifier) ".")* @namespace - // (identifier) @name.module .) - // alias: (identifier) @alias) - // ]) @import - - check_imports(&PYTHON, "import math", &[&["math", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "import math as maths", - &[&["math", "WILDCARD"]], - cx, - ); - - check_imports(&PYTHON, "import a.b.c", &[&["a", "b", "c", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "import a.b.c as d", - &[&["a", "b", "c", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_python_package_relative_imports(cx: &mut TestAppContext) { - // TODO: These should provide info about the dir they are relative to, to provide more - // precise resolution. Instead, fuzzy matching is used as usual. - - check_imports(&PYTHON, "from . import math", &[&["math"]], cx); - - check_imports(&PYTHON, "from .a import math", &[&["a", "math"]], cx); - - check_imports( - &PYTHON, - "from ..a.b import math", - &[&["a", "b", "math"]], - cx, - ); - - check_imports( - &PYTHON, - "from ..a.b import *", - &[&["a", "b", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_c_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: Distinguish that these are not relative to current path - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &C, - r#"#include "#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - - // TODO: These should be treated as relative, but don't start with ./ or ../ - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &C, - r#"#include "math.h""#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_cpp_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: Distinguish that these are not relative to current path - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &CPP, - r#"#include "#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - - // TODO: These should be treated as relative, but don't start with ./ or ../ - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &CPP, - r#"#include "math.h""#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_go_imports(cx: &mut TestAppContext) { - check_imports( - &GO, - r#"import . "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - - // not included, these are only for side-effects - check_imports(&GO, r#"import _ "lib/math""#, &[], cx); - } - - #[gpui::test] - fn test_go_named_module_imports(cx: &mut TestAppContext) { - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - - check_imports( - &GO, - r#"import "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - check_imports( - &GO, - r#"import m "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - } - - #[track_caller] - fn check_imports( - language: &Arc, - source: &str, - expected: &[&[&str]], - cx: &mut TestAppContext, - ) { - check_imports_with_file_abs_path(None, language, source, expected, cx); - } - - #[track_caller] - fn check_imports_with_file_abs_path( - parent_abs_path: Option<&Path>, - language: &Arc, - source: &str, - expected: &[&[&str]], - cx: &mut TestAppContext, - ) { - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local(source, cx); - buffer.set_language(Some(language.clone()), cx); - buffer - }); - cx.run_until_parked(); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - - let imports = Imports::gather(&snapshot, parent_abs_path); - let mut actual_symbols = imports - .identifier_to_imports - .iter() - .flat_map(|(identifier, imports)| { - imports - .iter() - .map(|import| import.to_identifier_parts(identifier.name.as_ref())) - }) - .chain( - imports - .wildcard_modules - .iter() - .map(|module| module.to_identifier_parts("WILDCARD")), - ) - .collect::>(); - let mut expected_symbols = expected - .iter() - .map(|expected| expected.iter().map(|s| s.to_string()).collect::>()) - .collect::>(); - actual_symbols.sort(); - expected_symbols.sort(); - if actual_symbols != expected_symbols { - let top_layer = snapshot.syntax_layers().next().unwrap(); - panic!( - "Expected imports: {:?}\n\ - Actual imports: {:?}\n\ - Tree:\n{}", - expected_symbols, - actual_symbols, - tree_to_string(&top_layer.node()), - ); - } - } - - fn tree_to_string(node: &tree_sitter::Node) -> String { - let mut cursor = node.walk(); - let mut result = String::new(); - let mut depth = 0; - 'outer: loop { - result.push_str(&" ".repeat(depth)); - if let Some(field_name) = cursor.field_name() { - result.push_str(field_name); - result.push_str(": "); - } - if cursor.node().is_named() { - result.push_str(cursor.node().kind()); - } else { - result.push('"'); - result.push_str(cursor.node().kind()); - result.push('"'); - } - result.push('\n'); - - if cursor.goto_first_child() { - depth += 1; - continue; - } - if cursor.goto_next_sibling() { - continue; - } - while cursor.goto_parent() { - depth -= 1; - if cursor.goto_next_sibling() { - continue 'outer; - } - } - break; - } - result - } - - static RUST: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - ignored_import_segments: HashSet::from_iter(["crate".into(), "super".into()]), - import_path_strip_regex: Some(Regex::new("/(lib|mod)\\.rs$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/rust/imports.scm")) - .unwrap(), - ) - }); - - static TYPESCRIPT: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "TypeScript".into(), - import_path_strip_regex: Some(Regex::new("(?:/index)?\\.[jt]s$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), - ) - .with_imports_query(include_str!("../../languages/src/typescript/imports.scm")) - .unwrap(), - ) - }); - - static PYTHON: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Python".into(), - import_path_strip_regex: Some(Regex::new("/__init__\\.py$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_python::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/python/imports.scm")) - .unwrap(), - ) - }); - - // TODO: Ideally should use actual language configurations - static C: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "C".into(), - import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_c::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/c/imports.scm")) - .unwrap(), - ) - }); - - static CPP: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "C++".into(), - import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_cpp::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/cpp/imports.scm")) - .unwrap(), - ) - }); - - static GO: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Go".into(), - ..Default::default() - }, - Some(tree_sitter_go::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/go/imports.scm")) - .unwrap(), - ) - }); - - impl Import { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - match self { - Import::Direct { module } => module.to_identifier_parts(identifier), - Import::Alias { - module, - external_identifier: external_name, - } => { - module.to_identifier_parts(&format!("{} AS {}", external_name.name, identifier)) - } - } - } - } - - impl Module { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - match self { - Self::Namespace(namespace) => namespace.to_identifier_parts(identifier), - Self::SourceExact(path) => { - vec![ - format!("SOURCE {}", path.display().to_string().replace("\\", "/")), - identifier.to_string(), - ] - } - Self::SourceFuzzy(path) => { - vec![ - format!( - "SOURCE FUZZY {}", - path.display().to_string().replace("\\", "/") - ), - identifier.to_string(), - ] - } - } - } - } - - impl Namespace { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - self.0 - .iter() - .map(|chunk| chunk.to_string()) - .chain(std::iter::once(identifier.to_string())) - .collect::>() - } - } -} diff --git a/crates/edit_prediction_context/src/outline.rs b/crates/edit_prediction_context/src/outline.rs deleted file mode 100644 index ec02c869dfae4cb861206cb801c285462e734f36..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/outline.rs +++ /dev/null @@ -1,126 +0,0 @@ -use language::{BufferSnapshot, SyntaxMapMatches}; -use std::{cmp::Reverse, ops::Range}; - -use crate::declaration::Identifier; - -// TODO: -// -// * how to handle multiple name captures? for now last one wins -// -// * annotation ranges -// -// * new "signature" capture for outline queries -// -// * Check parent behavior of "int x, y = 0" declarations in a test - -pub struct OutlineDeclaration { - pub parent_index: Option, - pub identifier: Identifier, - pub item_range: Range, - pub signature_range: Range, -} - -pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec { - declarations_overlapping_range(0..buffer.len(), buffer) -} - -pub fn declarations_overlapping_range( - range: Range, - buffer: &BufferSnapshot, -) -> Vec { - let mut declarations = OutlineIterator::new(range, buffer).collect::>(); - declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end))); - - let mut parent_stack: Vec<(usize, Range)> = Vec::new(); - for (index, declaration) in declarations.iter_mut().enumerate() { - while let Some((top_parent_index, top_parent_range)) = parent_stack.last() { - if declaration.item_range.start >= top_parent_range.end { - parent_stack.pop(); - } else { - declaration.parent_index = Some(*top_parent_index); - break; - } - } - parent_stack.push((index, declaration.item_range.clone())); - } - declarations -} - -/// Iterates outline items without being ordered w.r.t. nested items and without populating -/// `parent`. -pub struct OutlineIterator<'a> { - buffer: &'a BufferSnapshot, - matches: SyntaxMapMatches<'a>, -} - -impl<'a> OutlineIterator<'a> { - pub fn new(range: Range, buffer: &'a BufferSnapshot) -> Self { - let matches = buffer.syntax.matches(range, &buffer.text, |grammar| { - grammar.outline_config.as_ref().map(|c| &c.query) - }); - - Self { buffer, matches } - } -} - -impl<'a> Iterator for OutlineIterator<'a> { - type Item = OutlineDeclaration; - - fn next(&mut self) -> Option { - while let Some(mat) = self.matches.peek() { - let config = self.matches.grammars()[mat.grammar_index] - .outline_config - .as_ref() - .unwrap(); - - let mut name_range = None; - let mut item_range = None; - let mut signature_start = None; - let mut signature_end = None; - - let mut add_to_signature = |range: Range| { - if signature_start.is_none() { - signature_start = Some(range.start); - } - signature_end = Some(range.end); - }; - - for capture in mat.captures { - let range = capture.node.byte_range(); - if capture.index == config.name_capture_ix { - name_range = Some(range.clone()); - add_to_signature(range); - } else if Some(capture.index) == config.context_capture_ix - || Some(capture.index) == config.extra_context_capture_ix - { - add_to_signature(range); - } else if capture.index == config.item_capture_ix { - item_range = Some(range.clone()); - } - } - - let language_id = mat.language.id(); - self.matches.advance(); - - if let Some(name_range) = name_range - && let Some(item_range) = item_range - && let Some(signature_start) = signature_start - && let Some(signature_end) = signature_end - { - let name = self - .buffer - .text_for_range(name_range) - .collect::() - .into(); - - return Some(OutlineDeclaration { - identifier: Identifier { name, language_id }, - item_range: item_range, - signature_range: signature_start..signature_end, - parent_index: None, - }); - } - } - None - } -} diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs deleted file mode 100644 index 699adf1d8036802a7a4b9e34ca8e8094e4f97458..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/reference.rs +++ /dev/null @@ -1,173 +0,0 @@ -use collections::HashMap; -use language::BufferSnapshot; -use std::ops::Range; -use util::RangeExt; - -use crate::{ - declaration::Identifier, - excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, -}; - -#[derive(Debug, Clone)] -pub struct Reference { - pub identifier: Identifier, - pub range: Range, - pub region: ReferenceRegion, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum ReferenceRegion { - Breadcrumb, - Nearby, -} - -pub fn references_in_excerpt( - excerpt: &EditPredictionExcerpt, - excerpt_text: &EditPredictionExcerptText, - snapshot: &BufferSnapshot, -) -> HashMap> { - let mut references = references_in_range( - excerpt.range.clone(), - excerpt_text.body.as_str(), - ReferenceRegion::Nearby, - snapshot, - ); - - for ((_, range), text) in excerpt - .parent_declarations - .iter() - .zip(excerpt_text.parent_signatures.iter()) - { - references.extend(references_in_range( - range.clone(), - text.as_str(), - ReferenceRegion::Breadcrumb, - snapshot, - )); - } - - let mut identifier_to_references: HashMap> = HashMap::default(); - for reference in references { - identifier_to_references - .entry(reference.identifier.clone()) - .or_insert_with(Vec::new) - .push(reference); - } - identifier_to_references -} - -/// Finds all nodes which have a "variable" match from the highlights query within the offset range. -pub fn references_in_range( - range: Range, - range_text: &str, - reference_region: ReferenceRegion, - buffer: &BufferSnapshot, -) -> Vec { - let mut matches = buffer - .syntax - .matches(range.clone(), &buffer.text, |grammar| { - grammar - .highlights_config - .as_ref() - .map(|config| &config.query) - }); - - let mut references = Vec::new(); - let mut last_added_range = None; - while let Some(mat) = matches.peek() { - let config = matches.grammars()[mat.grammar_index] - .highlights_config - .as_ref(); - - if let Some(config) = config { - for capture in mat.captures { - if config.identifier_capture_indices.contains(&capture.index) { - let node_range = capture.node.byte_range(); - - // sometimes multiple highlight queries match - this deduplicates them - if Some(node_range.clone()) == last_added_range { - continue; - } - - if !range.contains_inclusive(&node_range) { - continue; - } - - let identifier_text = - &range_text[node_range.start - range.start..node_range.end - range.start]; - - references.push(Reference { - identifier: Identifier { - name: identifier_text.into(), - language_id: mat.language.id(), - }, - range: node_range.clone(), - region: reference_region, - }); - last_added_range = Some(node_range); - } - } - } - - matches.advance(); - } - references -} - -#[cfg(test)] -mod test { - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - - use crate::reference::{ReferenceRegion, references_in_range}; - - #[gpui::test] - fn test_identifier_node_truncated(cx: &mut TestAppContext) { - let code = indoc! { r#" - fn main() { - add(1, 2); - } - - fn add(a: i32, b: i32) -> i32 { - a + b - } - "# }; - let buffer = create_buffer(code, cx); - - let range = 0..35; - let references = references_in_range( - range.clone(), - &code[range], - ReferenceRegion::Breadcrumb, - &buffer, - ); - assert_eq!(references.len(), 2); - assert_eq!(references[0].identifier.name.as_ref(), "main"); - assert_eq!(references[1].identifier.name.as_ref(), "add"); - } - - fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = - cx.new(|cx| language::Buffer::local(text, cx).with_language(rust_lang().into(), cx)); - buffer.read_with(cx, |buffer, _| buffer.snapshot()) - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs deleted file mode 100644 index f489a083341b66c7cca3cdad76a9c7ea16fdc959..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ /dev/null @@ -1,1069 +0,0 @@ -use anyhow::{Result, anyhow}; -use collections::{HashMap, HashSet}; -use futures::channel::mpsc; -use futures::lock::Mutex; -use futures::{FutureExt as _, StreamExt, future}; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity}; -use itertools::Itertools; - -use language::{Buffer, BufferEvent}; -use postage::stream::Stream as _; -use project::buffer_store::{BufferStore, BufferStoreEvent}; -use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; -use project::{PathChange, Project, ProjectEntryId, ProjectPath}; -use slotmap::SlotMap; -use std::iter; -use std::ops::{DerefMut, Range}; -use std::sync::Arc; -use text::BufferId; -use util::{RangeExt as _, debug_panic, some_or_debug_panic}; - -use crate::CachedDeclarationPath; -use crate::declaration::{ - BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, -}; -use crate::outline::declarations_in_buffer; - -// TODO -// -// * Also queue / debounce buffer changes. A challenge for this is that use of -// `buffer_declarations_containing_range` assumes that the index is always immediately up to date. -// -// * Add a per language configuration for skipping indexing. -// -// * Handle tsx / ts / js referencing each-other - -// Potential future improvements: -// -// * Prevent indexing of a large file from blocking the queue. -// -// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which -// references are present and their scores. -// -// * Include single-file worktrees / non visible worktrees? E.g. go to definition that resolves to a -// file in a build dependency. Should not be editable in that case - but how to distinguish the case -// where it should be editable? - -// Potential future optimizations: -// -// * Index files on multiple threads in Zed (currently only parallel for the CLI). Adding some kind -// of priority system to the background executor could help - it's single threaded for now to avoid -// interfering with other work. -// -// * Parse files directly instead of loading into a Rope. -// -// - This would allow the task handling dirty_files to be done entirely on the background executor. -// -// - Make SyntaxMap generic to handle embedded languages? Will also need to find line boundaries, -// but that can be done by scanning characters in the flat representation. -// -// * Use something similar to slotmap without key versions. -// -// * Concurrent slotmap - -pub struct SyntaxIndex { - state: Arc>, - project: WeakEntity, - initial_file_indexing_done_rx: postage::watch::Receiver, - _file_indexing_task: Option>, -} - -pub struct SyntaxIndexState { - declarations: SlotMap, - identifiers: HashMap>, - files: HashMap, - buffers: HashMap, - dirty_files: HashMap, - dirty_files_tx: mpsc::Sender<()>, -} - -#[derive(Debug, Default)] -struct FileState { - declarations: Vec, -} - -#[derive(Default)] -struct BufferState { - declarations: Vec, - task: Option>, -} - -impl SyntaxIndex { - pub fn new( - project: &Entity, - file_indexing_parallelism: usize, - cx: &mut Context, - ) -> Self { - assert!(file_indexing_parallelism > 0); - let (dirty_files_tx, mut dirty_files_rx) = mpsc::channel::<()>(1); - let (mut initial_file_indexing_done_tx, initial_file_indexing_done_rx) = - postage::watch::channel(); - - let initial_state = SyntaxIndexState { - declarations: SlotMap::default(), - identifiers: HashMap::default(), - files: HashMap::default(), - buffers: HashMap::default(), - dirty_files: HashMap::default(), - dirty_files_tx, - }; - let mut this = Self { - project: project.downgrade(), - state: Arc::new(Mutex::new(initial_state)), - initial_file_indexing_done_rx, - _file_indexing_task: None, - }; - - let worktree_store = project.read(cx).worktree_store(); - let initial_worktree_snapshots = worktree_store - .read(cx) - .worktrees() - .map(|w| w.read(cx).snapshot()) - .collect::>(); - this._file_indexing_task = Some(cx.spawn(async move |this, cx| { - let snapshots_file_count = initial_worktree_snapshots - .iter() - .map(|worktree| worktree.file_count()) - .sum::(); - if snapshots_file_count > 0 { - let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism); - let chunk_count = snapshots_file_count.div_ceil(chunk_size); - let file_chunks = initial_worktree_snapshots - .iter() - .flat_map(|worktree| { - let worktree_id = worktree.id(); - worktree.files(false, 0).map(move |entry| { - ( - entry.id, - ProjectPath { - worktree_id, - path: entry.path.clone(), - }, - ) - }) - }) - .chunks(chunk_size); - - let mut tasks = Vec::with_capacity(chunk_count); - for chunk in file_chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - log::info!("Finished initial file indexing"); - } - - *initial_file_indexing_done_tx.borrow_mut() = true; - - let Ok(state) = this.read_with(cx, |this, _cx| Arc::downgrade(&this.state)) else { - return; - }; - while dirty_files_rx.next().await.is_some() { - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let was_underused = state.dirty_files.capacity() > 255 - && state.dirty_files.len() * 8 < state.dirty_files.capacity(); - let dirty_files = state.dirty_files.drain().collect::>(); - if was_underused { - state.dirty_files.shrink_to_fit(); - } - drop(state); - if dirty_files.is_empty() { - continue; - } - - let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism); - let chunk_count = dirty_files.len().div_ceil(chunk_size); - let mut tasks = Vec::with_capacity(chunk_count); - let chunks = dirty_files.into_iter().chunks(chunk_size); - for chunk in chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - } - })); - - cx.subscribe(&worktree_store, Self::handle_worktree_store_event) - .detach(); - - let buffer_store = project.read(cx).buffer_store().clone(); - for buffer in buffer_store.read(cx).buffers().collect::>() { - this.register_buffer(&buffer, cx); - } - cx.subscribe(&buffer_store, Self::handle_buffer_store_event) - .detach(); - - this - } - - async fn update_dirty_files( - this: &WeakEntity, - dirty_files: Vec<(ProjectEntryId, ProjectPath)>, - mut cx: AsyncApp, - ) { - for (entry_id, project_path) in dirty_files { - let Ok(task) = this.update(&mut cx, |this, cx| { - this.update_file(entry_id, project_path, cx) - }) else { - return; - }; - task.await; - } - } - - pub fn wait_for_initial_file_indexing(&self, cx: &App) -> Task> { - if *self.initial_file_indexing_done_rx.borrow() { - Task::ready(Ok(())) - } else { - let mut rx = self.initial_file_indexing_done_rx.clone(); - cx.background_spawn(async move { - loop { - match rx.recv().await { - Some(true) => return Ok(()), - Some(false) => {} - None => { - return Err(anyhow!( - "SyntaxIndex dropped while waiting for initial file indexing" - )); - } - } - } - }) - } - } - - pub fn indexed_file_paths(&self, cx: &App) -> Task> { - let state = self.state.clone(); - let project = self.project.clone(); - - cx.spawn(async move |cx| { - let state = state.lock().await; - let Some(project) = project.upgrade() else { - return vec![]; - }; - project - .read_with(cx, |project, cx| { - state - .files - .keys() - .filter_map(|entry_id| project.path_for_entry(*entry_id, cx)) - .collect() - }) - .unwrap_or_default() - }) - } - - fn handle_worktree_store_event( - &mut self, - _worktree_store: Entity, - event: &WorktreeStoreEvent, - cx: &mut Context, - ) { - use WorktreeStoreEvent::*; - match event { - WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { - let state = Arc::downgrade(&self.state); - let worktree_id = *worktree_id; - let updated_entries_set = updated_entries_set.clone(); - cx.background_spawn(async move { - let Some(state) = state.upgrade() else { return }; - let mut state = state.lock().await; - for (path, entry_id, path_change) in updated_entries_set.iter() { - if let PathChange::Removed = path_change { - state.files.remove(entry_id); - state.dirty_files.remove(entry_id); - } else { - let project_path = ProjectPath { - worktree_id, - path: path.clone(), - }; - state.dirty_files.insert(*entry_id, project_path); - } - } - match state.dirty_files_tx.try_send(()) { - Err(err) if err.is_disconnected() => { - log::error!("bug: syntax indexing queue is disconnected"); - } - _ => {} - } - }) - .detach(); - } - WorktreeDeletedEntry(_worktree_id, project_entry_id) => { - let project_entry_id = *project_entry_id; - self.with_state(cx, move |state| { - state.files.remove(&project_entry_id); - }) - } - _ => {} - } - } - - fn handle_buffer_store_event( - &mut self, - _buffer_store: Entity, - event: &BufferStoreEvent, - cx: &mut Context, - ) { - use BufferStoreEvent::*; - match event { - BufferAdded(buffer) => self.register_buffer(buffer, cx), - BufferOpened { .. } - | BufferChangedFilePath { .. } - | BufferDropped { .. } - | SharedBufferClosed { .. } => {} - } - } - - pub fn state(&self) -> &Arc> { - &self.state - } - - fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) { - if let Some(mut state) = self.state.try_lock() { - f(&mut state); - return; - } - let state = Arc::downgrade(&self.state); - cx.background_spawn(async move { - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - f(&mut state) - }) - .detach(); - } - - fn register_buffer(&self, buffer: &Entity, cx: &mut Context) { - let buffer_id = buffer.read(cx).remote_id(); - cx.observe_release(buffer, move |this, _buffer, cx| { - this.with_state(cx, move |state| { - if let Some(buffer_state) = state.buffers.remove(&buffer_id) { - SyntaxIndexState::remove_buffer_declarations( - &buffer_state.declarations, - &mut state.declarations, - &mut state.identifiers, - ); - } - }) - }) - .detach(); - cx.subscribe(buffer, Self::handle_buffer_event).detach(); - - self.update_buffer(buffer.clone(), cx); - } - - fn handle_buffer_event( - &mut self, - buffer: Entity, - event: &BufferEvent, - cx: &mut Context, - ) { - match event { - BufferEvent::Edited | - // paths are cached and so should be updated - BufferEvent::FileHandleChanged => self.update_buffer(buffer, cx), - _ => {} - } - } - - fn update_buffer(&self, buffer_entity: Entity, cx: &mut Context) { - let buffer = buffer_entity.read(cx); - if buffer.language().is_none() { - return; - } - - let Some((project_entry_id, cached_path)) = project::File::from_dyn(buffer.file()) - .and_then(|f| { - let project_entry_id = f.project_entry_id()?; - let cached_path = CachedDeclarationPath::new( - f.worktree.read(cx).abs_path(), - &f.path, - buffer.language(), - ); - Some((project_entry_id, cached_path)) - }) - else { - return; - }; - let buffer_id = buffer.remote_id(); - - let mut parse_status = buffer.parse_status(); - let snapshot_task = cx.spawn({ - let weak_buffer = buffer_entity.downgrade(); - async move |_, cx| { - while *parse_status.borrow() != language::ParseStatus::Idle { - parse_status.changed().await?; - } - weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) - } - }); - - let state = Arc::downgrade(&self.state); - let task = cx.background_spawn(async move { - // TODO: How to handle errors? - let Ok(snapshot) = snapshot_task.await else { - return; - }; - let rope = snapshot.text.as_rope(); - - let declarations = declarations_in_buffer(&snapshot) - .into_iter() - .map(|item| { - ( - item.parent_index, - BufferDeclaration::from_outline(item, &rope), - ) - }) - .collect::>(); - - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let state = state.deref_mut(); - - let buffer_state = state - .buffers - .entry(buffer_id) - .or_insert_with(Default::default); - - SyntaxIndexState::remove_buffer_declarations( - &buffer_state.declarations, - &mut state.declarations, - &mut state.identifiers, - ); - - let mut new_ids = Vec::with_capacity(declarations.len()); - state.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = - parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = state.declarations.insert(Declaration::Buffer { - rope: rope.clone(), - buffer_id, - declaration, - project_entry_id, - cached_path: cached_path.clone(), - }); - new_ids.push(declaration_id); - - state - .identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - - buffer_state.declarations = new_ids; - }); - - self.with_state(cx, move |state| { - state - .buffers - .entry(buffer_id) - .or_insert_with(Default::default) - .task = Some(task) - }); - } - - fn update_file( - &mut self, - entry_id: ProjectEntryId, - project_path: ProjectPath, - cx: &mut Context, - ) -> Task<()> { - let Some(project) = self.project.upgrade() else { - return Task::ready(()); - }; - let project = project.read(cx); - - let language_registry = project.languages(); - let Some(available_language) = - language_registry.language_for_file_path(project_path.path.as_std_path()) - else { - return Task::ready(()); - }; - let language = if let Some(Ok(Ok(language))) = language_registry - .load_language(&available_language) - .now_or_never() - { - if language - .grammar() - .is_none_or(|grammar| grammar.outline_config.is_none()) - { - return Task::ready(()); - } - future::Either::Left(async { Ok(language) }) - } else { - let language_registry = language_registry.clone(); - future::Either::Right(async move { - anyhow::Ok( - language_registry - .load_language(&available_language) - .await??, - ) - }) - }; - - let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else { - return Task::ready(()); - }; - - let snapshot_task = worktree.update(cx, |worktree, cx| { - let load_task = worktree.load_file(&project_path.path, cx); - let worktree_abs_path = worktree.abs_path(); - cx.spawn(async move |_this, cx| { - let loaded_file = load_task.await?; - let language = language.await?; - - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local(loaded_file.text, cx); - buffer.set_language(Some(language.clone()), cx); - buffer - })?; - - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != language::ParseStatus::Idle { - parse_status.changed().await?; - } - - let cached_path = CachedDeclarationPath::new( - worktree_abs_path, - &project_path.path, - Some(&language), - ); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - anyhow::Ok((snapshot, cached_path)) - }) - }); - - let state = Arc::downgrade(&self.state); - cx.background_spawn(async move { - // TODO: How to handle errors? - let Ok((snapshot, cached_path)) = snapshot_task.await else { - return; - }; - let rope = snapshot.as_rope(); - let declarations = declarations_in_buffer(&snapshot) - .into_iter() - .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope))) - .collect::>(); - - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let state = state.deref_mut(); - - let file_state = state.files.entry(entry_id).or_insert_with(Default::default); - for old_declaration_id in &file_state.declarations { - let Some(declaration) = state.declarations.remove(*old_declaration_id) else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - state.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); - } - } - - let mut new_ids = Vec::with_capacity(declarations.len()); - state.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = - parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = state.declarations.insert(Declaration::File { - project_entry_id: entry_id, - declaration, - cached_path: cached_path.clone(), - }); - new_ids.push(declaration_id); - - state - .identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - file_state.declarations = new_ids; - }) - } -} - -impl SyntaxIndexState { - pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { - self.declarations.get(id) - } - - /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector. - /// - /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded. - pub fn declarations_for_identifier( - &self, - identifier: &Identifier, - ) -> Vec<(DeclarationId, &Declaration)> { - // make sure to not have a large stack allocation - assert!(N < 32); - - let Some(declaration_ids) = self.identifiers.get(&identifier) else { - return vec![]; - }; - - let mut result = Vec::with_capacity(N); - let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); - let mut file_declarations = Vec::new(); - - for declaration_id in declaration_ids { - let declaration = self.declarations.get(*declaration_id); - let Some(declaration) = some_or_debug_panic(declaration) else { - continue; - }; - match declaration { - Declaration::Buffer { - project_entry_id, .. - } => { - included_buffer_entry_ids.push(*project_entry_id); - result.push((*declaration_id, declaration)); - if result.len() == N { - return Vec::new(); - } - } - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - file_declarations.push((*declaration_id, declaration)); - } - } - } - } - - for (declaration_id, declaration) in file_declarations { - match declaration { - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push((declaration_id, declaration)); - - if result.len() == N { - return Vec::new(); - } - } - } - Declaration::Buffer { .. } => {} - } - } - - result - } - - pub fn buffer_declarations_containing_range( - &self, - buffer_id: BufferId, - range: Range, - ) -> impl Iterator { - let Some(buffer_state) = self.buffers.get(&buffer_id) else { - return itertools::Either::Left(iter::empty()); - }; - - let iter = buffer_state - .declarations - .iter() - .filter_map(move |declaration_id| { - let Some(declaration) = self - .declarations - .get(*declaration_id) - .and_then(|d| d.as_buffer()) - else { - log::error!("bug: missing buffer outline declaration"); - return None; - }; - if declaration.item_range.contains_inclusive(&range) { - return Some((*declaration_id, declaration)); - } - return None; - }); - itertools::Either::Right(iter) - } - - pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { - match declaration { - Declaration::File { - project_entry_id, .. - } => self - .files - .get(project_entry_id) - .map(|file_state| file_state.declarations.len()) - .unwrap_or_default(), - Declaration::Buffer { buffer_id, .. } => self - .buffers - .get(buffer_id) - .map(|buffer_state| buffer_state.declarations.len()) - .unwrap_or_default(), - } - } - - fn remove_buffer_declarations( - old_declaration_ids: &[DeclarationId], - declarations: &mut SlotMap, - identifiers: &mut HashMap>, - ) { - for old_declaration_id in old_declaration_ids { - let Some(declaration) = declarations.remove(*old_declaration_id) else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) { - identifier_declarations.remove(old_declaration_id); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use gpui::TestAppContext; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use text::OffsetRangeExt as _; - use util::{path, rel_path::rel_path}; - - use crate::syntax_index::SyntaxIndex; - - #[gpui::test] - async fn test_unopen_indexed_files(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let main = Identifier { - name: "main".into(), - language_id: rust_lang_id, - }; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - - let decl = expect_file_decl("a.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range, 0..98); - - let decl = expect_file_decl("c.rs", &decls[1].1, &project, cx); - assert_eq!(decl.identifier, main.clone()); - assert_eq!(decl.item_range, 32..280); - }); - } - - #[gpui::test] - async fn test_parents_in_file(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let test_process_data = Identifier { - name: "test_process_data".into(), - language_id: rust_lang_id, - }; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&test_process_data); - assert_eq!(decls.len(), 1); - - let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, test_process_data); - - let parent_id = decl.parent.unwrap(); - let parent = index_state.declaration(parent_id).unwrap(); - let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); - assert_eq!( - parent_decl.identifier, - Identifier { - name: "tests".into(), - language_id: rust_lang_id - } - ); - assert_eq!(parent_decl.parent, None); - }); - } - - #[gpui::test] - async fn test_parents_in_buffer(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let test_process_data = Identifier { - name: "test_process_data".into(), - language_id: rust_lang_id, - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&test_process_data); - assert_eq!(decls.len(), 1); - - let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, test_process_data); - - let parent_id = decl.parent.unwrap(); - let parent = index_state.declaration(parent_id).unwrap(); - let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx); - assert_eq!( - parent_decl.identifier, - Identifier { - name: "tests".into(), - language_id: rust_lang_id - } - ); - assert_eq!(parent_decl.parent, None); - }); - - drop(buffer); - } - - #[gpui::test] - async fn test_declarations_limit(cx: &mut TestAppContext) { - let (_, index, rust_lang_id) = init_test(cx).await; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - let decls = index_state.declarations_for_identifier::<1>(&Identifier { - name: "main".into(), - language_id: rust_lang_id, - }); - assert_eq!(decls.len(), 0); - } - - #[gpui::test] - async fn test_buffer_shadow(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - - let main = Identifier { - name: "main".into(), - language_id: rust_lang_id, - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone()); - { - let index_state = index_state_arc.lock().await; - - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280); - - expect_file_decl("a.rs", &decls[1].1, &project, cx); - }); - } - - // Drop the buffer and wait for release - cx.update(|_| { - drop(buffer); - }); - cx.run_until_parked(); - - let index_state = index_state_arc.lock().await; - - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - expect_file_decl("a.rs", &decls[0].1, &project, cx); - expect_file_decl("c.rs", &decls[1].1, &project, cx); - }); - } - - fn expect_buffer_decl<'a>( - path: &str, - declaration: &'a Declaration, - project: &Entity, - cx: &App, - ) -> &'a BufferDeclaration { - if let Declaration::Buffer { - declaration, - project_entry_id, - .. - } = declaration - { - let project_path = project - .read(cx) - .path_for_entry(*project_entry_id, cx) - .unwrap(); - assert_eq!(project_path.path.as_ref(), rel_path(path),); - declaration - } else { - panic!("Expected a buffer declaration, found {:?}", declaration); - } - } - - fn expect_file_decl<'a>( - path: &str, - declaration: &'a Declaration, - project: &Entity, - cx: &App, - ) -> &'a FileDeclaration { - if let Declaration::File { - declaration, - project_entry_id: file, - .. - } = declaration - { - assert_eq!( - project - .read(cx) - .path_for_entry(*file, cx) - .unwrap() - .path - .as_ref(), - rel_path(path), - ); - declaration - } else { - panic!("Expected a file declaration, found {:?}", declaration); - } - } - - async fn init_test( - cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "a.rs": indoc! {r#" - fn main() { - let x = 1; - let y = 2; - let z = add(x, y); - println!("Result: {}", z); - } - - fn add(a: i32, b: i32) -> i32 { - a + b - } - "#}, - "b.rs": indoc! {" - pub struct Config { - pub name: String, - pub value: i32, - } - - impl Config { - pub fn new(name: String, value: i32) -> Self { - Config { name, value } - } - } - "}, - "c.rs": indoc! {r#" - use std::collections::HashMap; - - fn main() { - let args: Vec = std::env::args().collect(); - let data: Vec = args[1..] - .iter() - .filter_map(|s| s.parse().ok()) - .collect(); - let result = process_data(data); - println!("{:?}", result); - } - - fn process_data(data: Vec) -> HashMap { - let mut counts = HashMap::new(); - for value in data { - *counts.entry(value).or_insert(0) += 1; - } - counts - } - - #[cfg(test)] - mod tests { - use super::*; - - #[test] - fn test_process_data() { - let data = vec![1, 2, 2, 3]; - let result = process_data(data); - assert_eq!(result.get(&2), Some(&2)); - } - } - "#} - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let lang = rust_lang(); - let lang_id = lang.id(); - language_registry.add(Arc::new(lang)); - - let file_indexing_parallelism = 2; - let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx)); - cx.run_until_parked(); - - (project, index, lang_id) - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs deleted file mode 100644 index 308a9570206084fc223c72f2e1c49109ea157714..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ /dev/null @@ -1,314 +0,0 @@ -use hashbrown::HashTable; -use regex::Regex; -use std::{ - borrow::Cow, - hash::{Hash, Hasher as _}, - path::Path, - sync::LazyLock, -}; -use util::rel_path::RelPath; - -use crate::reference::Reference; - -// TODO: Consider implementing sliding window similarity matching like -// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts -// -// That implementation could actually be more efficient - no need to track words in the window that -// are not in the query. - -// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the -// two in parallel. - -static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); - -/// Multiset of text occurrences for text similarity that only stores hashes and counts. -#[derive(Debug, Default)] -pub struct Occurrences { - table: HashTable, - total_count: usize, -} - -#[derive(Debug)] -struct OccurrenceEntry { - hash: u64, - count: usize, -} - -impl Occurrences { - pub fn within_string(text: &str) -> Self { - Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str())) - } - - #[allow(dead_code)] - pub fn within_references(references: &[Reference]) -> Self { - Self::from_identifiers( - references - .iter() - .map(|reference| reference.identifier.name.as_ref()), - ) - } - - pub fn from_identifiers(identifiers: impl IntoIterator>) -> Self { - let mut this = Self::default(); - // TODO: Score matches that match case higher? - // - // TODO: Also include unsplit identifier? - for identifier in identifiers { - for identifier_part in split_identifier(identifier.as_ref()) { - this.add_hash(fx_hash(&identifier_part.to_lowercase())); - } - } - this - } - - pub fn from_worktree_path(worktree_name: Option>, rel_path: &RelPath) -> Self { - if let Some(worktree_name) = worktree_name { - Self::from_identifiers( - std::iter::once(worktree_name) - .chain(iter_path_without_extension(rel_path.as_std_path())), - ) - } else { - Self::from_path(rel_path.as_std_path()) - } - } - - pub fn from_path(path: &Path) -> Self { - Self::from_identifiers(iter_path_without_extension(path)) - } - - fn add_hash(&mut self, hash: u64) { - self.table - .entry( - hash, - |entry: &OccurrenceEntry| entry.hash == hash, - |entry| entry.hash, - ) - .and_modify(|entry| entry.count += 1) - .or_insert(OccurrenceEntry { hash, count: 1 }); - self.total_count += 1; - } - - fn contains_hash(&self, hash: u64) -> bool { - self.get_count(hash) != 0 - } - - fn get_count(&self, hash: u64) -> usize { - self.table - .find(hash, |entry| entry.hash == hash) - .map(|entry| entry.count) - .unwrap_or(0) - } -} - -fn iter_path_without_extension(path: &Path) -> impl Iterator> { - let last_component: Option> = path.file_stem().map(|stem| stem.to_string_lossy()); - let mut path_components = path.components(); - path_components.next_back(); - path_components - .map(|component| component.as_os_str().to_string_lossy()) - .chain(last_component) -} - -pub fn fx_hash(data: &T) -> u64 { - let mut hasher = collections::FxHasher::default(); - data.hash(&mut hasher); - hasher.finish() -} - -// Splits camelcase / snakecase / kebabcase / pascalcase -// -// TODO: Make this more efficient / elegant. -fn split_identifier(identifier: &str) -> Vec<&str> { - let mut parts = Vec::new(); - let mut start = 0; - let chars: Vec = identifier.chars().collect(); - - if chars.is_empty() { - return parts; - } - - let mut i = 0; - while i < chars.len() { - let ch = chars[i]; - - // Handle explicit delimiters (underscore and hyphen) - if ch == '_' || ch == '-' { - if i > start { - parts.push(&identifier[start..i]); - } - start = i + 1; - i += 1; - continue; - } - - // Handle camelCase and PascalCase transitions - if i > 0 && i < chars.len() { - let prev_char = chars[i - 1]; - - // Transition from lowercase/digit to uppercase - if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() { - parts.push(&identifier[start..i]); - start = i; - } - // Handle sequences like "XMLParser" -> ["XML", "Parser"] - else if i + 1 < chars.len() - && ch.is_uppercase() - && chars[i + 1].is_lowercase() - && prev_char.is_uppercase() - { - parts.push(&identifier[start..i]); - start = i; - } - } - - i += 1; - } - - // Add the last part if there's any remaining - if start < identifier.len() { - parts.push(&identifier[start..]); - } - - // Filter out empty strings - parts.into_iter().filter(|s| !s.is_empty()).collect() -} - -pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - let intersection = set_a - .table - .iter() - .filter(|entry| set_b.contains_hash(entry.hash)) - .count(); - let union = set_a.table.len() + set_b.table.len() - intersection; - intersection as f32 / union as f32 -} - -// TODO -#[allow(dead_code)] -pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - let intersection = set_a - .table - .iter() - .filter(|entry| set_b.contains_hash(entry.hash)) - .count(); - intersection as f32 / set_a.table.len() as f32 -} - -// TODO -#[allow(dead_code)] -pub fn weighted_jaccard_similarity<'a>( - mut set_a: &'a Occurrences, - mut set_b: &'a Occurrences, -) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - - let mut numerator = 0; - let mut denominator_a = 0; - let mut used_count_b = 0; - for entry_a in set_a.table.iter() { - let count_a = entry_a.count; - let count_b = set_b.get_count(entry_a.hash); - numerator += count_a.min(count_b); - denominator_a += count_a.max(count_b); - used_count_b += count_b; - } - - let denominator = denominator_a + (set_b.total_count - used_count_b); - if denominator == 0 { - 0.0 - } else { - numerator as f32 / denominator as f32 - } -} - -pub fn weighted_overlap_coefficient<'a>( - mut set_a: &'a Occurrences, - mut set_b: &'a Occurrences, -) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - - let mut numerator = 0; - for entry_a in set_a.table.iter() { - let count_a = entry_a.count; - let count_b = set_b.get_count(entry_a.hash); - numerator += count_a.min(count_b); - } - - let denominator = set_a.total_count.min(set_b.total_count); - if denominator == 0 { - 0.0 - } else { - numerator as f32 / denominator as f32 - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_split_identifier() { - assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]); - assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]); - assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]); - assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]); - assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]); - } - - #[test] - fn test_similarity_functions() { - // 10 identifier parts, 8 unique - // Repeats: 2 "outline", 2 "items" - let set_a = Occurrences::within_string( - "let mut outline_items = query_outline_items(&language, &tree, &source);", - ); - // 14 identifier parts, 11 unique - // Repeats: 2 "outline", 2 "language", 2 "tree" - let set_b = Occurrences::within_string( - "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec {", - ); - - // 6 overlaps: "outline", "items", "query", "language", "tree", "source" - // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str" - assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0)); - - // Numerator is one more than before due to both having 2 "outline". - // Denominator is the same except for 3 more due to the non-overlapping duplicates - assert_eq!( - weighted_jaccard_similarity(&set_a, &set_b), - 7.0 / (7.0 + 7.0 + 3.0) - ); - - // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8. - assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0); - - // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of - // the smaller set, 10. - assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0); - } - - #[test] - fn test_iter_path_without_extension() { - let mut iter = iter_path_without_extension(Path::new("")); - assert_eq!(iter.next(), None); - - let iter = iter_path_without_extension(Path::new("foo")); - assert_eq!(iter.collect::>(), ["foo"]); - - let iter = iter_path_without_extension(Path::new("foo/bar.txt")); - assert_eq!(iter.collect::>(), ["foo", "bar"]); - - let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt")); - assert_eq!(iter.collect::>(), ["foo", "bar", "baz"]); - } -} diff --git a/crates/edit_prediction_context2/Cargo.toml b/crates/edit_prediction_context2/Cargo.toml deleted file mode 100644 index 597884b44821e24a930c8730225be4c6bf1c90f6..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context2/Cargo.toml +++ /dev/null @@ -1,42 +0,0 @@ -[package] -name = "edit_prediction_context2" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/edit_prediction_context2.rs" - -[dependencies] -parking_lot.workspace = true -anyhow.workspace = true -collections.workspace = true -futures.workspace = true -gpui.workspace = true -language.workspace = true -lsp.workspace = true -project.workspace = true -log.workspace = true -serde.workspace = true -smallvec.workspace = true -tree-sitter.workspace = true -util.workspace = true - -[dev-dependencies] -env_logger.workspace = true -indoc.workspace = true -futures.workspace = true -gpui = { workspace = true, features = ["test-support"] } -language = { workspace = true, features = ["test-support"] } -lsp = { workspace = true, features = ["test-support"] } -pretty_assertions.workspace = true -project = {workspace= true, features = ["test-support"]} -serde_json.workspace = true -settings = {workspace= true, features = ["test-support"]} -text = { workspace = true, features = ["test-support"] } -util = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/edit_prediction_context2/src/edit_prediction_context2.rs b/crates/edit_prediction_context2/src/edit_prediction_context2.rs deleted file mode 100644 index f8790478547ddb8b7b873015846f2af6c1bcbc2c..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context2/src/edit_prediction_context2.rs +++ /dev/null @@ -1,465 +0,0 @@ -use crate::assemble_excerpts::assemble_excerpts; -use anyhow::Result; -use collections::HashMap; -use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; -use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; -use project::{LocationLink, Project, ProjectPath}; -use serde::{Serialize, Serializer}; -use smallvec::SmallVec; -use std::{ - collections::hash_map, - ops::Range, - sync::Arc, - time::{Duration, Instant}, -}; -use util::{RangeExt as _, ResultExt}; - -mod assemble_excerpts; -#[cfg(test)] -mod edit_prediction_context_tests; -#[cfg(test)] -mod fake_definition_lsp; - -pub struct RelatedExcerptStore { - project: WeakEntity, - related_files: Vec, - cache: HashMap>, - update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, -} - -pub enum RelatedExcerptStoreEvent { - StartedRefresh, - FinishedRefresh { - cache_hit_count: usize, - cache_miss_count: usize, - mean_definition_latency: Duration, - max_definition_latency: Duration, - }, -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct Identifier { - pub name: String, - pub range: Range, -} - -enum DefinitionTask { - CacheHit(Arc), - CacheMiss(Task>>>), -} - -#[derive(Debug)] -struct CacheEntry { - definitions: SmallVec<[CachedDefinition; 1]>, -} - -#[derive(Clone, Debug)] -struct CachedDefinition { - path: ProjectPath, - buffer: Entity, - anchor_range: Range, -} - -#[derive(Clone, Debug, Serialize)] -pub struct RelatedFile { - #[serde(serialize_with = "serialize_project_path")] - pub path: ProjectPath, - #[serde(skip)] - pub buffer: WeakEntity, - pub excerpts: Vec, - pub max_row: u32, -} - -impl RelatedFile { - pub fn merge_excerpts(&mut self) { - self.excerpts.sort_unstable_by(|a, b| { - a.point_range - .start - .cmp(&b.point_range.start) - .then(b.point_range.end.cmp(&a.point_range.end)) - }); - - let mut index = 1; - while index < self.excerpts.len() { - if self.excerpts[index - 1] - .point_range - .end - .cmp(&self.excerpts[index].point_range.start) - .is_ge() - { - let removed = self.excerpts.remove(index); - if removed - .point_range - .end - .cmp(&self.excerpts[index - 1].point_range.end) - .is_gt() - { - self.excerpts[index - 1].point_range.end = removed.point_range.end; - self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; - } - } else { - index += 1; - } - } - } -} - -#[derive(Clone, Debug, Serialize)] -pub struct RelatedExcerpt { - #[serde(skip)] - pub anchor_range: Range, - #[serde(serialize_with = "serialize_point_range")] - pub point_range: Range, - #[serde(serialize_with = "serialize_rope")] - pub text: Rope, -} - -fn serialize_project_path( - project_path: &ProjectPath, - serializer: S, -) -> Result { - project_path.path.serialize(serializer) -} - -fn serialize_rope(rope: &Rope, serializer: S) -> Result { - rope.to_string().serialize(serializer) -} - -fn serialize_point_range( - range: &Range, - serializer: S, -) -> Result { - [ - [range.start.row, range.start.column], - [range.end.row, range.end.column], - ] - .serialize(serializer) -} - -const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); - -impl EventEmitter for RelatedExcerptStore {} - -impl RelatedExcerptStore { - pub fn new(project: &Entity, cx: &mut Context) -> Self { - let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity, Anchor)>(); - cx.spawn(async move |this, cx| { - let executor = cx.background_executor().clone(); - while let Some((mut buffer, mut position)) = update_rx.next().await { - let mut timer = executor.timer(DEBOUNCE_DURATION).fuse(); - loop { - futures::select_biased! { - next = update_rx.next() => { - if let Some((new_buffer, new_position)) = next { - buffer = new_buffer; - position = new_position; - timer = executor.timer(DEBOUNCE_DURATION).fuse(); - } else { - return anyhow::Ok(()); - } - } - _ = timer => break, - } - } - - Self::fetch_excerpts(this.clone(), buffer, position, cx).await?; - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - - RelatedExcerptStore { - project: project.downgrade(), - update_tx, - related_files: Vec::new(), - cache: Default::default(), - } - } - - pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { - self.update_tx.unbounded_send((buffer, position)).ok(); - } - - pub fn related_files(&self) -> &[RelatedFile] { - &self.related_files - } - - async fn fetch_excerpts( - this: WeakEntity, - buffer: Entity, - position: Anchor, - cx: &mut AsyncApp, - ) -> Result<()> { - let (project, snapshot) = this.read_with(cx, |this, cx| { - (this.project.upgrade(), buffer.read(cx).snapshot()) - })?; - let Some(project) = project else { - return Ok(()); - }; - - let file = snapshot.file().cloned(); - if let Some(file) = &file { - log::debug!("retrieving_context buffer:{}", file.path().as_unix_str()); - } - - this.update(cx, |_, cx| { - cx.emit(RelatedExcerptStoreEvent::StartedRefresh); - })?; - - let identifiers = cx - .background_spawn(async move { identifiers_for_position(&snapshot, position) }) - .await; - - let async_cx = cx.clone(); - let start_time = Instant::now(); - let futures = this.update(cx, |this, cx| { - identifiers - .into_iter() - .filter_map(|identifier| { - let task = if let Some(entry) = this.cache.get(&identifier) { - DefinitionTask::CacheHit(entry.clone()) - } else { - DefinitionTask::CacheMiss( - this.project - .update(cx, |project, cx| { - project.definitions(&buffer, identifier.range.start, cx) - }) - .ok()?, - ) - }; - - let cx = async_cx.clone(); - let project = project.clone(); - Some(async move { - match task { - DefinitionTask::CacheHit(cache_entry) => { - Some((identifier, cache_entry, None)) - } - DefinitionTask::CacheMiss(task) => { - let locations = task.await.log_err()??; - let duration = start_time.elapsed(); - cx.update(|cx| { - ( - identifier, - Arc::new(CacheEntry { - definitions: locations - .into_iter() - .filter_map(|location| { - process_definition(location, &project, cx) - }) - .collect(), - }), - Some(duration), - ) - }) - .ok() - } - } - }) - }) - .collect::>() - })?; - - let mut cache_hit_count = 0; - let mut cache_miss_count = 0; - let mut mean_definition_latency = Duration::ZERO; - let mut max_definition_latency = Duration::ZERO; - let mut new_cache = HashMap::default(); - new_cache.reserve(futures.len()); - for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() { - new_cache.insert(identifier, entry); - if let Some(duration) = duration { - cache_miss_count += 1; - mean_definition_latency += duration; - max_definition_latency = max_definition_latency.max(duration); - } else { - cache_hit_count += 1; - } - } - mean_definition_latency /= cache_miss_count.max(1) as u32; - - let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; - - if let Some(file) = &file { - log::debug!( - "finished retrieving context buffer:{}, latency:{:?}", - file.path().as_unix_str(), - start_time.elapsed() - ); - } - - this.update(cx, |this, cx| { - this.cache = new_cache; - this.related_files = related_files; - cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { - cache_hit_count, - cache_miss_count, - mean_definition_latency, - max_definition_latency, - }); - })?; - - anyhow::Ok(()) - } -} - -async fn rebuild_related_files( - new_entries: HashMap>, - cx: &mut AsyncApp, -) -> Result<(HashMap>, Vec)> { - let mut snapshots = HashMap::default(); - for entry in new_entries.values() { - for definition in &entry.definitions { - if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { - definition - .buffer - .read_with(cx, |buffer, _| buffer.parsing_idle())? - .await; - e.insert( - definition - .buffer - .read_with(cx, |buffer, _| buffer.snapshot())?, - ); - } - } - } - - Ok(cx - .background_spawn(async move { - let mut files = Vec::::new(); - let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); - let mut paths_by_buffer = HashMap::default(); - for entry in new_entries.values() { - for definition in &entry.definitions { - let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else { - continue; - }; - paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone()); - ranges_by_buffer - .entry(definition.buffer.clone()) - .or_default() - .push(definition.anchor_range.to_point(snapshot)); - } - } - - for (buffer, ranges) in ranges_by_buffer { - let Some(snapshot) = snapshots.get(&buffer.entity_id()) else { - continue; - }; - let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else { - continue; - }; - let excerpts = assemble_excerpts(snapshot, ranges); - files.push(RelatedFile { - path: project_path.clone(), - buffer: buffer.downgrade(), - excerpts, - max_row: snapshot.max_point().row, - }); - } - - files.sort_by_key(|file| file.path.clone()); - (new_entries, files) - }) - .await) -} - -fn process_definition( - location: LocationLink, - project: &Entity, - cx: &mut App, -) -> Option { - let buffer = location.target.buffer.read(cx); - let anchor_range = location.target.range; - let file = buffer.file()?; - let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?; - if worktree.read(cx).is_single_file() { - return None; - } - Some(CachedDefinition { - path: ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path().clone(), - }, - buffer: location.target.buffer, - anchor_range, - }) -} - -/// Gets all of the identifiers that are present in the given line, and its containing -/// outline items. -fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec { - let offset = position.to_offset(buffer); - let point = buffer.offset_to_point(offset); - - let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point()); - let mut ranges = vec![line_range.to_offset(&buffer)]; - - // Include the range of the outline item itself, but not its body. - let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); - for item in outline_items { - if let Some(body_range) = item.body_range(&buffer) { - ranges.push(item.range.start..body_range.start.to_offset(&buffer)); - } else { - ranges.push(item.range.clone()); - } - } - - ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); - ranges.dedup_by(|a, b| { - if a.start <= b.end { - b.start = b.start.min(a.start); - b.end = b.end.max(a.end); - true - } else { - false - } - }); - - let mut identifiers = Vec::new(); - let outer_range = - ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end); - - let mut captures = buffer - .syntax - .captures(outer_range.clone(), &buffer.text, |grammar| { - grammar - .highlights_config - .as_ref() - .map(|config| &config.query) - }); - - for range in ranges { - captures.set_byte_range(range.start..outer_range.end); - - let mut last_range = None; - while let Some(capture) = captures.peek() { - let node_range = capture.node.byte_range(); - if node_range.start > range.end { - break; - } - let config = captures.grammars()[capture.grammar_index] - .highlights_config - .as_ref(); - - if let Some(config) = config - && config.identifier_capture_indices.contains(&capture.index) - && range.contains_inclusive(&node_range) - && Some(&node_range) != last_range.as_ref() - { - let name = buffer.text_for_range(node_range.clone()).collect(); - identifiers.push(Identifier { - range: buffer.anchor_after(node_range.start) - ..buffer.anchor_before(node_range.end), - name, - }); - last_range = Some(node_range); - } - - captures.advance(); - } - } - - identifiers -} diff --git a/crates/edit_prediction_types/Cargo.toml b/crates/edit_prediction_types/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ebc09680e1dcf99dc21e1714eca6a9db337f4a90 --- /dev/null +++ b/crates/edit_prediction_types/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "edit_prediction_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_types.rs" + +[dependencies] +client.workspace = true +gpui.workspace = true +language.workspace = true diff --git a/crates/edit_prediction_context2/LICENSE-GPL b/crates/edit_prediction_types/LICENSE-GPL similarity index 100% rename from crates/edit_prediction_context2/LICENSE-GPL rename to crates/edit_prediction_types/LICENSE-GPL diff --git a/crates/edit_prediction_types/src/edit_prediction_types.rs b/crates/edit_prediction_types/src/edit_prediction_types.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f63b8626d15dfd3e2cba78aacb50505186da01c --- /dev/null +++ b/crates/edit_prediction_types/src/edit_prediction_types.rs @@ -0,0 +1,298 @@ +use std::{ops::Range, sync::Arc}; + +use client::EditPredictionUsage; +use gpui::{App, Context, Entity, SharedString}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt}; + +// TODO: Find a better home for `Direction`. +// +// This should live in an ancestor crate of `editor` and `edit_prediction`, +// but at time of writing there isn't an obvious spot. +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum Direction { + Prev, + Next, +} + +#[derive(Clone)] +pub enum EditPrediction { + /// Edits within the buffer that requested the prediction + Local { + id: Option, + edits: Vec<(Range, Arc)>, + edit_preview: Option, + }, + /// Jump to a different file from the one that requested the prediction + Jump { + id: Option, + snapshot: language::BufferSnapshot, + target: language::Anchor, + }, +} + +pub enum DataCollectionState { + /// The provider doesn't support data collection. + Unsupported, + /// Data collection is enabled. + Enabled { is_project_open_source: bool }, + /// Data collection is disabled or unanswered. + Disabled { is_project_open_source: bool }, +} + +impl DataCollectionState { + pub fn is_supported(&self) -> bool { + !matches!(self, DataCollectionState::Unsupported) + } + + pub fn is_enabled(&self) -> bool { + matches!(self, DataCollectionState::Enabled { .. }) + } + + pub fn is_project_open_source(&self) -> bool { + match self { + Self::Enabled { + is_project_open_source, + } + | Self::Disabled { + is_project_open_source, + } => *is_project_open_source, + _ => false, + } + } +} + +pub trait EditPredictionDelegate: 'static + Sized { + fn name() -> &'static str; + fn display_name() -> &'static str; + fn show_predictions_in_menu() -> bool; + fn show_tab_accept_marker() -> bool { + false + } + fn supports_jump_to_edit() -> bool { + true + } + + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + DataCollectionState::Unsupported + } + + fn usage(&self, _cx: &App) -> Option { + None + } + + fn toggle_data_collection(&mut self, _cx: &mut App) {} + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool; + fn is_refreshing(&self, cx: &App) -> bool; + fn refresh( + &mut self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut Context, + ); + fn cycle( + &mut self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut Context, + ); + fn accept(&mut self, cx: &mut Context); + fn discard(&mut self, cx: &mut Context); + fn did_show(&mut self, _cx: &mut Context) {} + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) -> Option; +} + +pub trait EditPredictionDelegateHandle { + fn name(&self) -> &'static str; + fn display_name(&self) -> &'static str; + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool; + fn show_predictions_in_menu(&self) -> bool; + fn show_tab_accept_marker(&self) -> bool; + fn supports_jump_to_edit(&self) -> bool; + fn data_collection_state(&self, cx: &App) -> DataCollectionState; + fn usage(&self, cx: &App) -> Option; + fn toggle_data_collection(&self, cx: &mut App); + fn is_refreshing(&self, cx: &App) -> bool; + fn refresh( + &self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut App, + ); + fn cycle( + &self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut App, + ); + fn did_show(&self, cx: &mut App); + fn accept(&self, cx: &mut App); + fn discard(&self, cx: &mut App); + fn suggest( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut App, + ) -> Option; +} + +impl EditPredictionDelegateHandle for Entity +where + T: EditPredictionDelegate, +{ + fn name(&self) -> &'static str { + T::name() + } + + fn display_name(&self) -> &'static str { + T::display_name() + } + + fn show_predictions_in_menu(&self) -> bool { + T::show_predictions_in_menu() + } + + fn show_tab_accept_marker(&self) -> bool { + T::show_tab_accept_marker() + } + + fn supports_jump_to_edit(&self) -> bool { + T::supports_jump_to_edit() + } + + fn data_collection_state(&self, cx: &App) -> DataCollectionState { + self.read(cx).data_collection_state(cx) + } + + fn usage(&self, cx: &App) -> Option { + self.read(cx).usage(cx) + } + + fn toggle_data_collection(&self, cx: &mut App) { + self.update(cx, |this, cx| this.toggle_data_collection(cx)) + } + + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool { + self.read(cx).is_enabled(buffer, cursor_position, cx) + } + + fn is_refreshing(&self, cx: &App) -> bool { + self.read(cx).is_refreshing(cx) + } + + fn refresh( + &self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut App, + ) { + self.update(cx, |this, cx| { + this.refresh(buffer, cursor_position, debounce, cx) + }) + } + + fn cycle( + &self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut App, + ) { + self.update(cx, |this, cx| { + this.cycle(buffer, cursor_position, direction, cx) + }) + } + + fn accept(&self, cx: &mut App) { + self.update(cx, |this, cx| this.accept(cx)) + } + + fn discard(&self, cx: &mut App) { + self.update(cx, |this, cx| this.discard(cx)) + } + + fn did_show(&self, cx: &mut App) { + self.update(cx, |this, cx| this.did_show(cx)) + } + + fn suggest( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut App, + ) -> Option { + self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) + } +} + +/// Returns edits updated based on user edits since the old snapshot. None is returned if any user +/// edit is not a prefix of a predicted insertion. +pub fn interpolate_edits( + old_snapshot: &BufferSnapshot, + new_snapshot: &BufferSnapshot, + current_edits: &[(Range, Arc)], +) -> Option, Arc)>> { + let mut edits = Vec::new(); + + let mut model_edits = current_edits.iter().peekable(); + for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { + while let Some((model_old_range, _)) = model_edits.peek() { + let model_old_range = model_old_range.to_offset(old_snapshot); + if model_old_range.end < user_edit.old.start { + let (model_old_range, model_new_text) = model_edits.next().unwrap(); + edits.push((model_old_range.clone(), model_new_text.clone())); + } else { + break; + } + } + + if let Some((model_old_range, model_new_text)) = model_edits.peek() { + let model_old_offset_range = model_old_range.to_offset(old_snapshot); + if user_edit.old == model_old_offset_range { + let user_new_text = new_snapshot + .text_for_range(user_edit.new.clone()) + .collect::(); + + if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { + if !model_suffix.is_empty() { + let anchor = old_snapshot.anchor_after(user_edit.old.end); + edits.push((anchor..anchor, model_suffix.into())); + } + + model_edits.next(); + continue; + } + } + } + + return None; + } + + edits.extend(model_edits.cloned()); + + if edits.is_empty() { None } else { Some(edits) } +} diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_ui/Cargo.toml similarity index 77% rename from crates/edit_prediction_button/Cargo.toml rename to crates/edit_prediction_ui/Cargo.toml index d336cf66926d37ab7c0ebb1d5aa5a2172342350c..fb846f35d76ae2f6478ef675f246e4d06fe5f469 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_ui/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "edit_prediction_button" +name = "edit_prediction_ui" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,35 +9,43 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/edit_prediction_button.rs" +path = "src/edit_prediction_ui.rs" doctest = false [dependencies] anyhow.workspace = true +buffer_diff.workspace = true client.workspace = true cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true codestral.workspace = true +command_palette_hooks.workspace = true copilot.workspace = true edit_prediction.workspace = true +edit_prediction_types.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true +futures.workspace = true gpui.workspace = true indoc.workspace = true language.workspace = true +markdown.workspace = true +menu.workspace = true +multi_buffer.workspace = true paths.workspace = true project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true telemetry.workspace = true +text.workspace = true +theme.workspace = true ui.workspace = true ui_input.workspace = true -menu.workspace = true util.workspace = true workspace.workspace = true zed_actions.workspace = true -zeta.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta/LICENSE-GPL b/crates/edit_prediction_ui/LICENSE-GPL similarity index 100% rename from crates/zeta/LICENSE-GPL rename to crates/edit_prediction_ui/LICENSE-GPL diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs similarity index 97% rename from crates/edit_prediction_button/src/edit_prediction_button.rs rename to crates/edit_prediction_ui/src/edit_prediction_button.rs index 8b234497376aefdc972681c877a1122f3f9cee17..dd3ebab42029f5adb7570b71ae0cd662aff3328e 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -1,16 +1,14 @@ -mod sweep_api_token_modal; - -pub use sweep_api_token_modal::SweepApiKeyModal; - use anyhow::Result; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; -use codestral::CodestralCompletionProvider; +use codestral::CodestralEditPredictionDelegate; use copilot::{Copilot, Status}; +use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag}; +use edit_prediction_types::EditPredictionDelegateHandle; use editor::{ Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll, }; -use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; +use feature_flags::FeatureFlagAppExt; use fs::Fs; use gpui::{ Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle, @@ -44,7 +42,11 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag}; + +use crate::{ + RatePredictions, SweepApiKeyModal, + rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag, +}; actions!( edit_prediction, @@ -67,7 +69,7 @@ pub struct EditPredictionButton { editor_focus_handle: Option, language: Option>, file: Option>, - edit_prediction_provider: Option>, + edit_prediction_provider: Option>, fs: Arc, user_store: Entity, popover_menu_handle: PopoverMenuHandle, @@ -244,7 +246,7 @@ impl Render for EditPredictionButton { EditPredictionProvider::Codestral => { let enabled = self.editor_enabled.unwrap_or(true); - let has_api_key = CodestralCompletionProvider::has_api_key(cx); + let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx); let fs = self.fs.clone(); let this = cx.weak_entity(); @@ -317,16 +319,16 @@ impl Render for EditPredictionButton { ); let sweep_missing_token = is_sweep - && !zeta::Zeta::try_global(cx) - .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); + && !edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token()); - let zeta_icon = match (is_sweep, enabled) { + let ep_icon = match (is_sweep, enabled) { (true, _) => IconName::SweepAi, (false, true) => IconName::ZedPredict, (false, false) => IconName::ZedPredictDisabled, }; - if zeta::should_show_upsell_modal() { + if edit_prediction::should_show_upsell_modal() { let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { "Choose a Plan" } else { @@ -334,7 +336,7 @@ impl Render for EditPredictionButton { }; return div().child( - IconButton::new("zed-predict-pending-button", zeta_icon) + IconButton::new("zed-predict-pending-button", ep_icon) .shape(IconButtonShape::Square) .indicator(Indicator::dot().color(Color::Muted)) .indicator_border_color(Some(cx.theme().colors().status_bar_background)) @@ -379,7 +381,7 @@ impl Render for EditPredictionButton { None }; - let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon) + let icon_button = IconButton::new("zed-predict-pending-button", ep_icon) .shape(IconButtonShape::Square) .when_some(indicator_color, |this, color| { this.indicator(Indicator::dot().color(color)) @@ -419,13 +421,13 @@ impl Render for EditPredictionButton { let this = cx.weak_entity(); - let mut popover_menu = PopoverMenu::new("zeta") + let mut popover_menu = PopoverMenu::new("edit-prediction") .when(user.is_some(), |popover_menu| { let this = this.clone(); popover_menu.menu(move |window, cx| { this.update(cx, |this, cx| { - this.build_zeta_context_menu(provider, window, cx) + this.build_edit_prediction_context_menu(provider, window, cx) }) .ok() }) @@ -485,7 +487,7 @@ impl EditPredictionButton { cx.observe_global::(move |_, cx| cx.notify()) .detach(); - CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx); + CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx); Self { editor_subscription: None, @@ -520,7 +522,7 @@ impl EditPredictionButton { } } - if CodestralCompletionProvider::has_api_key(cx) { + if CodestralEditPredictionDelegate::has_api_key(cx) { providers.push(EditPredictionProvider::Codestral); } @@ -599,8 +601,8 @@ impl EditPredictionButton { EditPredictionProvider::Experimental( EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) => { - let has_api_token = zeta::Zeta::try_global(cx) - .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); + let has_api_token = edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token()); let should_open_modal = !has_api_token || is_current; @@ -947,8 +949,8 @@ impl EditPredictionButton { ) .context(editor_focus_handle) .when( - cx.has_flag::(), - |this| this.action("Rate Completions", RateCompletions.boxed_clone()), + cx.has_flag::(), + |this| this.action("Rate Predictions", RatePredictions.boxed_clone()), ); } @@ -1016,7 +1018,7 @@ impl EditPredictionButton { }) } - fn build_zeta_context_menu( + fn build_edit_prediction_context_menu( &self, provider: EditPredictionProvider, window: &mut Window, diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs similarity index 85% rename from crates/zeta2_tools/src/zeta2_context_view.rs rename to crates/edit_prediction_ui/src/edit_prediction_context_view.rs index 882846929a62f90f349d40f8f6b6996f83613ec7..0e343fe3fcb8ed7bb6bf3e8481927344d63133ee 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs @@ -23,16 +23,16 @@ use ui::{ StyledTypography as _, h_flex, v_flex, }; -use workspace::Item; -use zeta::{ - Zeta, ZetaContextRetrievalFinishedDebugInfo, ZetaContextRetrievalStartedDebugInfo, - ZetaDebugInfo, +use edit_prediction::{ + ContextRetrievalFinishedDebugEvent, ContextRetrievalStartedDebugEvent, DebugEvent, + EditPredictionStore, }; +use workspace::Item; -pub struct Zeta2ContextView { +pub struct EditPredictionContextView { empty_focus_handle: FocusHandle, project: Entity, - zeta: Entity, + store: Entity, runs: VecDeque, current_ix: usize, _update_task: Task>, @@ -50,13 +50,13 @@ actions!( dev, [ /// Go to the previous context retrieval run - Zeta2ContextGoBack, + EditPredictionContextGoBack, /// Go to the next context retrieval run - Zeta2ContextGoForward + EditPredictionContextGoForward ] ); -impl Zeta2ContextView { +impl EditPredictionContextView { pub fn new( project: Entity, client: &Arc, @@ -64,13 +64,13 @@ impl Zeta2ContextView { window: &mut gpui::Window, cx: &mut Context, ) -> Self { - let zeta = Zeta::global(client, user_store, cx); + let store = EditPredictionStore::global(client, user_store, cx); - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info()); + let mut debug_rx = store.update(cx, |store, _| store.debug_info()); let _update_task = cx.spawn_in(window, async move |this, cx| { while let Some(event) = debug_rx.next().await { this.update_in(cx, |this, window, cx| { - this.handle_zeta_event(event, window, cx) + this.handle_store_event(event, window, cx) })?; } Ok(()) @@ -81,35 +81,35 @@ impl Zeta2ContextView { project, runs: VecDeque::new(), current_ix: 0, - zeta, + store, _update_task, } } - fn handle_zeta_event( + fn handle_store_event( &mut self, - event: ZetaDebugInfo, + event: DebugEvent, window: &mut gpui::Window, cx: &mut Context, ) { match event { - ZetaDebugInfo::ContextRetrievalStarted(info) => { + DebugEvent::ContextRetrievalStarted(info) => { if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_started(info, window, cx); } } - ZetaDebugInfo::ContextRetrievalFinished(info) => { + DebugEvent::ContextRetrievalFinished(info) => { if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_finished(info, window, cx); } } - ZetaDebugInfo::EditPredictionRequested(_) => {} + DebugEvent::EditPredictionRequested(_) => {} } } fn handle_context_retrieval_started( &mut self, - info: ZetaContextRetrievalStartedDebugInfo, + info: ContextRetrievalStartedDebugEvent, window: &mut Window, cx: &mut Context, ) { @@ -141,7 +141,7 @@ impl Zeta2ContextView { fn handle_context_retrieval_finished( &mut self, - info: ZetaContextRetrievalFinishedDebugInfo, + info: ContextRetrievalFinishedDebugEvent, window: &mut Window, cx: &mut Context, ) { @@ -154,7 +154,7 @@ impl Zeta2ContextView { let project = self.project.clone(); let related_files = self - .zeta + .store .read(cx) .context_for_project(&self.project, cx) .to_vec(); @@ -220,7 +220,7 @@ impl Zeta2ContextView { fn handle_go_back( &mut self, - _: &Zeta2ContextGoBack, + _: &EditPredictionContextGoBack, window: &mut Window, cx: &mut Context, ) { @@ -231,7 +231,7 @@ impl Zeta2ContextView { fn handle_go_forward( &mut self, - _: &Zeta2ContextGoForward, + _: &EditPredictionContextGoForward, window: &mut Window, cx: &mut Context, ) { @@ -243,7 +243,10 @@ impl Zeta2ContextView { cx.notify(); } - fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div { + fn render_informational_footer( + &self, + cx: &mut Context<'_, EditPredictionContextView>, + ) -> ui::Div { let run = &self.runs[self.current_ix]; let new_run_started = self .runs @@ -279,10 +282,10 @@ impl Zeta2ContextView { .disabled(self.current_ix == 0 || self.runs.len() < 2) .tooltip(ui::Tooltip::for_action_title( "Go to previous run", - &Zeta2ContextGoBack, + &EditPredictionContextGoBack, )) .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_back(&Zeta2ContextGoBack, window, cx); + this.handle_go_back(&EditPredictionContextGoBack, window, cx); })), ) .child( @@ -308,10 +311,14 @@ impl Zeta2ContextView { .disabled(self.current_ix + 1 == self.runs.len()) .tooltip(ui::Tooltip::for_action_title( "Go to next run", - &Zeta2ContextGoBack, + &EditPredictionContextGoBack, )) .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_forward(&Zeta2ContextGoForward, window, cx); + this.handle_go_forward( + &EditPredictionContextGoForward, + window, + cx, + ); })), ), ), @@ -319,7 +326,7 @@ impl Zeta2ContextView { } } -impl Focusable for Zeta2ContextView { +impl Focusable for EditPredictionContextView { fn focus_handle(&self, cx: &App) -> FocusHandle { self.runs .get(self.current_ix) @@ -328,9 +335,9 @@ impl Focusable for Zeta2ContextView { } } -impl EventEmitter<()> for Zeta2ContextView {} +impl EventEmitter<()> for EditPredictionContextView {} -impl Item for Zeta2ContextView { +impl Item for EditPredictionContextView { type Event = (); fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { @@ -357,10 +364,10 @@ impl Item for Zeta2ContextView { } } -impl gpui::Render for Zeta2ContextView { +impl gpui::Render for EditPredictionContextView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl ui::IntoElement { v_flex() - .key_context("Zeta2Context") + .key_context("EditPredictionContext") .on_action(cx.listener(Self::handle_go_back)) .on_action(cx.listener(Self::handle_go_forward)) .size_full() diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs new file mode 100644 index 0000000000000000000000000000000000000000..51b491c6b3512968bca4ce2e7ed73a505bd73a00 --- /dev/null +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -0,0 +1,128 @@ +mod edit_prediction_button; +mod edit_prediction_context_view; +mod rate_prediction_modal; +mod sweep_api_token_modal; + +use std::any::{Any as _, TypeId}; + +use command_palette_hooks::CommandPaletteFilter; +use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag}; +use edit_prediction_context_view::EditPredictionContextView; +use feature_flags::FeatureFlagAppExt as _; +use gpui::actions; +use project::DisableAiSettings; +use rate_prediction_modal::RatePredictionsModal; +use settings::{Settings as _, SettingsStore}; +use ui::{App, prelude::*}; +use workspace::{SplitDirection, Workspace}; + +pub use edit_prediction_button::{EditPredictionButton, ToggleMenu}; +pub use sweep_api_token_modal::SweepApiKeyModal; + +use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag; + +actions!( + dev, + [ + /// Opens the edit prediction context view. + OpenEditPredictionContextView, + ] +); + +actions!( + edit_prediction, + [ + /// Opens the rate completions modal. + RatePredictions, + ] +); + +pub fn init(cx: &mut App) { + feature_gate_predict_edits_actions(cx); + + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { + workspace.register_action(|workspace, _: &RatePredictions, window, cx| { + if cx.has_flag::() { + RatePredictionsModal::toggle(workspace, window, cx); + } + }); + + workspace.register_action_renderer(|div, _, _, cx| { + let has_flag = cx.has_flag::(); + div.when(has_flag, |div| { + div.on_action(cx.listener( + move |workspace, _: &OpenEditPredictionContextView, window, cx| { + let project = workspace.project(); + workspace.split_item( + SplitDirection::Right, + Box::new(cx.new(|cx| { + EditPredictionContextView::new( + project.clone(), + workspace.client(), + workspace.user_store(), + window, + cx, + ) + })), + window, + cx, + ); + }, + )) + }) + }); + }) + .detach(); +} + +fn feature_gate_predict_edits_actions(cx: &mut App) { + let rate_completion_action_types = [TypeId::of::()]; + let reset_onboarding_action_types = [TypeId::of::()]; + let all_action_types = [ + TypeId::of::(), + TypeId::of::(), + zed_actions::OpenZedPredictOnboarding.type_id(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + filter.hide_action_types(&reset_onboarding_action_types); + filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); + }); + + cx.observe_global::(move |cx| { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let has_feature_flag = cx.has_flag::(); + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if is_ai_disabled { + filter.hide_action_types(&all_action_types); + } else if has_feature_flag { + filter.show_action_types(&rate_completion_action_types); + } else { + filter.hide_action_types(&rate_completion_action_types); + } + }); + }) + .detach(); + + cx.observe_flag::(move |is_enabled, cx| { + if !DisableAiSettings::get_global(cx).disable_ai { + if is_enabled { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_action_types(&rate_completion_action_types); + }); + } else { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + }); + } + } + }) + .detach(); +} diff --git a/crates/zeta/src/rate_prediction_modal.rs b/crates/edit_prediction_ui/src/rate_prediction_modal.rs similarity index 95% rename from crates/zeta/src/rate_prediction_modal.rs rename to crates/edit_prediction_ui/src/rate_prediction_modal.rs index 0cceb86608ed609122c81d406c71280894789e88..8e754b33dc18c5be60bc052c33aa08cdcb980acb 100644 --- a/crates/zeta/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs @@ -1,7 +1,8 @@ -use crate::{EditPrediction, EditPredictionRating, Zeta}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use cloud_zeta2_prompt::write_codeblock; +use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore}; use editor::{Editor, ExcerptRange, MultiBuffer}; +use feature_flags::FeatureFlag; use gpui::{ App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable, Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*, @@ -9,9 +10,7 @@ use gpui::{ use language::{LanguageRegistry, Point, language_settings}; use markdown::{Markdown, MarkdownStyle}; use settings::Settings as _; -use std::fmt::Write; -use std::sync::Arc; -use std::time::Duration; +use std::{fmt::Write, sync::Arc, time::Duration}; use theme::ThemeSettings; use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*}; use workspace::{ModalView, Workspace}; @@ -34,8 +33,14 @@ actions!( ] ); +pub struct PredictEditsRatePredictionsFeatureFlag; + +impl FeatureFlag for PredictEditsRatePredictionsFeatureFlag { + const NAME: &'static str = "predict-edits-rate-completions"; +} + pub struct RatePredictionsModal { - zeta: Entity, + ep_store: Entity, language_registry: Arc, active_prediction: Option, selected_index: usize, @@ -68,10 +73,10 @@ impl RatePredictionView { impl RatePredictionsModal { pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { - if let Some(zeta) = Zeta::try_global(cx) { + if let Some(ep_store) = EditPredictionStore::try_global(cx) { let language_registry = workspace.app_state().languages.clone(); workspace.toggle_modal(window, cx, |window, cx| { - RatePredictionsModal::new(zeta, language_registry, window, cx) + RatePredictionsModal::new(ep_store, language_registry, window, cx) }); telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction"); @@ -79,15 +84,15 @@ impl RatePredictionsModal { } pub fn new( - zeta: Entity, + ep_store: Entity, language_registry: Arc, window: &mut Window, cx: &mut Context, ) -> Self { - let subscription = cx.observe(&zeta, |_, _, cx| cx.notify()); + let subscription = cx.observe(&ep_store, |_, _, cx| cx.notify()); Self { - zeta, + ep_store, language_registry, selected_index: 0, focus_handle: cx.focus_handle(), @@ -113,7 +118,7 @@ impl RatePredictionsModal { self.selected_index += 1; self.selected_index = usize::min( self.selected_index, - self.zeta.read(cx).shown_predictions().count(), + self.ep_store.read(cx).shown_predictions().count(), ); cx.notify(); } @@ -130,7 +135,7 @@ impl RatePredictionsModal { fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) { let next_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -146,11 +151,11 @@ impl RatePredictionsModal { } fn select_prev_edit(&mut self, _: &PreviousEdit, _: &mut Window, cx: &mut Context) { - let zeta = self.zeta.read(cx); - let completions_len = zeta.shown_completions_len(); + let ep_store = self.ep_store.read(cx); + let completions_len = ep_store.shown_completions_len(); let prev_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .rev() @@ -173,7 +178,7 @@ impl RatePredictionsModal { } fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { - self.selected_index = self.zeta.read(cx).shown_completions_len() - 1; + self.selected_index = self.ep_store.read(cx).shown_completions_len() - 1; cx.notify(); } @@ -183,9 +188,9 @@ impl RatePredictionsModal { window: &mut Window, cx: &mut Context, ) { - self.zeta.update(cx, |zeta, cx| { + self.ep_store.update(cx, |ep_store, cx| { if let Some(active) = &self.active_prediction { - zeta.rate_prediction( + ep_store.rate_prediction( &active.prediction, EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), @@ -216,8 +221,8 @@ impl RatePredictionsModal { return; } - self.zeta.update(cx, |zeta, cx| { - zeta.rate_prediction( + self.ep_store.update(cx, |ep_store, cx| { + ep_store.rate_prediction( &active.prediction, EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), @@ -254,7 +259,7 @@ impl RatePredictionsModal { cx: &mut Context, ) { let completion = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -267,7 +272,7 @@ impl RatePredictionsModal { fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let completion = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -288,7 +293,7 @@ impl RatePredictionsModal { // Avoid resetting completion rating if it's already selected. if let Some(prediction) = prediction { self.selected_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .enumerate() @@ -376,7 +381,7 @@ impl RatePredictionsModal { &included_file.path, &included_file.excerpts, if included_file.path == prediction.inputs.cursor_path { - cursor_insertions + cursor_insertions.as_slice() } else { &[] }, @@ -564,7 +569,7 @@ impl RatePredictionsModal { let border_color = cx.theme().colors().border; let bg_color = cx.theme().colors().editor_background; - let rated = self.zeta.read(cx).is_prediction_rated(&completion_id); + let rated = self.ep_store.read(cx).is_prediction_rated(&completion_id); let feedback_empty = active_prediction .feedback_editor .read(cx) @@ -715,7 +720,7 @@ impl RatePredictionsModal { } fn render_shown_completions(&self, cx: &Context) -> impl Iterator { - self.zeta + self.ep_store .read(cx) .shown_predictions() .cloned() @@ -725,7 +730,7 @@ impl RatePredictionsModal { .active_prediction .as_ref() .is_some_and(|selected| selected.prediction.id == completion.id); - let rated = self.zeta.read(cx).is_prediction_rated(&completion.id); + let rated = self.ep_store.read(cx).is_prediction_rated(&completion.id); let (icon_name, icon_color, tooltip_text) = match (rated, completion.edits.is_empty()) { diff --git a/crates/edit_prediction_button/src/sweep_api_token_modal.rs b/crates/edit_prediction_ui/src/sweep_api_token_modal.rs similarity index 92% rename from crates/edit_prediction_button/src/sweep_api_token_modal.rs rename to crates/edit_prediction_ui/src/sweep_api_token_modal.rs index ab2102f25a2a7291644ca67ab3c89fd47da7ac0a..80366fc2ac691f165d44e1e6a29a633522146984 100644 --- a/crates/edit_prediction_button/src/sweep_api_token_modal.rs +++ b/crates/edit_prediction_ui/src/sweep_api_token_modal.rs @@ -1,10 +1,10 @@ +use edit_prediction::EditPredictionStore; use gpui::{ DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render, }; use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*}; use ui_input::InputField; use workspace::ModalView; -use zeta::Zeta; pub struct SweepApiKeyModal { api_key_input: Entity, @@ -29,9 +29,10 @@ impl SweepApiKeyModal { let api_key = self.api_key_input.read(cx).text(cx); let api_key = (!api_key.trim().is_empty()).then_some(api_key); - if let Some(zeta) = Zeta::try_global(cx) { - zeta.update(cx, |zeta, cx| { - zeta.sweep_ai + if let Some(ep_store) = EditPredictionStore::try_global(cx) { + ep_store.update(cx, |ep_store, cx| { + ep_store + .sweep_ai .set_api_token(api_key, cx) .detach_and_log_err(cx); }); diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 736916ebbf74f20f11e8c03a0e584bd8ae92e07d..94c9fb10f50f8e0440b2e91cf0c16d1f701d9451 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -49,7 +49,7 @@ fs.workspace = true git.workspace = true gpui.workspace = true indoc.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true itertools.workspace = true language.workspace = true linkify.workspace = true diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index a1839144a47a81f668ba2743cd5e362f6711d0e9..bfce1532ce78699e1fb524fd594df1ba83c864a5 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,4 +1,4 @@ -use edit_prediction::EditPredictionProvider; +use edit_prediction_types::EditPredictionDelegate; use gpui::{Entity, KeyBinding, Modifiers, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; @@ -15,7 +15,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let absolute_zero_celsius = ˇ;"); @@ -37,7 +37,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let pi = ˇ\"foo\";"); @@ -59,7 +59,7 @@ async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -128,7 +128,7 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 3+ lines above the proposed edit @@ -233,7 +233,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui: init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default()); + let provider = cx.new(|_| FakeNonZedEditPredictionDelegate::default()); assign_editor_completion_provider_non_zed(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -281,7 +281,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA cx.update(|cx| cx.bind_keys([KeyBinding::new("ctrl-shift-a", AcceptEditPrediction, None)])); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let x = ˇ;"); @@ -371,7 +371,7 @@ fn accept_completion(cx: &mut EditorTestContext) { } fn propose_edits( - provider: &Entity, + provider: &Entity, edits: Vec<(Range, &str)>, cx: &mut EditorTestContext, ) { @@ -383,7 +383,7 @@ fn propose_edits( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -393,7 +393,7 @@ fn propose_edits( } fn assign_editor_completion_provider( - provider: Entity, + provider: Entity, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -402,7 +402,7 @@ fn assign_editor_completion_provider( } fn propose_edits_non_zed( - provider: &Entity, + provider: &Entity, edits: Vec<(Range, &str)>, cx: &mut EditorTestContext, ) { @@ -414,7 +414,7 @@ fn propose_edits_non_zed( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -424,7 +424,7 @@ fn propose_edits_non_zed( } fn assign_editor_completion_provider_non_zed( - provider: Entity, + provider: Entity, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -433,17 +433,20 @@ fn assign_editor_completion_provider_non_zed( } #[derive(Default, Clone)] -pub struct FakeEditPredictionProvider { - pub completion: Option, +pub struct FakeEditPredictionDelegate { + pub completion: Option, } -impl FakeEditPredictionProvider { - pub fn set_edit_prediction(&mut self, completion: Option) { +impl FakeEditPredictionDelegate { + pub fn set_edit_prediction( + &mut self, + completion: Option, + ) { self.completion = completion; } } -impl EditPredictionProvider for FakeEditPredictionProvider { +impl EditPredictionDelegate for FakeEditPredictionDelegate { fn name() -> &'static str { "fake-completion-provider" } @@ -452,7 +455,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { "Fake Completion Provider" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -486,7 +489,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { &mut self, _buffer: gpui::Entity, _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, + _direction: edit_prediction_types::Direction, _cx: &mut gpui::Context, ) { } @@ -500,23 +503,26 @@ impl EditPredictionProvider for FakeEditPredictionProvider { _buffer: &gpui::Entity, _cursor_position: language::Anchor, _cx: &mut gpui::Context, - ) -> Option { + ) -> Option { self.completion.clone() } } #[derive(Default, Clone)] -pub struct FakeNonZedEditPredictionProvider { - pub completion: Option, +pub struct FakeNonZedEditPredictionDelegate { + pub completion: Option, } -impl FakeNonZedEditPredictionProvider { - pub fn set_edit_prediction(&mut self, completion: Option) { +impl FakeNonZedEditPredictionDelegate { + pub fn set_edit_prediction( + &mut self, + completion: Option, + ) { self.completion = completion; } } -impl EditPredictionProvider for FakeNonZedEditPredictionProvider { +impl EditPredictionDelegate for FakeNonZedEditPredictionDelegate { fn name() -> &'static str { "fake-non-zed-provider" } @@ -525,7 +531,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { "Fake Non-Zed Provider" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { false } @@ -559,7 +565,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { &mut self, _buffer: gpui::Entity, _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, + _direction: edit_prediction_types::Direction, _cx: &mut gpui::Context, ) { } @@ -573,7 +579,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { _buffer: &gpui::Entity, _cursor_position: language::Anchor, _cx: &mut gpui::Context, - ) -> Option { + ) -> Option { self.completion.clone() } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 306d7a272b0b8c33e66803ccdbbd74194fde403a..6651cce374001865d21dfdb182659f2a8c008305 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -51,7 +51,7 @@ pub mod test; pub(crate) use actions::*; pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use edit_prediction::Direction; +pub use edit_prediction_types::Direction; pub use editor_settings::{ CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, @@ -92,7 +92,7 @@ use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; -use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle}; +use edit_prediction_types::{EditPredictionDelegate, EditPredictionDelegateHandle}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; use futures::{ @@ -1120,7 +1120,7 @@ pub struct Editor { pending_mouse_down: Option>>>, gutter_hovered: bool, hovered_link_state: Option, - edit_prediction_provider: Option, + edit_prediction_provider: Option, code_action_providers: Vec>, active_edit_prediction: Option, /// Used to prevent flickering as the user types while the menu is open @@ -1562,8 +1562,8 @@ pub struct RenameState { struct InvalidationStack(Vec); -struct RegisteredEditPredictionProvider { - provider: Arc, +struct RegisteredEditPredictionDelegate { + provider: Arc, _subscription: Subscription, } @@ -2988,9 +2988,9 @@ impl Editor { window: &mut Window, cx: &mut Context, ) where - T: EditPredictionProvider, + T: EditPredictionDelegate, { - self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider { + self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionDelegate { _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { if this.focus_handle.is_focused(window) { this.update_visible_edit_prediction(window, cx); @@ -7394,7 +7394,7 @@ impl Editor { && self .edit_prediction_provider .as_ref() - .is_some_and(|provider| provider.provider.show_completions_in_menu()); + .is_some_and(|provider| provider.provider.show_predictions_in_menu()); let preview_requires_modifier = all_language_settings(file, cx).edit_predictions_mode() == EditPredictionsMode::Subtle; @@ -8095,12 +8095,12 @@ impl Editor { let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; let (completion_id, edits, edit_preview) = match edit_prediction { - edit_prediction::EditPrediction::Local { + edit_prediction_types::EditPrediction::Local { id, edits, edit_preview, } => (id, edits, edit_preview), - edit_prediction::EditPrediction::Jump { + edit_prediction_types::EditPrediction::Jump { id, snapshot, target, @@ -8241,7 +8241,7 @@ impl Editor { Some(()) } - pub fn edit_prediction_provider(&self) -> Option> { + pub fn edit_prediction_provider(&self) -> Option> { Some(self.edit_prediction_provider.as_ref()?.provider.clone()) } @@ -9563,7 +9563,7 @@ impl Editor { editor_bg_color.blend(accent_color.opacity(0.6)) } fn get_prediction_provider_icon_name( - provider: &Option, + provider: &Option, ) -> IconName { match provider { Some(provider) => match provider.provider.name() { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 7ab3dcc2345dd8a140b7c4762dc5afadb9cef484..683972254ce0ffb719679d431f0a72485cee97f2 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ JoinLines, code_context_menus::CodeContextMenu, - edit_prediction_tests::FakeEditPredictionProvider, + edit_prediction_tests::FakeEditPredictionDelegate, element::StickyHeader, linked_editing_ranges::LinkedEditingRanges, scroll::scroll_amount::ScrollAmount, @@ -8636,7 +8636,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); }); @@ -8659,7 +8659,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 47b6f1230ac747c2633327d1be923d33388cf179..26615aea0f7566ec6dbbd66a128c1a395cc1b9bc 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -1,11 +1,5 @@ use crate::FeatureFlag; -pub struct PredictEditsRateCompletionsFeatureFlag; - -impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag { - const NAME: &'static str = "predict-edits-rate-completions"; -} - pub struct NotebookFeatureFlag; impl FeatureFlag for NotebookFeatureFlag { diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index 5b86367f35d508579ac6ba999fc8c9236e7fd66a..c2d0c48a9e7733402eae32886c0863326882c134 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -16,7 +16,7 @@ doctest = false anyhow.workspace = true client.workspace = true collections.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs index 7a9963dbc424185c52be6879a0a9e722db7106b2..527f4ec37da17c784d3323ebc87a23eb914905ea 100644 --- a/crates/supermaven/src/supermaven.rs +++ b/crates/supermaven/src/supermaven.rs @@ -1,7 +1,7 @@ mod messages; -mod supermaven_completion_provider; +mod supermaven_edit_prediction_delegate; -pub use supermaven_completion_provider::*; +pub use supermaven_edit_prediction_delegate::*; use anyhow::{Context as _, Result}; #[allow(unused_imports)] diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_edit_prediction_delegate.rs similarity index 95% rename from crates/supermaven/src/supermaven_completion_provider.rs rename to crates/supermaven/src/supermaven_edit_prediction_delegate.rs index 9d5e256aca1b66644145cb688851d0ec5c1b81b9..578bc894f223fd458f510694194aebe633d7a6db 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_edit_prediction_delegate.rs @@ -1,6 +1,6 @@ use crate::{Supermaven, SupermavenCompletionStateId}; use anyhow::Result; -use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; +use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate}; use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; use language::{Anchor, Buffer, BufferSnapshot}; @@ -15,7 +15,7 @@ use unicode_segmentation::UnicodeSegmentation; pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); -pub struct SupermavenCompletionProvider { +pub struct SupermavenEditPredictionDelegate { supermaven: Entity, buffer_id: Option, completion_id: Option, @@ -25,7 +25,7 @@ pub struct SupermavenCompletionProvider { completion_position: Option, } -impl SupermavenCompletionProvider { +impl SupermavenEditPredictionDelegate { pub fn new(supermaven: Entity) -> Self { Self { supermaven, @@ -104,7 +104,7 @@ fn completion_from_diff( } } -impl EditPredictionProvider for SupermavenCompletionProvider { +impl EditPredictionDelegate for SupermavenEditPredictionDelegate { fn name() -> &'static str { "supermaven" } @@ -113,7 +113,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { "Supermaven" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -269,8 +269,8 @@ impl EditPredictionProvider for SupermavenCompletionProvider { } fn reset_completion_cache( - provider: &mut SupermavenCompletionProvider, - _cx: &mut Context, + provider: &mut SupermavenEditPredictionDelegate, + _cx: &mut Context, ) { provider.pending_refresh = None; provider.completion_id = None; diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 3358cc5d32bea308083ae1f6ee06268cf22d670a..6ee7d0a4ea75ff5e13a4db6f5fe73c2a5ba80193 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -50,7 +50,6 @@ debugger_tools.workspace = true debugger_ui.workspace = true diagnostics.workspace = true editor.workspace = true -zeta2_tools.workspace = true env_logger.workspace = true extension.workspace = true extension_host.workspace = true @@ -74,7 +73,8 @@ gpui = { workspace = true, features = [ gpui_tokio.workspace = true rayon.workspace = true -edit_prediction_button.workspace = true +edit_prediction.workspace = true +edit_prediction_ui.workspace = true http_client.workspace = true image_viewer.workspace = true inspector_ui.workspace = true @@ -160,7 +160,6 @@ web_search_providers.workspace = true workspace.workspace = true zed_actions.workspace = true zed_env_vars.workspace = true -zeta.workspace = true zlog.workspace = true zlog_settings.workspace = true chrono.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index f92c819dd22c69d95533d16249345e6128e9ded0..10f599e876032bf297d3eaf173093a308d666cc9 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -581,7 +581,7 @@ pub fn main() { language_model::init(app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); acp_tools::init(cx); - zeta2_tools::init(cx); + edit_prediction_ui::init(cx); web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); @@ -640,7 +640,7 @@ pub fn main() { settings_ui::init(cx); keymap_editor::init(cx); extensions_ui::init(cx); - zeta::init(cx); + edit_prediction::init(cx); inspector_ui::init(app_state.clone(), cx); json_schema_store::init(cx); miniprofiler_ui::init(*STARTUP_TIME.get().unwrap(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 49a43eae47fe36c9cd93f3ce6371cf39c5f5e514..164d6b8383fe940e3a92d5461edbff878300474a 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -401,8 +401,8 @@ pub fn initialize_workspace( unstable_version_notification(cx); let edit_prediction_menu_handle = PopoverMenuHandle::default(); - let edit_prediction_button = cx.new(|cx| { - edit_prediction_button::EditPredictionButton::new( + let edit_prediction_ui = cx.new(|cx| { + edit_prediction_ui::EditPredictionButton::new( app_state.fs.clone(), app_state.user_store.clone(), edit_prediction_menu_handle.clone(), @@ -411,7 +411,7 @@ pub fn initialize_workspace( ) }); workspace.register_action({ - move |_, _: &edit_prediction_button::ToggleMenu, window, cx| { + move |_, _: &edit_prediction_ui::ToggleMenu, window, cx| { edit_prediction_menu_handle.toggle(window, cx); } }); @@ -450,7 +450,7 @@ pub fn initialize_workspace( status_bar.add_left_item(lsp_button, window, cx); status_bar.add_left_item(diagnostic_summary, window, cx); status_bar.add_left_item(activity_indicator, window, cx); - status_bar.add_right_item(edit_prediction_button, window, cx); + status_bar.add_right_item(edit_prediction_ui, window, cx); status_bar.add_right_item(active_buffer_language, window, cx); status_bar.add_right_item(active_toolchain_language, window, cx); status_bar.add_right_item(line_ending_indicator, window, cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index f413fd94cb1a48adb213120364ed2f59c4cf58e0..2d5746b87ab20de5d0aca47a4d5da60b9ec33d2a 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -1,7 +1,8 @@ use client::{Client, UserStore}; -use codestral::CodestralCompletionProvider; +use codestral::CodestralEditPredictionDelegate; use collections::HashMap; -use copilot::{Copilot, CopilotCompletionProvider}; +use copilot::{Copilot, CopilotEditPredictionDelegate}; +use edit_prediction::{SweepFeatureFlag, ZedEditPredictionDelegate, Zeta2FeatureFlag}; use editor::Editor; use feature_flags::FeatureFlagAppExt; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; @@ -12,9 +13,8 @@ use settings::{ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore, }; use std::{cell::RefCell, rc::Rc, sync::Arc}; -use supermaven::{Supermaven, SupermavenCompletionProvider}; +use supermaven::{Supermaven, SupermavenEditPredictionDelegate}; use ui::Window; -use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider}; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); @@ -59,7 +59,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { }) .detach(); - cx.on_action(clear_zeta_edit_history); + cx.on_action(clear_edit_prediction_store_edit_history); let mut provider = all_language_settings(None, cx).edit_predictions.provider; cx.subscribe(&user_store, { @@ -100,9 +100,9 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { .detach(); } -fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) { - if let Some(zeta) = zeta::Zeta::try_global(cx) { - zeta.update(cx, |zeta, _| zeta.clear_history()); +fn clear_edit_prediction_store_edit_history(_: &edit_prediction::ClearHistory, cx: &mut App) { + if let Some(ep_store) = edit_prediction::EditPredictionStore::try_global(cx) { + ep_store.update(cx, |ep_store, _| ep_store.clear_history()); } } @@ -176,7 +176,7 @@ fn assign_edit_prediction_provider( match provider { EditPredictionProvider::None => { - editor.set_edit_prediction_provider::(None, window, cx); + editor.set_edit_prediction_provider::(None, window, cx); } EditPredictionProvider::Copilot => { if let Some(copilot) = Copilot::global(cx) { @@ -187,55 +187,61 @@ fn assign_edit_prediction_provider( copilot.register_buffer(&buffer, cx); }); } - let provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); editor.set_edit_prediction_provider(Some(provider), window, cx); } } EditPredictionProvider::Supermaven => { if let Some(supermaven) = Supermaven::global(cx) { - let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven)); + let provider = cx.new(|_| SupermavenEditPredictionDelegate::new(supermaven)); editor.set_edit_prediction_provider(Some(provider), window, cx); } } EditPredictionProvider::Codestral => { let http_client = client.http_client(); - let provider = cx.new(|_| CodestralCompletionProvider::new(http_client)); + let provider = cx.new(|_| CodestralEditPredictionDelegate::new(http_client)); editor.set_edit_prediction_provider(Some(provider), window, cx); } value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { - let zeta = zeta::Zeta::global(client, &user_store, cx); + let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx); if let Some(project) = editor.project() && let Some(buffer) = &singleton_buffer && buffer.read(cx).file().is_some() { - let has_model = zeta.update(cx, |zeta, cx| { + let has_model = ep_store.update(cx, |ep_store, cx| { let model = if let EditPredictionProvider::Experimental(name) = value { if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() { - zeta::ZetaEditPredictionModel::Sweep + edit_prediction::EditPredictionModel::Sweep } else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() { - zeta::ZetaEditPredictionModel::Zeta2 + edit_prediction::EditPredictionModel::Zeta2 } else { return false; } } else if user_store.read(cx).current_user().is_some() { - zeta::ZetaEditPredictionModel::Zeta1 + edit_prediction::EditPredictionModel::Zeta1 } else { return false; }; - zeta.set_edit_prediction_model(model); - zeta.register_buffer(buffer, project, cx); + ep_store.set_edit_prediction_model(model); + ep_store.register_buffer(buffer, project, cx); true }); if has_model { let provider = cx.new(|cx| { - ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx) + ZedEditPredictionDelegate::new( + project.clone(), + singleton_buffer, + &client, + &user_store, + cx, + ) }); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml deleted file mode 100644 index b90934e67c2a689e1f7bb9704ff28a408de3049a..0000000000000000000000000000000000000000 --- a/crates/zeta/Cargo.toml +++ /dev/null @@ -1,85 +0,0 @@ -[package] -name = "zeta" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/zeta.rs" - -[features] -eval-support = [] - -[dependencies] -ai_onboarding.workspace = true -anyhow.workspace = true -arrayvec.workspace = true -brotli.workspace = true -buffer_diff.workspace = true -client.workspace = true -cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true -collections.workspace = true -command_palette_hooks.workspace = true -copilot.workspace = true -credentials_provider.workspace = true -db.workspace = true -edit_prediction.workspace = true -edit_prediction_context.workspace = true -edit_prediction_context2.workspace = true -editor.workspace = true -feature_flags.workspace = true -fs.workspace = true -futures.workspace = true -gpui.workspace = true -indoc.workspace = true -itertools.workspace = true -language.workspace = true -language_model.workspace = true -log.workspace = true -lsp.workspace = true -markdown.workspace = true -menu.workspace = true -open_ai.workspace = true -postage.workspace = true -pretty_assertions.workspace = true -project.workspace = true -rand.workspace = true -regex.workspace = true -release_channel.workspace = true -semver.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -smol.workspace = true -strsim.workspace = true -strum.workspace = true -telemetry.workspace = true -telemetry_events.workspace = true -theme.workspace = true -thiserror.workspace = true -ui.workspace = true -util.workspace = true -uuid.workspace = true -workspace.workspace = true -worktree.workspace = true -zed_actions.workspace = true - -[dev-dependencies] -clock = { workspace = true, features = ["test-support"] } -cloud_api_types.workspace = true -cloud_llm_client = { workspace = true, features = ["test-support"] } -ctor.workspace = true -gpui = { workspace = true, features = ["test-support"] } -indoc.workspace = true -language = { workspace = true, features = ["test-support"] } -language_model = { workspace = true, features = ["test-support"] } -lsp.workspace = true -parking_lot.workspace = true -project = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/zeta/src/retrieval_search.rs b/crates/zeta/src/retrieval_search.rs deleted file mode 100644 index f429f167744422c3641b5a68ca662af48c8e1614..0000000000000000000000000000000000000000 --- a/crates/zeta/src/retrieval_search.rs +++ /dev/null @@ -1,490 +0,0 @@ -use anyhow::Result; -use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; -use collections::HashMap; -use edit_prediction_context2::{RelatedExcerpt, RelatedFile}; -use futures::{ - StreamExt, - channel::mpsc::{self, UnboundedSender}, -}; -use gpui::{AppContext, AsyncApp, Entity}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint}; -use project::{ - Project, ProjectPath, WorktreeSettings, - search::{SearchQuery, SearchResult}, -}; -use smol::channel; -use std::ops::Range; -use util::{ - ResultExt as _, - paths::{PathMatcher, PathStyle}, -}; -use workspace::item::Settings as _; - -#[cfg(feature = "eval-support")] -type CachedSearchResults = std::collections::BTreeMap>>; - -pub async fn run_retrieval_searches( - queries: Vec, - project: Entity, - #[cfg(feature = "eval-support")] eval_cache: Option>, - cx: &mut AsyncApp, -) -> Result> { - #[cfg(feature = "eval-support")] - let cache = if let Some(eval_cache) = eval_cache { - use crate::EvalCacheEntryKind; - use anyhow::Context; - use collections::FxHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = FxHasher::default(); - project.read_with(cx, |project, cx| { - let mut worktrees = project.worktrees(cx); - let Some(worktree) = worktrees.next() else { - panic!("Expected a single worktree in eval project. Found none."); - }; - assert!( - worktrees.next().is_none(), - "Expected a single worktree in eval project. Found more than one." - ); - worktree.read(cx).abs_path().hash(&mut hasher); - })?; - - queries.hash(&mut hasher); - let key = (EvalCacheEntryKind::Search, hasher.finish()); - - if let Some(cached_results) = eval_cache.read(key) { - let file_results = serde_json::from_str::(&cached_results) - .context("Failed to deserialize cached search results")?; - let mut results = Vec::new(); - - for (path, ranges) in file_results { - let project_path = project.update(cx, |project, cx| { - project.find_project_path(path, cx).unwrap() - })?; - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(project_path.clone(), cx) - })? - .await?; - let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - let mut ranges: Vec<_> = ranges - .into_iter() - .map( - |Range { - start: (start_row, start_col), - end: (end_row, end_col), - }| { - snapshot.anchor_before(Point::new(start_row, start_col)) - ..snapshot.anchor_after(Point::new(end_row, end_col)) - }, - ) - .collect(); - merge_anchor_ranges(&mut ranges, &snapshot); - results.push(RelatedFile { - path: project_path, - buffer: buffer.downgrade(), - excerpts: ranges - .into_iter() - .map(|range| RelatedExcerpt { - point_range: range.to_point(&snapshot), - text: snapshot.as_rope().slice(range.to_offset(&snapshot)), - anchor_range: range, - }) - .collect(), - max_row: snapshot.max_point().row, - }); - } - - return Ok(results); - } - - Some((eval_cache, serde_json::to_string_pretty(&queries)?, key)) - } else { - None - }; - - let (exclude_matcher, path_style) = project.update(cx, |project, cx| { - let global_settings = WorktreeSettings::get_global(cx); - let exclude_patterns = global_settings - .file_scan_exclusions - .sources() - .chain(global_settings.private_files.sources()); - let path_style = project.path_style(cx); - anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style)) - })??; - - let (results_tx, mut results_rx) = mpsc::unbounded(); - - for query in queries { - let exclude_matcher = exclude_matcher.clone(); - let results_tx = results_tx.clone(); - let project = project.clone(); - cx.spawn(async move |cx| { - run_query( - query, - results_tx.clone(), - path_style, - exclude_matcher, - &project, - cx, - ) - .await - .log_err(); - }) - .detach() - } - drop(results_tx); - - #[cfg(feature = "eval-support")] - let cache = cache.clone(); - cx.background_spawn(async move { - let mut results: Vec = Vec::default(); - let mut snapshots = HashMap::default(); - - let mut total_bytes = 0; - 'outer: while let Some((project_path, buffer, snapshot, excerpts)) = results_rx.next().await - { - let existing = results - .iter_mut() - .find(|related_file| related_file.buffer.entity_id() == buffer.entity_id()); - let existing = match existing { - Some(existing) => existing, - None => { - results.push(RelatedFile { - path: project_path, - buffer: buffer.downgrade(), - excerpts: Vec::new(), - max_row: snapshot.max_point().row, - }); - results.last_mut().unwrap() - } - }; - // let existing = results.entry(buffer).or_default(); - existing.excerpts.reserve(excerpts.len()); - - for (range, size) in excerpts { - // Blunt trimming of the results until we have a proper algorithmic filtering step - if (total_bytes + size) > MAX_RESULTS_LEN { - log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B"); - break 'outer; - } - total_bytes += size; - existing.excerpts.push(RelatedExcerpt { - point_range: range.to_point(&snapshot), - text: snapshot.as_rope().slice(range.to_offset(&snapshot)), - anchor_range: range, - }); - } - snapshots.insert(buffer.entity_id(), snapshot); - } - - #[cfg(feature = "eval-support")] - if let Some((cache, queries, key)) = cache { - let cached_results: CachedSearchResults = results - .iter() - .map(|related_file| { - let mut ranges = related_file - .excerpts - .iter() - .map( - |RelatedExcerpt { - point_range: Range { start, end }, - .. - }| { - (start.row, start.column)..(end.row, end.column) - }, - ) - .collect::>(); - ranges.sort_unstable_by_key(|range| (range.start, range.end)); - (related_file.path.path.as_std_path().to_path_buf(), ranges) - }) - .collect(); - cache.write( - key, - &queries, - &serde_json::to_string_pretty(&cached_results)?, - ); - } - - for related_file in results.iter_mut() { - related_file.merge_excerpts(); - } - - Ok(results) - }) - .await -} - -#[cfg(feature = "eval-support")] -pub(crate) fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapshot) { - ranges.sort_unstable_by(|a, b| { - a.start - .cmp(&b.start, snapshot) - .then(b.end.cmp(&a.end, snapshot)) - }); - - let mut index = 1; - while index < ranges.len() { - if ranges[index - 1] - .end - .cmp(&ranges[index].start, snapshot) - .is_ge() - { - let removed = ranges.remove(index); - if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() { - ranges[index - 1].end = removed.end; - } - } else { - index += 1; - } - } -} - -const MAX_EXCERPT_LEN: usize = 768; -const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; - -struct SearchJob { - buffer: Entity, - snapshot: BufferSnapshot, - project_path: ProjectPath, - ranges: Vec>, - query_ix: usize, - jobs_tx: channel::Sender, -} - -async fn run_query( - input_query: SearchToolQuery, - results_tx: UnboundedSender<( - ProjectPath, - Entity, - BufferSnapshot, - Vec<(Range, usize)>, - )>, - path_style: PathStyle, - exclude_matcher: PathMatcher, - project: &Entity, - cx: &mut AsyncApp, -) -> Result<()> { - let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?; - - let make_search = |regex: &str| -> Result { - SearchQuery::regex( - regex, - false, - true, - false, - true, - include_matcher.clone(), - exclude_matcher.clone(), - true, - None, - ) - }; - - if let Some(outer_syntax_regex) = input_query.syntax_node.first() { - let outer_syntax_query = make_search(outer_syntax_regex)?; - let nested_syntax_queries = input_query - .syntax_node - .into_iter() - .skip(1) - .map(|query| make_search(&query)) - .collect::>>()?; - let content_query = input_query - .content - .map(|regex| make_search(®ex)) - .transpose()?; - - let (jobs_tx, jobs_rx) = channel::unbounded(); - - let outer_search_results_rx = - project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?; - - let outer_search_task = cx.spawn(async move |cx| { - futures::pin_mut!(outer_search_results_rx); - while let Some(SearchResult::Buffer { buffer, ranges }) = - outer_search_results_rx.next().await - { - buffer - .read_with(cx, |buffer, _| buffer.parsing_idle())? - .await; - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let Some(file) = snapshot.file() else { - continue; - }; - - let project_path = cx.update(|cx| ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path().clone(), - })?; - let expanded_ranges: Vec<_> = ranges - .into_iter() - .filter_map(|range| expand_to_parent_range(&range, &snapshot)) - .collect(); - jobs_tx - .send(SearchJob { - project_path, - buffer, - snapshot, - ranges: expanded_ranges, - query_ix: 0, - jobs_tx: jobs_tx.clone(), - }) - .await?; - } - anyhow::Ok(()) - }); - - let n_workers = cx.background_executor().num_cpus(); - let search_job_task = cx.background_executor().scoped(|scope| { - for _ in 0..n_workers { - scope.spawn(async { - while let Ok(job) = jobs_rx.recv().await { - process_nested_search_job( - &results_tx, - &nested_syntax_queries, - &content_query, - job, - ) - .await; - } - }); - } - }); - - search_job_task.await; - outer_search_task.await?; - } else if let Some(content_regex) = &input_query.content { - let search_query = make_search(&content_regex)?; - - let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?; - futures::pin_mut!(results_rx); - - while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let Some(file) = snapshot.file() else { - continue; - }; - let project_path = cx.update(|cx| ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path().clone(), - })?; - - let ranges = ranges - .into_iter() - .map(|range| { - let range = range.to_offset(&snapshot); - let range = expand_to_entire_lines(range, &snapshot); - let size = range.len(); - let range = - snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); - (range, size) - }) - .collect(); - - let send_result = - results_tx.unbounded_send((project_path, buffer.clone(), snapshot.clone(), ranges)); - - if let Err(err) = send_result - && !err.is_disconnected() - { - log::error!("{err}"); - } - } - } else { - log::warn!("Context gathering model produced a glob-only search"); - } - - anyhow::Ok(()) -} - -async fn process_nested_search_job( - results_tx: &UnboundedSender<( - ProjectPath, - Entity, - BufferSnapshot, - Vec<(Range, usize)>, - )>, - queries: &Vec, - content_query: &Option, - job: SearchJob, -) { - if let Some(search_query) = queries.get(job.query_ix) { - let mut subranges = Vec::new(); - for range in job.ranges { - let start = range.start; - let search_results = search_query.search(&job.snapshot, Some(range)).await; - for subrange in search_results { - let subrange = start + subrange.start..start + subrange.end; - subranges.extend(expand_to_parent_range(&subrange, &job.snapshot)); - } - } - job.jobs_tx - .send(SearchJob { - project_path: job.project_path, - buffer: job.buffer, - snapshot: job.snapshot, - ranges: subranges, - query_ix: job.query_ix + 1, - jobs_tx: job.jobs_tx.clone(), - }) - .await - .ok(); - } else { - let ranges = if let Some(content_query) = content_query { - let mut subranges = Vec::new(); - for range in job.ranges { - let start = range.start; - let search_results = content_query.search(&job.snapshot, Some(range)).await; - for subrange in search_results { - let subrange = start + subrange.start..start + subrange.end; - subranges.push(subrange); - } - } - subranges - } else { - job.ranges - }; - - let matches = ranges - .into_iter() - .map(|range| { - let snapshot = &job.snapshot; - let range = expand_to_entire_lines(range, snapshot); - let size = range.len(); - let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); - (range, size) - }) - .collect(); - - let send_result = - results_tx.unbounded_send((job.project_path, job.buffer, job.snapshot, matches)); - - if let Err(err) = send_result - && !err.is_disconnected() - { - log::error!("{err}"); - } - } -} - -fn expand_to_entire_lines(range: Range, snapshot: &BufferSnapshot) -> Range { - let mut point_range = range.to_point(snapshot); - point_range.start.column = 0; - if point_range.end.column > 0 { - point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0)); - } - point_range.to_offset(snapshot) -} - -fn expand_to_parent_range( - range: &Range, - snapshot: &BufferSnapshot, -) -> Option> { - let mut line_range = range.to_point(&snapshot); - line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len; - line_range.end.column = snapshot.line_len(line_range.end.row); - // TODO skip result if matched line isn't the first node line? - - let node = snapshot.syntax_ancestor(line_range)?; - Some(node.byte_range()) -} diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs deleted file mode 100644 index 576067b9844cd668c69411d7a4098975db4a5d26..0000000000000000000000000000000000000000 --- a/crates/zeta/src/zeta.rs +++ /dev/null @@ -1,3890 +0,0 @@ -use anyhow::{Context as _, Result, anyhow, bail}; -use arrayvec::ArrayVec; -use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; -use cloud_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, - EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, - MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, - ZED_VERSION_HEADER_NAME, -}; -use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; -use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES}; -use collections::{HashMap, HashSet}; -use command_palette_hooks::CommandPaletteFilter; -use db::kvp::{Dismissable, KEY_VALUE_STORE}; -use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions, - EditPredictionScoreOptions, Line, SyntaxIndex, -}; -use edit_prediction_context2::{ - RelatedExcerpt, RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile, -}; -use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; -use futures::{ - AsyncReadExt as _, FutureExt as _, StreamExt as _, - channel::{ - mpsc::{self, UnboundedReceiver}, - oneshot, - }, - select_biased, -}; -use gpui::BackgroundExecutor; -use gpui::{ - App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, - http_client::{self, AsyncBody, Method}, - prelude::*, -}; -use language::language_settings::all_language_settings; -use language::{ - Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint, -}; -use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::{LlmApiToken, RefreshLlmTokenListener}; -use open_ai::FunctionDefinition; -use project::{DisableAiSettings, Project, ProjectItem as _, ProjectPath, WorktreeId}; -use release_channel::AppVersion; -use semver::Version; -use serde::de::DeserializeOwned; -use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file}; -use std::any::{Any as _, TypeId}; -use std::collections::{VecDeque, hash_map}; -use telemetry_events::EditPredictionRating; -use workspace::Workspace; - -use std::ops::Range; -use std::path::Path; -use std::rc::Rc; -use std::str::FromStr as _; -use std::sync::{Arc, LazyLock}; -use std::time::{Duration, Instant}; -use std::{env, mem}; -use thiserror::Error; -use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; -use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; - -mod license_detection; -mod onboarding_modal; -mod prediction; -mod provider; -mod rate_prediction_modal; -pub mod retrieval_search; -pub mod sweep_ai; -pub mod udiff; -mod xml_edits; -pub mod zeta1; - -#[cfg(test)] -mod zeta_tests; - -use crate::license_detection::LicenseDetectionWatcher; -use crate::onboarding_modal::ZedPredictModal; -pub use crate::prediction::EditPrediction; -pub use crate::prediction::EditPredictionId; -pub use crate::prediction::EditPredictionInputs; -use crate::prediction::EditPredictionResult; -use crate::rate_prediction_modal::{ - NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction, - ThumbsUpActivePrediction, -}; -pub use crate::sweep_ai::SweepAi; -use crate::zeta1::request_prediction_with_zeta1; -pub use provider::ZetaEditPredictionProvider; - -actions!( - edit_prediction, - [ - /// Resets the edit prediction onboarding state. - ResetOnboarding, - /// Opens the rate completions modal. - RateCompletions, - /// Clears the edit prediction history. - ClearHistory, - ] -); - -/// Maximum number of events to track. -const EVENT_COUNT_MAX: usize = 6; -const CHANGE_GROUPING_LINE_SPAN: u32 = 8; -const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; -const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); - -pub struct SweepFeatureFlag; - -impl FeatureFlag for SweepFeatureFlag { - const NAME: &str = "sweep-ai"; -} -pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { - max_bytes: 512, - min_bytes: 128, - target_before_cursor_over_total_bytes: 0.5, -}; - -pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Lsp(DEFAULT_EXCERPT_OPTIONS); - -pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { - excerpt: DEFAULT_EXCERPT_OPTIONS, -}; - -pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions = - EditPredictionContextOptions { - use_imports: true, - max_retrieved_declarations: 0, - excerpt: DEFAULT_EXCERPT_OPTIONS, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps: true, - }, - }; - -pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { - context: DEFAULT_CONTEXT_OPTIONS, - max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, - max_diagnostic_bytes: 2048, - prompt_format: PromptFormat::DEFAULT, - file_indexing_parallelism: 1, - buffer_change_grouping_interval: Duration::from_secs(1), -}; - -static USE_OLLAMA: LazyLock = - LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); -static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock = LazyLock::new(|| { - env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA { - "qwen3-coder:30b".to_string() - } else { - "yqvev8r3".to_string() - }) -}); -static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { - match env::var("ZED_ZETA2_MODEL").as_deref() { - Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten - Ok(model) => model, - Err(_) if *USE_OLLAMA => "qwen3-coder:30b", - Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten - } - .to_string() -}); -static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { - env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { - if *USE_OLLAMA { - Some("http://localhost:11434/v1/chat/completions".into()) - } else { - None - } - }) -}); - -pub struct Zeta2FeatureFlag; - -impl FeatureFlag for Zeta2FeatureFlag { - const NAME: &'static str = "zeta2"; - - fn enabled_for_staff() -> bool { - true - } -} - -#[derive(Clone)] -struct ZetaGlobal(Entity); - -impl Global for ZetaGlobal {} - -pub struct Zeta { - client: Arc, - user_store: Entity, - llm_token: LlmApiToken, - _llm_token_subscription: Subscription, - projects: HashMap, - use_context: bool, - options: ZetaOptions, - update_required: bool, - debug_tx: Option>, - #[cfg(feature = "eval-support")] - eval_cache: Option>, - edit_prediction_model: ZetaEditPredictionModel, - pub sweep_ai: SweepAi, - data_collection_choice: DataCollectionChoice, - reject_predictions_tx: mpsc::UnboundedSender, - shown_predictions: VecDeque, - rated_predictions: HashSet, -} - -#[derive(Copy, Clone, Default, PartialEq, Eq)] -pub enum ZetaEditPredictionModel { - #[default] - Zeta1, - Zeta2, - Sweep, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct ZetaOptions { - pub context: ContextMode, - pub max_prompt_bytes: usize, - pub max_diagnostic_bytes: usize, - pub prompt_format: predict_edits_v3::PromptFormat, - pub file_indexing_parallelism: usize, - pub buffer_change_grouping_interval: Duration, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum ContextMode { - Agentic(AgenticContextOptions), - Syntax(EditPredictionContextOptions), - Lsp(EditPredictionExcerptOptions), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct AgenticContextOptions { - pub excerpt: EditPredictionExcerptOptions, -} - -impl ContextMode { - pub fn excerpt(&self) -> &EditPredictionExcerptOptions { - match self { - ContextMode::Agentic(options) => &options.excerpt, - ContextMode::Syntax(options) => &options.excerpt, - ContextMode::Lsp(options) => &options, - } - } -} - -#[derive(Debug)] -pub enum ZetaDebugInfo { - ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), - ContextRetrievalFinished(ZetaContextRetrievalFinishedDebugInfo), - EditPredictionRequested(ZetaEditPredictionDebugInfo), -} - -#[derive(Debug)] -pub struct ZetaContextRetrievalStartedDebugInfo { - pub project_entity_id: EntityId, - pub timestamp: Instant, - pub search_prompt: String, -} - -#[derive(Debug)] -pub struct ZetaContextRetrievalFinishedDebugInfo { - pub project_entity_id: EntityId, - pub timestamp: Instant, - pub metadata: Vec<(&'static str, SharedString)>, -} - -#[derive(Debug)] -pub struct ZetaEditPredictionDebugInfo { - pub inputs: EditPredictionInputs, - pub retrieval_time: Duration, - pub buffer: WeakEntity, - pub position: language::Anchor, - pub local_prompt: Result, - pub response_rx: oneshot::Receiver<(Result, Duration)>, -} - -pub type RequestDebugInfo = predict_edits_v3::DebugInfo; - -struct ZetaProject { - events: VecDeque>, - last_event: Option, - recent_paths: VecDeque, - registered_buffers: HashMap, - current_prediction: Option, - next_pending_prediction_id: usize, - pending_predictions: ArrayVec, - context_updates_tx: smol::channel::Sender<()>, - context_updates_rx: smol::channel::Receiver<()>, - last_prediction_refresh: Option<(EntityId, Instant)>, - cancelled_predictions: HashSet, - context: ZetaProjectContext, - license_detection_watchers: HashMap>, - _subscription: gpui::Subscription, -} - -enum ZetaProjectContext { - Syntax(Entity), - Lsp(Entity), - Agentic { - refresh_context_task: Option>>>, - refresh_context_debounce_task: Option>>, - refresh_context_timestamp: Option, - context: Vec, - }, -} - -impl ZetaProject { - pub fn events(&self, cx: &App) -> Vec> { - self.events - .iter() - .cloned() - .chain( - self.last_event - .as_ref() - .and_then(|event| event.finalize(&self.license_detection_watchers, cx)), - ) - .collect() - } - - fn cancel_pending_prediction( - &mut self, - pending_prediction: PendingPrediction, - cx: &mut Context, - ) { - self.cancelled_predictions.insert(pending_prediction.id); - - cx.spawn(async move |this, cx| { - let Some(prediction_id) = pending_prediction.task.await else { - return; - }; - - this.update(cx, |this, _cx| { - this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false); - }) - .ok(); - }) - .detach() - } -} - -#[derive(Debug, Clone)] -struct CurrentEditPrediction { - pub requested_by: PredictionRequestedBy, - pub prediction: EditPrediction, - pub was_shown: bool, -} - -impl CurrentEditPrediction { - fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { - let Some(new_edits) = self - .prediction - .interpolate(&self.prediction.buffer.read(cx)) - else { - return false; - }; - - if self.prediction.buffer != old_prediction.prediction.buffer { - return true; - } - - let Some(old_edits) = old_prediction - .prediction - .interpolate(&old_prediction.prediction.buffer.read(cx)) - else { - return true; - }; - - let requested_by_buffer_id = self.requested_by.buffer_id(); - - // This reduces the occurrence of UI thrash from replacing edits - // - // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. - if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) - && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) - && old_edits.len() == 1 - && new_edits.len() == 1 - { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text.as_ref()) - } else { - true - } - } -} - -#[derive(Debug, Clone)] -enum PredictionRequestedBy { - DiagnosticsUpdate, - Buffer(EntityId), -} - -impl PredictionRequestedBy { - pub fn buffer_id(&self) -> Option { - match self { - PredictionRequestedBy::DiagnosticsUpdate => None, - PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), - } - } -} - -#[derive(Debug)] -struct PendingPrediction { - id: usize, - task: Task>, -} - -/// A prediction from the perspective of a buffer. -#[derive(Debug)] -enum BufferEditPrediction<'a> { - Local { prediction: &'a EditPrediction }, - Jump { prediction: &'a EditPrediction }, -} - -#[cfg(test)] -impl std::ops::Deref for BufferEditPrediction<'_> { - type Target = EditPrediction; - - fn deref(&self) -> &Self::Target { - match self { - BufferEditPrediction::Local { prediction } => prediction, - BufferEditPrediction::Jump { prediction } => prediction, - } - } -} - -struct RegisteredBuffer { - snapshot: BufferSnapshot, - _subscriptions: [gpui::Subscription; 2], -} - -struct LastEvent { - old_snapshot: BufferSnapshot, - new_snapshot: BufferSnapshot, - end_edit_anchor: Option, -} - -impl LastEvent { - pub fn finalize( - &self, - license_detection_watchers: &HashMap>, - cx: &App, - ) -> Option> { - let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); - let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); - - let file = self.new_snapshot.file(); - let old_file = self.old_snapshot.file(); - - let in_open_source_repo = [file, old_file].iter().all(|file| { - file.is_some_and(|file| { - license_detection_watchers - .get(&file.worktree_id(cx)) - .is_some_and(|watcher| watcher.is_project_open_source()) - }) - }); - - let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text()); - - if path == old_path && diff.is_empty() { - None - } else { - Some(Arc::new(predict_edits_v3::Event::BufferChange { - old_path, - path, - diff, - in_open_source_repo, - // TODO: Actually detect if this edit was predicted or not - predicted: false, - })) - } - } -} - -fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc { - if let Some(file) = snapshot.file() { - file.full_path(cx).into() - } else { - Path::new(&format!("untitled-{}", snapshot.remote_id())).into() - } -} - -impl Zeta { - pub fn try_global(cx: &App) -> Option> { - cx.try_global::().map(|global| global.0.clone()) - } - - pub fn global( - client: &Arc, - user_store: &Entity, - cx: &mut App, - ) -> Entity { - cx.try_global::() - .map(|global| global.0.clone()) - .unwrap_or_else(|| { - let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); - cx.set_global(ZetaGlobal(zeta.clone())); - zeta - }) - } - - pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let data_collection_choice = Self::load_data_collection_choice(); - - let llm_token = LlmApiToken::default(); - - let (reject_tx, reject_rx) = mpsc::unbounded(); - cx.background_spawn({ - let client = client.clone(); - let llm_token = llm_token.clone(); - let app_version = AppVersion::global(cx); - let background_executor = cx.background_executor().clone(); - async move { - Self::handle_rejected_predictions( - reject_rx, - client, - llm_token, - app_version, - background_executor, - ) - .await - } - }) - .detach(); - - let mut this = Self { - projects: HashMap::default(), - client, - user_store, - options: DEFAULT_OPTIONS, - use_context: false, - llm_token, - _llm_token_subscription: cx.subscribe( - &refresh_llm_token_listener, - |this, _listener, _event, cx| { - let client = this.client.clone(); - let llm_token = this.llm_token.clone(); - cx.spawn(async move |_this, _cx| { - llm_token.refresh(&client).await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - }, - ), - update_required: false, - debug_tx: None, - #[cfg(feature = "eval-support")] - eval_cache: None, - edit_prediction_model: ZetaEditPredictionModel::Zeta2, - sweep_ai: SweepAi::new(cx), - data_collection_choice, - reject_predictions_tx: reject_tx, - rated_predictions: Default::default(), - shown_predictions: Default::default(), - }; - - this.enable_or_disable_context_retrieval(cx); - let weak_this = cx.weak_entity(); - cx.on_flags_ready(move |_, cx| { - weak_this - .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx)) - .ok(); - }) - .detach(); - cx.observe_global::(|this, cx| { - this.enable_or_disable_context_retrieval(cx); - }) - .detach(); - - this - } - - pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { - self.edit_prediction_model = model; - } - - pub fn has_sweep_api_token(&self) -> bool { - self.sweep_ai - .api_token - .clone() - .now_or_never() - .flatten() - .is_some() - } - - #[cfg(feature = "eval-support")] - pub fn with_eval_cache(&mut self, cache: Arc) { - self.eval_cache = Some(cache); - } - - pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { - let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); - self.debug_tx = Some(debug_watch_tx); - debug_watch_rx - } - - pub fn options(&self) -> &ZetaOptions { - &self.options - } - - pub fn set_options(&mut self, options: ZetaOptions) { - self.options = options; - } - - pub fn set_use_context(&mut self, use_context: bool) { - self.use_context = use_context; - } - - pub fn clear_history(&mut self) { - for zeta_project in self.projects.values_mut() { - zeta_project.events.clear(); - } - } - - pub fn context_for_project<'a>( - &'a self, - project: &Entity, - cx: &'a App, - ) -> &'a [RelatedFile] { - self.projects - .get(&project.entity_id()) - .and_then(|project| match &project.context { - ZetaProjectContext::Syntax(_) => None, - ZetaProjectContext::Lsp(store) => Some(store.read(cx).related_files()), - ZetaProjectContext::Agentic { context, .. } => Some(context.as_slice()), - }) - .unwrap_or(&[]) - } - - pub fn usage(&self, cx: &App) -> Option { - if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 { - self.user_store.read(cx).edit_prediction_usage() - } else { - None - } - } - - pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { - self.get_or_init_zeta_project(project, cx); - } - - pub fn register_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let zeta_project = self.get_or_init_zeta_project(project, cx); - Self::register_buffer_impl(zeta_project, buffer, project, cx); - } - - fn get_or_init_zeta_project( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> &mut ZetaProject { - let entity_id = project.entity_id(); - let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); - self.projects - .entry(entity_id) - .or_insert_with(|| ZetaProject { - context: match &self.options.context { - ContextMode::Agentic(_) => ZetaProjectContext::Agentic { - refresh_context_task: None, - refresh_context_debounce_task: None, - refresh_context_timestamp: None, - context: Vec::new(), - }, - ContextMode::Syntax(_) => ZetaProjectContext::Syntax(cx.new(|cx| { - SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) - })), - ContextMode::Lsp(_) => { - let related_excerpt_store = - cx.new(|cx| RelatedExcerptStore::new(project, cx)); - cx.subscribe( - &related_excerpt_store, - move |this, _, event, _| match event { - RelatedExcerptStoreEvent::StartedRefresh => { - if let Some(debug_tx) = this.debug_tx.clone() { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( - ZetaContextRetrievalStartedDebugInfo { - project_entity_id: entity_id, - timestamp: Instant::now(), - search_prompt: String::new(), - }, - )) - .ok(); - } - } - RelatedExcerptStoreEvent::FinishedRefresh { - cache_hit_count, - cache_miss_count, - mean_definition_latency, - max_definition_latency, - } => { - if let Some(debug_tx) = this.debug_tx.clone() { - debug_tx - .unbounded_send( - ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalFinishedDebugInfo { - project_entity_id: entity_id, - timestamp: Instant::now(), - metadata: vec![ - ( - "Cache Hits", - format!( - "{}/{}", - cache_hit_count, - cache_hit_count - + cache_miss_count - ) - .into(), - ), - ( - "Max LSP Time", - format!( - "{} ms", - max_definition_latency - .as_millis() - ) - .into(), - ), - ( - "Mean LSP Time", - format!( - "{} ms", - mean_definition_latency - .as_millis() - ) - .into(), - ), - ], - }, - ), - ) - .ok(); - } - if let Some(project_state) = this.projects.get(&entity_id) { - project_state.context_updates_tx.send_blocking(()).ok(); - } - } - }, - ) - .detach(); - ZetaProjectContext::Lsp(related_excerpt_store) - } - }, - events: VecDeque::new(), - last_event: None, - recent_paths: VecDeque::new(), - context_updates_rx, - context_updates_tx, - registered_buffers: HashMap::default(), - current_prediction: None, - cancelled_predictions: HashSet::default(), - pending_predictions: ArrayVec::new(), - next_pending_prediction_id: 0, - last_prediction_refresh: None, - license_detection_watchers: HashMap::default(), - _subscription: cx.subscribe(&project, Self::handle_project_event), - }) - } - - pub fn project_context_updates( - &self, - project: &Entity, - ) -> Option> { - let project_state = self.projects.get(&project.entity_id())?; - Some(project_state.context_updates_rx.clone()) - } - - fn handle_project_event( - &mut self, - project: Entity, - event: &project::Event, - cx: &mut Context, - ) { - // TODO [zeta2] init with recent paths - match event { - project::Event::ActiveEntryChanged(Some(active_entry_id)) => { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - let path = project.read(cx).path_for_entry(*active_entry_id, cx); - if let Some(path) = path { - if let Some(ix) = zeta_project - .recent_paths - .iter() - .position(|probe| probe == &path) - { - zeta_project.recent_paths.remove(ix); - } - zeta_project.recent_paths.push_front(path); - } - } - project::Event::DiagnosticsUpdated { .. } => { - if cx.has_flag::() { - self.refresh_prediction_from_diagnostics(project, cx); - } - } - _ => (), - } - } - - fn register_buffer_impl<'a>( - zeta_project: &'a mut ZetaProject, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) -> &'a mut RegisteredBuffer { - let buffer_id = buffer.entity_id(); - - if let Some(file) = buffer.read(cx).file() { - let worktree_id = file.worktree_id(cx); - if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) { - zeta_project - .license_detection_watchers - .entry(worktree_id) - .or_insert_with(|| { - let project_entity_id = project.entity_id(); - cx.observe_release(&worktree, move |this, _worktree, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project_entity_id) - else { - return; - }; - zeta_project.license_detection_watchers.remove(&worktree_id); - }) - .detach(); - Rc::new(LicenseDetectionWatcher::new(&worktree, cx)) - }); - } - } - - match zeta_project.registered_buffers.entry(buffer_id) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - let snapshot = buffer.read(cx).snapshot(); - let project_entity_id = project.entity_id(); - entry.insert(RegisteredBuffer { - snapshot, - _subscriptions: [ - cx.subscribe(buffer, { - let project = project.downgrade(); - move |this, buffer, event, cx| { - if let language::BufferEvent::Edited = event - && let Some(project) = project.upgrade() - { - this.report_changes_for_buffer(&buffer, &project, cx); - } - } - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project_entity_id) - else { - return; - }; - zeta_project.registered_buffers.remove(&buffer_id); - }), - ], - }) - } - } - } - - fn report_changes_for_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let project_state = self.get_or_init_zeta_project(project, cx); - let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx); - - let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version == registered_buffer.snapshot.version { - return; - } - - let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - let end_edit_anchor = new_snapshot - .anchored_edits_since::(&old_snapshot.version) - .last() - .map(|(_, range)| range.end); - let events = &mut project_state.events; - - if let Some(LastEvent { - new_snapshot: last_new_snapshot, - end_edit_anchor: last_end_edit_anchor, - .. - }) = project_state.last_event.as_mut() - { - let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() - == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version; - - let should_coalesce = is_next_snapshot_of_same_buffer - && end_edit_anchor - .as_ref() - .zip(last_end_edit_anchor.as_ref()) - .is_some_and(|(a, b)| { - let a = a.to_point(&new_snapshot); - let b = b.to_point(&new_snapshot); - a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN - }); - - if should_coalesce { - *last_end_edit_anchor = end_edit_anchor; - *last_new_snapshot = new_snapshot; - return; - } - } - - if events.len() + 1 >= EVENT_COUNT_MAX { - events.pop_front(); - } - - if let Some(event) = project_state.last_event.take() { - events.extend(event.finalize(&project_state.license_detection_watchers, cx)); - } - - project_state.last_event = Some(LastEvent { - old_snapshot, - new_snapshot, - end_edit_anchor, - }); - } - - fn current_prediction_for_buffer( - &self, - buffer: &Entity, - project: &Entity, - cx: &App, - ) -> Option> { - let project_state = self.projects.get(&project.entity_id())?; - - let CurrentEditPrediction { - requested_by, - prediction, - .. - } = project_state.current_prediction.as_ref()?; - - if prediction.targets_buffer(buffer.read(cx)) { - Some(BufferEditPrediction::Local { prediction }) - } else { - let show_jump = match requested_by { - PredictionRequestedBy::Buffer(requested_by_buffer_id) => { - requested_by_buffer_id == &buffer.entity_id() - } - PredictionRequestedBy::DiagnosticsUpdate => true, - }; - - if show_jump { - Some(BufferEditPrediction::Jump { prediction }) - } else { - None - } - } - } - - fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { - match self.edit_prediction_model { - ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} - ZetaEditPredictionModel::Sweep => return, - } - - let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - let Some(prediction) = project_state.current_prediction.take() else { - return; - }; - let request_id = prediction.prediction.id.to_string(); - for pending_prediction in mem::take(&mut project_state.pending_predictions) { - project_state.cancel_pending_prediction(pending_prediction, cx); - } - - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - cx.spawn(async move |this, cx| { - let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { - http_client::Url::parse(&predict_edits_url)? - } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/accept", &[])? - }; - - let response = cx - .background_spawn(Self::send_api_request::<()>( - move |builder| { - let req = builder.uri(url.as_ref()).body( - serde_json::to_string(&AcceptEditPredictionBody { - request_id: request_id.clone(), - })? - .into(), - ); - Ok(req?) - }, - client, - llm_token, - app_version, - )) - .await; - - Self::handle_api_response(&this, response, cx)?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - - async fn handle_rejected_predictions( - rx: UnboundedReceiver, - client: Arc, - llm_token: LlmApiToken, - app_version: Version, - background_executor: BackgroundExecutor, - ) { - let mut rx = std::pin::pin!(rx.peekable()); - let mut batched = Vec::new(); - - while let Some(rejection) = rx.next().await { - batched.push(rejection); - - if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { - select_biased! { - next = rx.as_mut().peek().fuse() => { - if next.is_some() { - continue; - } - } - () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {}, - } - } - - let url = client - .http_client() - .build_zed_llm_url("/predict_edits/reject", &[]) - .unwrap(); - - let flush_count = batched - .len() - // in case items have accumulated after failure - .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST); - let start = batched.len() - flush_count; - - let body = RejectEditPredictionsBodyRef { - rejections: &batched[start..], - }; - - let result = Self::send_api_request::<()>( - |builder| { - let req = builder - .uri(url.as_ref()) - .body(serde_json::to_string(&body)?.into()); - anyhow::Ok(req?) - }, - client.clone(), - llm_token.clone(), - app_version.clone(), - ) - .await; - - if result.log_err().is_some() { - batched.drain(start..); - } - } - } - - fn reject_current_prediction( - &mut self, - reason: EditPredictionRejectReason, - project: &Entity, - ) { - if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { - project_state.pending_predictions.clear(); - if let Some(prediction) = project_state.current_prediction.take() { - self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown); - } - }; - } - - fn did_show_current_prediction(&mut self, project: &Entity, _cx: &mut Context) { - if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { - if let Some(current_prediction) = project_state.current_prediction.as_mut() { - if !current_prediction.was_shown { - current_prediction.was_shown = true; - self.shown_predictions - .push_front(current_prediction.prediction.clone()); - if self.shown_predictions.len() > 50 { - let completion = self.shown_predictions.pop_back().unwrap(); - self.rated_predictions.remove(&completion.id); - } - } - } - } - } - - fn reject_prediction( - &mut self, - prediction_id: EditPredictionId, - reason: EditPredictionRejectReason, - was_shown: bool, - ) { - match self.edit_prediction_model { - ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} - ZetaEditPredictionModel::Sweep => return, - } - - self.reject_predictions_tx - .unbounded_send(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - }) - .log_err(); - } - - fn is_refreshing(&self, project: &Entity) -> bool { - self.projects - .get(&project.entity_id()) - .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) - } - - pub fn refresh_prediction_from_buffer( - &mut self, - project: Entity, - buffer: Entity, - position: language::Anchor, - cx: &mut Context, - ) { - self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { - let Some(request_task) = this - .update(cx, |this, cx| { - this.request_prediction( - &project, - &buffer, - position, - PredictEditsRequestTrigger::Other, - cx, - ) - }) - .log_err() - else { - return Task::ready(anyhow::Ok(None)); - }; - - cx.spawn(async move |_cx| { - request_task.await.map(|prediction_result| { - prediction_result.map(|prediction_result| { - ( - prediction_result, - PredictionRequestedBy::Buffer(buffer.entity_id()), - ) - }) - }) - }) - }) - } - - pub fn refresh_prediction_from_diagnostics( - &mut self, - project: Entity, - cx: &mut Context, - ) { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - // Prefer predictions from buffer - if zeta_project.current_prediction.is_some() { - return; - }; - - self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { - let Some(open_buffer_task) = project - .update(cx, |project, cx| { - project - .active_entry() - .and_then(|entry| project.path_for_entry(entry, cx)) - .map(|path| project.open_buffer(path, cx)) - }) - .log_err() - .flatten() - else { - return Task::ready(anyhow::Ok(None)); - }; - - cx.spawn(async move |cx| { - let active_buffer = open_buffer_task.await?; - let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer, - &snapshot, - Default::default(), - Default::default(), - &project, - cx, - ) - .await? - else { - return anyhow::Ok(None); - }; - - let Some(prediction_result) = this - .update(cx, |this, cx| { - this.request_prediction( - &project, - &jump_buffer, - jump_position, - PredictEditsRequestTrigger::Diagnostics, - cx, - ) - })? - .await? - else { - return anyhow::Ok(None); - }; - - this.update(cx, |this, cx| { - Some(( - if this - .get_or_init_zeta_project(&project, cx) - .current_prediction - .is_none() - { - prediction_result - } else { - EditPredictionResult { - id: prediction_result.id, - prediction: Err(EditPredictionRejectReason::CurrentPreferred), - } - }, - PredictionRequestedBy::DiagnosticsUpdate, - )) - }) - }) - }); - } - - #[cfg(not(test))] - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - #[cfg(test)] - pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; - - fn queue_prediction_refresh( - &mut self, - project: Entity, - throttle_entity: EntityId, - cx: &mut Context, - do_refresh: impl FnOnce( - WeakEntity, - &mut AsyncApp, - ) - -> Task>> - + 'static, - ) { - let zeta_project = self.get_or_init_zeta_project(&project, cx); - let pending_prediction_id = zeta_project.next_pending_prediction_id; - zeta_project.next_pending_prediction_id += 1; - let last_request = zeta_project.last_prediction_refresh; - - let task = cx.spawn(async move |this, cx| { - if let Some((last_entity, last_timestamp)) = last_request - && throttle_entity == last_entity - && let Some(timeout) = - (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - // If this task was cancelled before the throttle timeout expired, - // do not perform a request. - let mut is_cancelled = true; - this.update(cx, |this, cx| { - let project_state = this.get_or_init_zeta_project(&project, cx); - if !project_state - .cancelled_predictions - .remove(&pending_prediction_id) - { - project_state.last_prediction_refresh = Some((throttle_entity, Instant::now())); - is_cancelled = false; - } - }) - .ok(); - if is_cancelled { - return None; - } - - let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten(); - let new_prediction_id = new_prediction_result - .as_ref() - .map(|(prediction, _)| prediction.id.clone()); - - // When a prediction completes, remove it from the pending list, and cancel - // any pending predictions that were enqueued before it. - this.update(cx, |this, cx| { - let zeta_project = this.get_or_init_zeta_project(&project, cx); - - let is_cancelled = zeta_project - .cancelled_predictions - .remove(&pending_prediction_id); - - let new_current_prediction = if !is_cancelled - && let Some((prediction_result, requested_by)) = new_prediction_result - { - match prediction_result.prediction { - Ok(prediction) => { - let new_prediction = CurrentEditPrediction { - requested_by, - prediction, - was_shown: false, - }; - - if let Some(current_prediction) = - zeta_project.current_prediction.as_ref() - { - if new_prediction.should_replace_prediction(¤t_prediction, cx) - { - this.reject_current_prediction( - EditPredictionRejectReason::Replaced, - &project, - ); - - Some(new_prediction) - } else { - this.reject_prediction( - new_prediction.prediction.id, - EditPredictionRejectReason::CurrentPreferred, - false, - ); - None - } - } else { - Some(new_prediction) - } - } - Err(reject_reason) => { - this.reject_prediction(prediction_result.id, reject_reason, false); - None - } - } - } else { - None - }; - - let zeta_project = this.get_or_init_zeta_project(&project, cx); - - if let Some(new_prediction) = new_current_prediction { - zeta_project.current_prediction = Some(new_prediction); - } - - let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions); - for (ix, pending_prediction) in pending_predictions.iter().enumerate() { - if pending_prediction.id == pending_prediction_id { - pending_predictions.remove(ix); - for pending_prediction in pending_predictions.drain(0..ix) { - zeta_project.cancel_pending_prediction(pending_prediction, cx) - } - break; - } - } - this.get_or_init_zeta_project(&project, cx) - .pending_predictions = pending_predictions; - cx.notify(); - }) - .ok(); - - new_prediction_id - }); - - if zeta_project.pending_predictions.len() <= 1 { - zeta_project.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - task, - }); - } else if zeta_project.pending_predictions.len() == 2 { - let pending_prediction = zeta_project.pending_predictions.pop().unwrap(); - zeta_project.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - task, - }); - zeta_project.cancel_pending_prediction(pending_prediction, cx); - } - } - - pub fn request_prediction( - &mut self, - project: &Entity, - active_buffer: &Entity, - position: language::Anchor, - trigger: PredictEditsRequestTrigger, - cx: &mut Context, - ) -> Task>> { - self.request_prediction_internal( - project.clone(), - active_buffer.clone(), - position, - trigger, - cx.has_flag::(), - cx, - ) - } - - fn request_prediction_internal( - &mut self, - project: Entity, - active_buffer: Entity, - position: language::Anchor, - trigger: PredictEditsRequestTrigger, - allow_jump: bool, - cx: &mut Context, - ) -> Task>> { - const DIAGNOSTIC_LINES_RANGE: u32 = 20; - - self.get_or_init_zeta_project(&project, cx); - let zeta_project = self.projects.get(&project.entity_id()).unwrap(); - let events = zeta_project.events(cx); - let has_events = !events.is_empty(); - - let snapshot = active_buffer.read(cx).snapshot(); - let cursor_point = position.to_point(&snapshot); - let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); - let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; - let diagnostic_search_range = - Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); - - let task = match self.edit_prediction_model { - ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1( - self, - &project, - &active_buffer, - snapshot.clone(), - position, - events, - trigger, - cx, - ), - ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2( - &project, - &active_buffer, - snapshot.clone(), - position, - events, - trigger, - cx, - ), - ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep( - &project, - &active_buffer, - snapshot.clone(), - position, - events, - &zeta_project.recent_paths, - if self.use_context { - self.context_for_project(&project, cx).to_vec() - } else { - Vec::new() - }, - diagnostic_search_range.clone(), - cx, - ), - }; - - cx.spawn(async move |this, cx| { - let prediction = task.await?; - - if prediction.is_none() && allow_jump { - let cursor_point = position.to_point(&snapshot); - if has_events - && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer.clone(), - &snapshot, - diagnostic_search_range, - cursor_point, - &project, - cx, - ) - .await? - { - return this - .update(cx, |this, cx| { - this.request_prediction_internal( - project, - jump_buffer, - jump_position, - trigger, - false, - cx, - ) - })? - .await; - } - - return anyhow::Ok(None); - } - - Ok(prediction) - }) - } - - async fn next_diagnostic_location( - active_buffer: Entity, - active_buffer_snapshot: &BufferSnapshot, - active_buffer_diagnostic_search_range: Range, - active_buffer_cursor_point: Point, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result, language::Anchor)>> { - // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request - let mut jump_location = active_buffer_snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(_, group)| { - let range = &group.entries[group.primary_ix] - .range - .to_point(&active_buffer_snapshot); - if range.overlaps(&active_buffer_diagnostic_search_range) { - None - } else { - Some(range.start) - } - }) - .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) - .map(|position| { - ( - active_buffer.clone(), - active_buffer_snapshot.anchor_before(position), - ) - }); - - if jump_location.is_none() { - let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { - let file = buffer.file()?; - - Some(ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path().clone(), - }) - })?; - - let buffer_task = project.update(cx, |project, cx| { - let (path, _, _) = project - .diagnostic_summaries(false, cx) - .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) - .max_by_key(|(path, _, _)| { - // find the buffer with errors that shares most parent directories - path.path - .components() - .zip( - active_buffer_path - .as_ref() - .map(|p| p.path.components()) - .unwrap_or_default(), - ) - .take_while(|(a, b)| a == b) - .count() - })?; - - Some(project.open_buffer(path, cx)) - })?; - - if let Some(buffer_task) = buffer_task { - let closest_buffer = buffer_task.await?; - - jump_location = closest_buffer - .read_with(cx, |buffer, _cx| { - buffer - .buffer_diagnostics(None) - .into_iter() - .min_by_key(|entry| entry.diagnostic.severity) - .map(|entry| entry.range.start) - })? - .map(|position| (closest_buffer, position)); - } - } - - anyhow::Ok(jump_location) - } - - fn request_prediction_with_zeta2( - &mut self, - project: &Entity, - active_buffer: &Entity, - active_snapshot: BufferSnapshot, - position: language::Anchor, - events: Vec>, - trigger: PredictEditsRequestTrigger, - cx: &mut Context, - ) -> Task>> { - let options = self.options.clone(); - let buffer_snapshotted_at = Instant::now(); - - let Some((excerpt_path, active_project_path)) = active_snapshot - .file() - .map(|file| -> Arc { file.full_path(cx).into() }) - .zip(active_buffer.read(cx).project_path(cx)) - else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; - - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - let debug_tx = self.debug_tx.clone(); - - let diagnostics = active_snapshot.diagnostic_sets().clone(); - - let file = active_buffer.read(cx).file(); - - let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); - - // TODO data collection - let can_collect_data = file - .as_ref() - .map_or(false, |file| self.can_collect_file(project, file, cx)); - - let mut included_files = self.context_for_project(project, cx).to_vec(); - - #[cfg(feature = "eval-support")] - let eval_cache = self.eval_cache.clone(); - - let request_task = cx.background_spawn({ - let active_buffer = active_buffer.clone(); - async move { - let cursor_offset = position.to_offset(&active_snapshot); - let cursor_point = cursor_offset.to_point(&active_snapshot); - - let before_retrieval = Instant::now(); - - let (diagnostic_groups, diagnostic_groups_truncated) = - Self::gather_nearby_diagnostics( - cursor_offset, - &diagnostics, - &active_snapshot, - options.max_diagnostic_bytes, - ); - - let excerpt_options = options.context.excerpt(); - - let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &active_snapshot, - &excerpt_options, - None, - ) else { - return Ok((None, None)); - }; - - let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) - ..active_snapshot.anchor_before(excerpt.range.end); - let related_excerpt = RelatedExcerpt { - anchor_range: excerpt_anchor_range.clone(), - point_range: Point::new(excerpt.line_range.start.0, 0) - ..Point::new(excerpt.line_range.end.0, 0), - text: active_snapshot.as_rope().slice(excerpt.range), - }; - - if let Some(buffer_ix) = included_files - .iter() - .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) - { - let file = &mut included_files[buffer_ix]; - file.excerpts.push(related_excerpt); - file.merge_excerpts(); - let last_ix = included_files.len() - 1; - included_files.swap(buffer_ix, last_ix); - } else { - let active_file = RelatedFile { - path: active_project_path, - buffer: active_buffer.downgrade(), - excerpts: vec![related_excerpt], - max_row: active_snapshot.max_point().row, - }; - included_files.push(active_file); - } - - let included_files = included_files - .iter() - .map(|related_file| predict_edits_v3::IncludedFile { - path: Arc::from(related_file.path.path.as_std_path()), - max_row: Line(related_file.max_row), - excerpts: related_file - .excerpts - .iter() - .map(|excerpt| predict_edits_v3::Excerpt { - start_line: Line(excerpt.point_range.start.row), - text: excerpt.text.to_string().into(), - }) - .collect(), - }) - .collect::>(); - - let cloud_request = predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: String::new(), - excerpt_line_range: Line(0)..Line(0), - excerpt_range: 0..0, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(cursor_point.row), - column: cursor_point.column, - }, - included_files, - referenced_declarations: vec![], - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - debug_info: debug_tx.is_some(), - prompt_max_bytes: Some(options.max_prompt_bytes), - prompt_format: options.prompt_format, - // TODO [zeta2] - signatures: vec![], - excerpt_parent: None, - git_info: None, - trigger, - }; - - let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); - - let inputs = EditPredictionInputs { - included_files: cloud_request.included_files, - events: cloud_request.events, - cursor_point: cloud_request.cursor_point, - cursor_path: cloud_request.excerpt_path, - }; - - let retrieval_time = Instant::now() - before_retrieval; - - let debug_response_tx = if let Some(debug_tx) = &debug_tx { - let (response_tx, response_rx) = oneshot::channel(); - - debug_tx - .unbounded_send(ZetaDebugInfo::EditPredictionRequested( - ZetaEditPredictionDebugInfo { - inputs: inputs.clone(), - retrieval_time, - buffer: active_buffer.downgrade(), - local_prompt: match prompt_result.as_ref() { - Ok((prompt, _)) => Ok(prompt.clone()), - Err(err) => Err(err.to_string()), - }, - position, - response_rx, - }, - )) - .ok(); - Some(response_tx) - } else { - None - }; - - if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send((Err("Request skipped".to_string()), Duration::ZERO)) - .ok(); - } - anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") - } - - let (prompt, _) = prompt_result?; - let generation_params = - cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); - let request = open_ai::Request { - model: EDIT_PREDICTIONS_MODEL_ID.clone(), - messages: vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt), - }], - stream: false, - max_completion_tokens: None, - stop: generation_params.stop.unwrap_or_default(), - temperature: generation_params.temperature.unwrap_or(0.7), - tool_choice: None, - parallel_tool_calls: None, - tools: vec![], - prompt_cache_key: None, - reasoning_effort: None, - }; - - log::trace!("Sending edit prediction request"); - - let before_request = Instant::now(); - let response = Self::send_raw_llm_request( - request, - client, - llm_token, - app_version, - #[cfg(feature = "eval-support")] - eval_cache, - #[cfg(feature = "eval-support")] - EvalCacheEntryKind::Prediction, - ) - .await; - let received_response_at = Instant::now(); - let request_time = received_response_at - before_request; - - log::trace!("Got edit prediction response"); - - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send(( - response - .as_ref() - .map_err(|err| err.to_string()) - .map(|response| response.0.clone()), - request_time, - )) - .ok(); - } - - let (res, usage) = response?; - let request_id = EditPredictionId(res.id.clone().into()); - let Some(mut output_text) = text_from_response(res) else { - return Ok((Some((request_id, None)), usage)); - }; - - if output_text.contains(CURSOR_MARKER) { - log::trace!("Stripping out {CURSOR_MARKER} from response"); - output_text = output_text.replace(CURSOR_MARKER, ""); - } - - let get_buffer_from_context = |path: &Path| { - if Some(path) == active_file_full_path.as_deref() { - Some(( - &active_snapshot, - std::slice::from_ref(&excerpt_anchor_range), - )) - } else { - None - } - }; - - let (_, edits) = match options.prompt_format { - PromptFormat::NumLinesUniDiff => { - // TODO: Implement parsing of multi-file diffs - crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? - } - PromptFormat::Minimal - | PromptFormat::MinimalQwen - | PromptFormat::SeedCoder1120 => { - if output_text.contains("--- a/\n+++ b/\nNo edits") { - let edits = vec![]; - (&active_snapshot, edits) - } else { - crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? - } - } - PromptFormat::OldTextNewText => { - crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context) - .await? - } - _ => { - bail!("unsupported prompt format {}", options.prompt_format) - } - }; - - anyhow::Ok(( - Some(( - request_id, - Some(( - inputs, - active_buffer, - active_snapshot.clone(), - edits, - received_response_at, - )), - )), - usage, - )) - } - }); - - cx.spawn({ - async move |this, cx| { - let Some((id, prediction)) = - Self::handle_api_response(&this, request_task.await, cx)? - else { - return Ok(None); - }; - - let Some(( - inputs, - edited_buffer, - edited_buffer_snapshot, - edits, - received_response_at, - )) = prediction - else { - return Ok(Some(EditPredictionResult { - id, - prediction: Err(EditPredictionRejectReason::Empty), - })); - }; - - // TODO telemetry: duration, etc - Ok(Some( - EditPredictionResult::new( - id, - &edited_buffer, - &edited_buffer_snapshot, - edits.into(), - buffer_snapshotted_at, - received_response_at, - inputs, - cx, - ) - .await, - )) - } - }) - } - - async fn send_raw_llm_request( - request: open_ai::Request, - client: Arc, - llm_token: LlmApiToken, - app_version: Version, - #[cfg(feature = "eval-support")] eval_cache: Option>, - #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, - ) -> Result<(open_ai::Response, Option)> { - let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { - http_client::Url::parse(&predict_edits_url)? - } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/raw", &[])? - }; - - #[cfg(feature = "eval-support")] - let cache_key = if let Some(cache) = eval_cache { - use collections::FxHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = FxHasher::default(); - url.hash(&mut hasher); - let request_str = serde_json::to_string_pretty(&request)?; - request_str.hash(&mut hasher); - let hash = hasher.finish(); - - let key = (eval_cache_kind, hash); - if let Some(response_str) = cache.read(key) { - return Ok((serde_json::from_str(&response_str)?, None)); - } - - Some((cache, request_str, key)) - } else { - None - }; - - let (response, usage) = Self::send_api_request( - |builder| { - let req = builder - .uri(url.as_ref()) - .body(serde_json::to_string(&request)?.into()); - Ok(req?) - }, - client, - llm_token, - app_version, - ) - .await?; - - #[cfg(feature = "eval-support")] - if let Some((cache, request, key)) = cache_key { - cache.write(key, &request, &serde_json::to_string_pretty(&response)?); - } - - Ok((response, usage)) - } - - fn handle_api_response( - this: &WeakEntity, - response: Result<(T, Option)>, - cx: &mut gpui::AsyncApp, - ) -> Result { - match response { - Ok((data, usage)) => { - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - Ok(data) - } - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button("Update Zed", "https://zed.dev/releases") - }) - }, - ); - }) - .ok(); - } - Err(err) - } - } - } - - async fn send_api_request( - build: impl Fn(http_client::http::request::Builder) -> Result>, - client: Arc, - llm_token: LlmApiToken, - app_version: Version, - ) -> Result<(Res, Option)> - where - Res: DeserializeOwned, - { - let http_client = client.http_client(); - let mut token = llm_token.acquire(&client).await?; - let mut did_retry = false; - - loop { - let request_builder = http_client::Request::builder().method(Method::POST); - - let request = build( - request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), - )?; - - let mut response = http_client.send(request).await?; - - if let Some(minimum_required_version) = response - .headers() - .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) - .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) - { - anyhow::ensure!( - app_version >= minimum_required_version, - ZedUpdateRequiredError { - minimum_version: minimum_required_version - } - ); - } - - if response.status().is_success() { - let usage = EditPredictionUsage::from_headers(response.headers()).ok(); - - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - return Ok((serde_json::from_slice(&body)?, usage)); - } else if !did_retry - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { - did_retry = true; - token = llm_token.refresh(&client).await?; - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - body - ); - } - } - } - - pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); - pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); - - pub fn refresh_context_if_needed( - &mut self, - project: &Entity, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) { - if !self.use_context { - return; - } - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - match &mut zeta_project.context { - ZetaProjectContext::Syntax(_entity) => {} - ZetaProjectContext::Lsp(related_excerpt_store) => { - related_excerpt_store.update(cx, |store, cx| { - store.refresh(buffer.clone(), cursor_position, cx); - }); - } - ZetaProjectContext::Agentic { - refresh_context_debounce_task, - refresh_context_timestamp, - .. - } => { - let now = Instant::now(); - let was_idle = refresh_context_timestamp.map_or(true, |timestamp| { - now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION - }); - *refresh_context_timestamp = Some(now); - *refresh_context_debounce_task = Some(cx.spawn({ - let buffer = buffer.clone(); - let project = project.clone(); - async move |this, cx| { - if was_idle { - log::debug!("refetching edit prediction context after idle"); - } else { - cx.background_executor() - .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) - .await; - log::debug!("refetching edit prediction context after pause"); - } - this.update(cx, |this, cx| { - let task = this.refresh_context_with_agentic_retrieval( - project.clone(), - buffer, - cursor_position, - cx, - ); - - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) - { - if let ZetaProjectContext::Agentic { - refresh_context_task, - .. - } = &mut zeta_project.context - { - *refresh_context_task = Some(task.log_err()); - } - }; - }) - .ok() - } - })); - } - } - } - - // Refresh the related excerpts asynchronously. Ensure the task runs to completion, - // and avoid spawning more than one concurrent task. - pub fn refresh_context_with_agentic_retrieval( - &mut self, - project: Entity, - buffer: Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) -> Task> { - let Some(zeta_project) = self.projects.get(&project.entity_id()) else { - return Task::ready(anyhow::Ok(())); - }; - - let ContextMode::Agentic(options) = &self.options().context else { - return Task::ready(anyhow::Ok(())); - }; - - let snapshot = buffer.read(cx).snapshot(); - let cursor_point = cursor_position.to_point(&snapshot); - let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &snapshot, - &options.excerpt, - None, - ) else { - return Task::ready(Ok(())); - }; - - let app_version = AppVersion::global(cx); - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let debug_tx = self.debug_tx.clone(); - let current_file_path: Arc = snapshot - .file() - .map(|f| f.full_path(cx).into()) - .unwrap_or_else(|| Path::new("untitled").into()); - - let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt( - predict_edits_v3::PlanContextRetrievalRequest { - excerpt: cursor_excerpt.text(&snapshot).body, - excerpt_path: current_file_path, - excerpt_line_range: cursor_excerpt.line_range, - cursor_file_max_row: Line(snapshot.max_point().row), - events: zeta_project.events(cx), - }, - ) { - Ok(prompt) => prompt, - Err(err) => { - return Task::ready(Err(err)); - } - }; - - let retrieval_started_at = Instant::now(); - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( - ZetaContextRetrievalStartedDebugInfo { - project_entity_id: project.entity_id(), - timestamp: retrieval_started_at, - search_prompt: prompt.clone(), - }, - )) - .ok(); - } - - pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { - let schema = language_model::tool_schema::root_schema_for::( - language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, - ); - - let description = schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap() - .to_string(); - - (schema.into(), description) - }); - - let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); - - let request = open_ai::Request { - model: CONTEXT_RETRIEVAL_MODEL_ID.clone(), - messages: vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt), - }], - stream: false, - max_completion_tokens: None, - stop: Default::default(), - temperature: 0.7, - tool_choice: None, - parallel_tool_calls: None, - tools: vec![open_ai::ToolDefinition::Function { - function: FunctionDefinition { - name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), - description: Some(tool_description), - parameters: Some(tool_schema), - }, - }], - prompt_cache_key: None, - reasoning_effort: None, - }; - - #[cfg(feature = "eval-support")] - let eval_cache = self.eval_cache.clone(); - - cx.spawn(async move |this, cx| { - log::trace!("Sending search planning request"); - let response = Self::send_raw_llm_request( - request, - client, - llm_token, - app_version, - #[cfg(feature = "eval-support")] - eval_cache.clone(), - #[cfg(feature = "eval-support")] - EvalCacheEntryKind::Context, - ) - .await; - let mut response = Self::handle_api_response(&this, response, cx)?; - log::trace!("Got search planning response"); - - let choice = response - .choices - .pop() - .context("No choices in retrieval response")?; - let open_ai::RequestMessage::Assistant { - content: _, - tool_calls, - } = choice.message - else { - anyhow::bail!("Retrieval response didn't include an assistant message"); - }; - - let mut queries: Vec = Vec::new(); - for tool_call in tool_calls { - let open_ai::ToolCallContent::Function { function } = tool_call.content; - if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { - log::warn!( - "Context retrieval response tried to call an unknown tool: {}", - function.name - ); - - continue; - } - - let input: SearchToolInput = serde_json::from_str(&function.arguments) - .with_context(|| format!("invalid search json {}", &function.arguments))?; - queries.extend(input.queries); - } - - log::trace!("Running retrieval search: {queries:#?}"); - let query_generation_finished_at = Instant::now(); - - let related_excerpts_result = retrieval_search::run_retrieval_searches( - queries, - project.clone(), - #[cfg(feature = "eval-support")] - eval_cache, - cx, - ) - .await; - - log::trace!("Search queries executed"); - let query_execution_finished_at = Instant::now(); - - this.update(cx, |this, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { - return Ok(()); - }; - if let ZetaProjectContext::Agentic { - refresh_context_task, - context, - .. - } = &mut zeta_project.context - { - refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalFinishedDebugInfo { - project_entity_id: project.entity_id(), - timestamp: Instant::now(), - metadata: vec![ - ( - "query_generation", - format!( - "{:?}", - query_generation_finished_at - retrieval_started_at - ) - .into(), - ), - ( - "search_execution", - format!( - "{:?}", - query_execution_finished_at - - query_generation_finished_at - ) - .into(), - ), - ], - }, - )) - .ok(); - } - match related_excerpts_result { - Ok(excerpts) => { - *context = excerpts; - Ok(()) - } - Err(error) => Err(error), - } - } else { - Ok(()) - } - })? - }) - } - - fn gather_nearby_diagnostics( - cursor_offset: usize, - diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], - snapshot: &BufferSnapshot, - max_diagnostics_bytes: usize, - ) -> (Vec, bool) { - // TODO: Could make this more efficient - let mut diagnostic_groups = Vec::new(); - for (language_server_id, diagnostics) in diagnostic_sets { - let mut groups = Vec::new(); - diagnostics.groups(*language_server_id, &mut groups, &snapshot); - diagnostic_groups.extend( - groups - .into_iter() - .map(|(_, group)| group.resolve::(&snapshot)), - ); - } - - // sort by proximity to cursor - diagnostic_groups.sort_by_key(|group| { - let range = &group.entries[group.primary_ix].range; - if range.start >= cursor_offset { - range.start - cursor_offset - } else if cursor_offset >= range.end { - cursor_offset - range.end - } else { - (cursor_offset - range.start).min(range.end - cursor_offset) - } - }); - - let mut results = Vec::new(); - let mut diagnostic_groups_truncated = false; - let mut diagnostics_byte_count = 0; - for group in diagnostic_groups { - let raw_value = serde_json::value::to_raw_value(&group).unwrap(); - diagnostics_byte_count += raw_value.get().len(); - if diagnostics_byte_count > max_diagnostics_bytes { - diagnostic_groups_truncated = true; - break; - } - results.push(predict_edits_v3::DiagnosticGroup(raw_value)); - } - - (results, diagnostic_groups_truncated) - } - - pub fn wait_for_initial_indexing( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> Task> { - let zeta_project = self.get_or_init_zeta_project(project, cx); - if let ZetaProjectContext::Syntax(syntax_index) = &zeta_project.context { - syntax_index.read(cx).wait_for_initial_file_indexing(cx) - } else { - Task::ready(Ok(())) - } - } - - fn is_file_open_source( - &self, - project: &Entity, - file: &Arc, - cx: &App, - ) -> bool { - if !file.is_local() || file.is_private() { - return false; - } - let Some(zeta_project) = self.projects.get(&project.entity_id()) else { - return false; - }; - zeta_project - .license_detection_watchers - .get(&file.worktree_id(cx)) - .as_ref() - .is_some_and(|watcher| watcher.is_project_open_source()) - } - - fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { - self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx) - } - - fn can_collect_events(&self, events: &[Arc]) -> bool { - if !self.data_collection_choice.is_enabled() { - return false; - } - events.iter().all(|event| { - matches!( - event.as_ref(), - Event::BufferChange { - in_open_source_repo: true, - .. - } - ) - }) - } - - fn load_data_collection_choice() -> DataCollectionChoice { - let choice = KEY_VALUE_STORE - .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) - .log_err() - .flatten(); - - match choice.as_deref() { - Some("true") => DataCollectionChoice::Enabled, - Some("false") => DataCollectionChoice::Disabled, - Some(_) => { - log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'"); - DataCollectionChoice::NotAnswered - } - None => DataCollectionChoice::NotAnswered, - } - } - - pub fn shown_predictions(&self) -> impl DoubleEndedIterator { - self.shown_predictions.iter() - } - - pub fn shown_completions_len(&self) -> usize { - self.shown_predictions.len() - } - - pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool { - self.rated_predictions.contains(id) - } - - pub fn rate_prediction( - &mut self, - prediction: &EditPrediction, - rating: EditPredictionRating, - feedback: String, - cx: &mut Context, - ) { - self.rated_predictions.insert(prediction.id.clone()); - telemetry::event!( - "Edit Prediction Rated", - rating, - inputs = prediction.inputs, - output = prediction.edit_preview.as_unified_diff(&prediction.edits), - feedback - ); - self.client.telemetry().flush_events().detach(); - cx.notify(); - } - - fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, Zeta>) { - self.use_context = cx.has_flag::() - && all_language_settings(None, cx).edit_predictions.use_context; - } -} - -pub fn text_from_response(mut res: open_ai::Response) -> Option { - let choice = res.choices.pop()?; - let output_text = match choice.message { - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(content)), - .. - } => content, - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Multipart(mut content)), - .. - } => { - if content.is_empty() { - log::error!("No output from Baseten completion response"); - return None; - } - - match content.remove(0) { - open_ai::MessagePart::Text { text } => text, - open_ai::MessagePart::Image { .. } => { - log::error!("Expected text, got an image"); - return None; - } - } - } - _ => { - log::error!("Invalid response message: {:?}", choice.message); - return None; - } - }; - Some(output_text) -} - -#[derive(Error, Debug)] -#[error( - "You must update to Zed version {minimum_version} or higher to continue using edit predictions." -)] -pub struct ZedUpdateRequiredError { - minimum_version: Version, -} - -#[cfg(feature = "eval-support")] -pub type EvalCacheKey = (EvalCacheEntryKind, u64); - -#[cfg(feature = "eval-support")] -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum EvalCacheEntryKind { - Context, - Search, - Prediction, -} - -#[cfg(feature = "eval-support")] -impl std::fmt::Display for EvalCacheEntryKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EvalCacheEntryKind::Search => write!(f, "search"), - EvalCacheEntryKind::Context => write!(f, "context"), - EvalCacheEntryKind::Prediction => write!(f, "prediction"), - } - } -} - -#[cfg(feature = "eval-support")] -pub trait EvalCache: Send + Sync { - fn read(&self, key: EvalCacheKey) -> Option; - fn write(&self, key: EvalCacheKey, input: &str, value: &str); -} - -#[derive(Debug, Clone, Copy)] -pub enum DataCollectionChoice { - NotAnswered, - Enabled, - Disabled, -} - -impl DataCollectionChoice { - pub fn is_enabled(self) -> bool { - match self { - Self::Enabled => true, - Self::NotAnswered | Self::Disabled => false, - } - } - - pub fn is_answered(self) -> bool { - match self { - Self::Enabled | Self::Disabled => true, - Self::NotAnswered => false, - } - } - - #[must_use] - pub fn toggle(&self) -> DataCollectionChoice { - match self { - Self::Enabled => Self::Disabled, - Self::Disabled => Self::Enabled, - Self::NotAnswered => Self::Enabled, - } - } -} - -impl From for DataCollectionChoice { - fn from(value: bool) -> Self { - match value { - true => DataCollectionChoice::Enabled, - false => DataCollectionChoice::Disabled, - } - } -} - -struct ZedPredictUpsell; - -impl Dismissable for ZedPredictUpsell { - const KEY: &'static str = "dismissed-edit-predict-upsell"; - - fn dismissed() -> bool { - // To make this backwards compatible with older versions of Zed, we - // check if the user has seen the previous Edit Prediction Onboarding - // before, by checking the data collection choice which was written to - // the database once the user clicked on "Accept and Enable" - if KEY_VALUE_STORE - .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) - .log_err() - .is_some_and(|s| s.is_some()) - { - return true; - } - - KEY_VALUE_STORE - .read_kvp(Self::KEY) - .log_err() - .is_some_and(|s| s.is_some()) - } -} - -pub fn should_show_upsell_modal() -> bool { - !ZedPredictUpsell::dismissed() -} - -pub fn init(cx: &mut App) { - feature_gate_predict_edits_actions(cx); - - cx.observe_new(move |workspace: &mut Workspace, _, _cx| { - workspace.register_action(|workspace, _: &RateCompletions, window, cx| { - if cx.has_flag::() { - RatePredictionsModal::toggle(workspace, window, cx); - } - }); - - workspace.register_action( - move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| { - ZedPredictModal::toggle( - workspace, - workspace.user_store().clone(), - workspace.client().clone(), - window, - cx, - ) - }, - ); - - workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| { - update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| { - settings - .project - .all_languages - .features - .get_or_insert_default() - .edit_prediction_provider = Some(EditPredictionProvider::None) - }); - }); - }) - .detach(); -} - -fn feature_gate_predict_edits_actions(cx: &mut App) { - let rate_completion_action_types = [TypeId::of::()]; - let reset_onboarding_action_types = [TypeId::of::()]; - let zeta_all_action_types = [ - TypeId::of::(), - TypeId::of::(), - zed_actions::OpenZedPredictOnboarding.type_id(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&rate_completion_action_types); - filter.hide_action_types(&reset_onboarding_action_types); - filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); - }); - - cx.observe_global::(move |cx| { - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - let has_feature_flag = cx.has_flag::(); - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - if is_ai_disabled { - filter.hide_action_types(&zeta_all_action_types); - } else if has_feature_flag { - filter.show_action_types(&rate_completion_action_types); - } else { - filter.hide_action_types(&rate_completion_action_types); - } - }); - }) - .detach(); - - cx.observe_flag::(move |is_enabled, cx| { - if !DisableAiSettings::get_global(cx).disable_ai { - if is_enabled { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_action_types(&rate_completion_action_types); - }); - } else { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&rate_completion_action_types); - }); - } - } - }) - .detach(); -} - -#[cfg(test)] -mod tests { - use std::{path::Path, sync::Arc, time::Duration}; - - use client::UserStore; - use clock::FakeSystemClock; - use cloud_llm_client::{ - EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, - }; - use futures::{ - AsyncReadExt, StreamExt, - channel::{mpsc, oneshot}, - }; - use gpui::{ - Entity, TestAppContext, - http_client::{FakeHttpClient, Response}, - prelude::*, - }; - use indoc::indoc; - use language::OffsetRangeExt as _; - use lsp::LanguageServerId; - use open_ai::Usage; - use pretty_assertions::{assert_eq, assert_matches}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - use uuid::Uuid; - - use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta}; - - #[gpui::test] - async fn test_current_state(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "1.txt": "Hello!\nHow\nBye\n", - "2.txt": "Hola!\nComo\nAdios\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - zeta.update(cx, |zeta, cx| { - zeta.register_project(&project, cx); - }); - - let buffer1 = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); - project.set_active_path(Some(path.clone()), cx); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot1.anchor_before(language::Point::new(1, 3)); - - // Prediction for current file - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); - - respond_tx - .send(model_response(indoc! {r" - --- a/root/1.txt - +++ b/root/1.txt - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) - .unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer1, &project, cx) - .unwrap(); - assert_matches!(prediction, BufferEditPrediction::Local { .. }); - }); - - zeta.update(cx, |zeta, _cx| { - zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); - }); - - // Prediction for diagnostic in another file - - let diagnostic = lsp::Diagnostic { - range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), - severity: Some(lsp::DiagnosticSeverity::ERROR), - message: "Sentence is incomplete".to_string(), - ..Default::default() - }; - - project.update(cx, |project, cx| { - project.lsp_store().update(cx, |lsp_store, cx| { - lsp_store - .update_diagnostics( - LanguageServerId(0), - lsp::PublishDiagnosticsParams { - uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), - diagnostics: vec![diagnostic], - version: None, - }, - None, - language::DiagnosticSourceKind::Pushed, - &[], - cx, - ) - .unwrap(); - }); - }); - - let (_request, respond_tx) = requests.predict.next().await.unwrap(); - respond_tx - .send(model_response(indoc! {r#" - --- a/root/2.txt - +++ b/root/2.txt - Hola! - -Como - +Como estas? - Adios - "#})) - .unwrap(); - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer1, &project, cx) - .unwrap(); - assert_matches!( - prediction, - BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) - ); - }); - - let buffer2 = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer2, &project, cx) - .unwrap(); - assert_matches!(prediction, BufferEditPrediction::Local { .. }); - }); - } - - #[gpui::test] - async fn test_simple_request(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, Default::default(), cx) - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - - // TODO Put back when we have a structured request again - // assert_eq!( - // request.excerpt_path.as_ref(), - // Path::new(path!("root/foo.md")) - // ); - // assert_eq!( - // request.cursor_point, - // Point { - // line: Line(1), - // column: 3 - // } - // ); - - respond_tx - .send(model_response(indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) - .unwrap(); - - let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); - - assert_eq!(prediction.edits.len(), 1); - assert_eq!( - prediction.edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 3) - ); - assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); - } - - #[gpui::test] - async fn test_request_events(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\n\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(&buffer, &project, cx); - }); - - buffer.update(cx, |buffer, cx| { - buffer.edit(vec![(7..7, "How")], None, cx); - }); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, Default::default(), cx) - }); - - let (request, respond_tx) = requests.predict.next().await.unwrap(); - - let prompt = prompt_from_request(&request); - assert!( - prompt.contains(indoc! {" - --- a/root/foo.md - +++ b/root/foo.md - @@ -1,3 +1,3 @@ - Hello! - - - +How - Bye - "}), - "{prompt}" - ); - - respond_tx - .send(model_response(indoc! {r#" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "#})) - .unwrap(); - - let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); - - assert_eq!(prediction.edits.len(), 1); - assert_eq!( - prediction.edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 3) - ); - assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); - } - - #[gpui::test] - async fn test_empty_prediction(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - const NO_OP_DIFF: &str = indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How - Bye - "}; - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let response = model_response(NO_OP_DIFF); - let id = response.id.clone(); - respond_tx.send(response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - assert!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .is_none() - ); - }); - - // prediction is reported as rejected - let (reject_request, _) = requests.reject.next().await.unwrap(); - - assert_eq!( - &reject_request.rejections, - &[EditPredictionRejection { - request_id: id, - reason: EditPredictionRejectReason::Empty, - was_shown: false - }] - ); - } - - #[gpui::test] - async fn test_interpolated_empty(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - - buffer.update(cx, |buffer, cx| { - buffer.set_text("Hello!\nHow are you?\nBye", cx); - }); - - let response = model_response(SIMPLE_DIFF); - let id = response.id.clone(); - respond_tx.send(response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - assert!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .is_none() - ); - }); - - // prediction is reported as rejected - let (reject_request, _) = requests.reject.next().await.unwrap(); - - assert_eq!( - &reject_request.rejections, - &[EditPredictionRejection { - request_id: id, - reason: EditPredictionRejectReason::InterpolatedEmpty, - was_shown: false - }] - ); - } - - const SIMPLE_DIFF: &str = indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "}; - - #[gpui::test] - async fn test_replace_current(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(SIMPLE_DIFF); - let first_id = first_response.id.clone(); - respond_tx.send(first_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - first_id - ); - }); - - // a second request is triggered - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let second_response = model_response(SIMPLE_DIFF); - let second_id = second_response.id.clone(); - respond_tx.send(second_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // second replaces first - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - second_id - ); - }); - - // first is reported as replaced - let (reject_request, _) = requests.reject.next().await.unwrap(); - - assert_eq!( - &reject_request.rejections, - &[EditPredictionRejection { - request_id: first_id, - reason: EditPredictionRejectReason::Replaced, - was_shown: false - }] - ); - } - - #[gpui::test] - async fn test_current_preferred(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(SIMPLE_DIFF); - let first_id = first_response.id.clone(); - respond_tx.send(first_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - first_id - ); - }); - - // a second request is triggered - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - // worse than current prediction - let second_response = model_response(indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are - Bye - "}); - let second_id = second_response.id.clone(); - respond_tx.send(second_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // first is preferred over second - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - first_id - ); - }); - - // second is reported as rejected - let (reject_request, _) = requests.reject.next().await.unwrap(); - - assert_eq!( - &reject_request.rejections, - &[EditPredictionRejection { - request_id: second_id, - reason: EditPredictionRejectReason::CurrentPreferred, - was_shown: false - }] - ); - } - - #[gpui::test] - async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - // start two refresh tasks - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_first) = requests.predict.next().await.unwrap(); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_second) = requests.predict.next().await.unwrap(); - - // wait for throttle - cx.run_until_parked(); - - // second responds first - let second_response = model_response(SIMPLE_DIFF); - let second_id = second_response.id.clone(); - respond_second.send(second_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // current prediction is second - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - second_id - ); - }); - - let first_response = model_response(SIMPLE_DIFF); - let first_id = first_response.id.clone(); - respond_first.send(first_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // current prediction is still second, since first was cancelled - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - second_id - ); - }); - - // first is reported as rejected - let (reject_request, _) = requests.reject.next().await.unwrap(); - - cx.run_until_parked(); - - assert_eq!( - &reject_request.rejections, - &[EditPredictionRejection { - request_id: first_id, - reason: EditPredictionRejectReason::Canceled, - was_shown: false - }] - ); - } - - #[gpui::test] - async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - // start two refresh tasks - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_first) = requests.predict.next().await.unwrap(); - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - }); - - let (_, respond_second) = requests.predict.next().await.unwrap(); - - // wait for throttle, so requests are sent - cx.run_until_parked(); - - zeta.update(cx, |zeta, cx| { - // start a third request - zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); - - // 2 are pending, so 2nd is cancelled - assert_eq!( - zeta.get_or_init_zeta_project(&project, cx) - .cancelled_predictions - .iter() - .copied() - .collect::>(), - [1] - ); - }); - - // wait for throttle - cx.run_until_parked(); - - let (_, respond_third) = requests.predict.next().await.unwrap(); - - let first_response = model_response(SIMPLE_DIFF); - let first_id = first_response.id.clone(); - respond_first.send(first_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // current prediction is first - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - first_id - ); - }); - - let cancelled_response = model_response(SIMPLE_DIFF); - let cancelled_id = cancelled_response.id.clone(); - respond_second.send(cancelled_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // current prediction is still first, since second was cancelled - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - first_id - ); - }); - - let third_response = model_response(SIMPLE_DIFF); - let third_response_id = third_response.id.clone(); - respond_third.send(third_response).unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - // third completes and replaces first - assert_eq!( - zeta.current_prediction_for_buffer(&buffer, &project, cx) - .unwrap() - .id - .0, - third_response_id - ); - }); - - // second is reported as rejected - let (reject_request, _) = requests.reject.next().await.unwrap(); - - cx.run_until_parked(); - - assert_eq!( - &reject_request.rejections, - &[ - EditPredictionRejection { - request_id: cancelled_id, - reason: EditPredictionRejectReason::Canceled, - was_shown: false - }, - EditPredictionRejection { - request_id: first_id, - reason: EditPredictionRejectReason::Replaced, - was_shown: false - } - ] - ); - } - - #[gpui::test] - async fn test_rejections_flushing(cx: &mut TestAppContext) { - let (zeta, mut requests) = init_test(cx); - - zeta.update(cx, |zeta, _cx| { - zeta.reject_prediction( - EditPredictionId("test-1".into()), - EditPredictionRejectReason::Discarded, - false, - ); - zeta.reject_prediction( - EditPredictionId("test-2".into()), - EditPredictionRejectReason::Canceled, - true, - ); - }); - - cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); - cx.run_until_parked(); - - let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); - respond_tx.send(()).unwrap(); - - // batched - assert_eq!(reject_request.rejections.len(), 2); - assert_eq!( - reject_request.rejections[0], - EditPredictionRejection { - request_id: "test-1".to_string(), - reason: EditPredictionRejectReason::Discarded, - was_shown: false - } - ); - assert_eq!( - reject_request.rejections[1], - EditPredictionRejection { - request_id: "test-2".to_string(), - reason: EditPredictionRejectReason::Canceled, - was_shown: true - } - ); - - // Reaching batch size limit sends without debounce - zeta.update(cx, |zeta, _cx| { - for i in 0..70 { - zeta.reject_prediction( - EditPredictionId(format!("batch-{}", i).into()), - EditPredictionRejectReason::Discarded, - false, - ); - } - }); - - // First MAX/2 items are sent immediately - cx.run_until_parked(); - let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); - respond_tx.send(()).unwrap(); - - assert_eq!(reject_request.rejections.len(), 50); - assert_eq!(reject_request.rejections[0].request_id, "batch-0"); - assert_eq!(reject_request.rejections[49].request_id, "batch-49"); - - // Remaining items are debounced with the next batch - cx.executor().advance_clock(Duration::from_secs(15)); - cx.run_until_parked(); - - let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); - respond_tx.send(()).unwrap(); - - assert_eq!(reject_request.rejections.len(), 20); - assert_eq!(reject_request.rejections[0].request_id, "batch-50"); - assert_eq!(reject_request.rejections[19].request_id, "batch-69"); - - // Request failure - zeta.update(cx, |zeta, _cx| { - zeta.reject_prediction( - EditPredictionId("retry-1".into()), - EditPredictionRejectReason::Discarded, - false, - ); - }); - - cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); - cx.run_until_parked(); - - let (reject_request, _respond_tx) = requests.reject.next().await.unwrap(); - assert_eq!(reject_request.rejections.len(), 1); - assert_eq!(reject_request.rejections[0].request_id, "retry-1"); - // Simulate failure - drop(_respond_tx); - - // Add another rejection - zeta.update(cx, |zeta, _cx| { - zeta.reject_prediction( - EditPredictionId("retry-2".into()), - EditPredictionRejectReason::Discarded, - false, - ); - }); - - cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); - cx.run_until_parked(); - - // Retry should include both the failed item and the new one - let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); - respond_tx.send(()).unwrap(); - - assert_eq!(reject_request.rejections.len(), 2); - assert_eq!(reject_request.rejections[0].request_id, "retry-1"); - assert_eq!(reject_request.rejections[1].request_id, "retry-2"); - } - - // Skipped until we start including diagnostics in prompt - // #[gpui::test] - // async fn test_request_diagnostics(cx: &mut TestAppContext) { - // let (zeta, mut req_rx) = init_test(cx); - // let fs = FakeFs::new(cx.executor()); - // fs.insert_tree( - // "/root", - // json!({ - // "foo.md": "Hello!\nBye" - // }), - // ) - // .await; - // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); - // let diagnostic = lsp::Diagnostic { - // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), - // severity: Some(lsp::DiagnosticSeverity::ERROR), - // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), - // ..Default::default() - // }; - - // project.update(cx, |project, cx| { - // project.lsp_store().update(cx, |lsp_store, cx| { - // // Create some diagnostics - // lsp_store - // .update_diagnostics( - // LanguageServerId(0), - // lsp::PublishDiagnosticsParams { - // uri: path_to_buffer_uri.clone(), - // diagnostics: vec![diagnostic], - // version: None, - // }, - // None, - // language::DiagnosticSourceKind::Pushed, - // &[], - // cx, - // ) - // .unwrap(); - // }); - // }); - - // let buffer = project - // .update(cx, |project, cx| { - // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - // project.open_buffer(path, cx) - // }) - // .await - // .unwrap(); - - // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - // let position = snapshot.anchor_before(language::Point::new(0, 0)); - - // let _prediction_task = zeta.update(cx, |zeta, cx| { - // zeta.request_prediction(&project, &buffer, position, cx) - // }); - - // let (request, _respond_tx) = req_rx.next().await.unwrap(); - - // assert_eq!(request.diagnostic_groups.len(), 1); - // let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) - // .unwrap(); - // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 - // assert_eq!( - // value, - // json!({ - // "entries": [{ - // "range": { - // "start": 8, - // "end": 10 - // }, - // "diagnostic": { - // "source": null, - // "code": null, - // "code_description": null, - // "severity": 1, - // "message": "\"Hello\" deprecated. Use \"Hi\" instead", - // "markdown": null, - // "group_id": 0, - // "is_primary": true, - // "is_disk_based": false, - // "is_unnecessary": false, - // "source_kind": "Pushed", - // "data": null, - // "underline": true - // } - // }], - // "primary_ix": 0 - // }) - // ); - // } - - fn model_response(text: &str) -> open_ai::Response { - open_ai::Response { - id: Uuid::new_v4().to_string(), - object: "response".into(), - created: 0, - model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(text.to_string())), - tool_calls: vec![], - }, - finish_reason: None, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - } - } - - fn prompt_from_request(request: &open_ai::Request) -> &str { - assert_eq!(request.messages.len(), 1); - let open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(content), - .. - } = &request.messages[0] - else { - panic!( - "Request does not have single user message of type Plain. {:#?}", - request - ); - }; - content - } - - struct RequestChannels { - predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, - reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, - } - - fn init_test(cx: &mut TestAppContext) -> (Entity, RequestChannels) { - cx.update(move |cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - zlog::init_test(); - - let (predict_req_tx, predict_req_rx) = mpsc::unbounded(); - let (reject_req_tx, reject_req_rx) = mpsc::unbounded(); - - let http_client = FakeHttpClient::create({ - move |req| { - let uri = req.uri().path().to_string(); - let mut body = req.into_body(); - let predict_req_tx = predict_req_tx.clone(); - let reject_req_tx = reject_req_tx.clone(); - async move { - let resp = match uri.as_str() { - "/client/llm_tokens" => serde_json::to_string(&json!({ - "token": "test" - })) - .unwrap(), - "/predict_edits/raw" => { - let mut buf = Vec::new(); - body.read_to_end(&mut buf).await.ok(); - let req = serde_json::from_slice(&buf).unwrap(); - let (res_tx, res_rx) = oneshot::channel(); - predict_req_tx.unbounded_send((req, res_tx)).unwrap(); - serde_json::to_string(&res_rx.await?).unwrap() - } - "/predict_edits/reject" => { - let mut buf = Vec::new(); - body.read_to_end(&mut buf).await.ok(); - let req = serde_json::from_slice(&buf).unwrap(); - - let (res_tx, res_rx) = oneshot::channel(); - reject_req_tx.unbounded_send((req, res_tx)).unwrap(); - serde_json::to_string(&res_rx.await?).unwrap() - } - _ => { - panic!("Unexpected path: {}", uri) - } - }; - - Ok(Response::builder().body(resp.into()).unwrap()) - } - } - }); - - let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); - client.cloud_client().set_credentials(1, "test".into()); - - language_model::init(client.clone(), cx); - - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = Zeta::global(&client, &user_store, cx); - - ( - zeta, - RequestChannels { - predict: predict_req_rx, - reject: reject_req_rx, - }, - ) - }) - } -} diff --git a/crates/zeta/src/zeta_tests.rs b/crates/zeta/src/zeta_tests.rs deleted file mode 100644 index 3549cda36d575a989f5bc4bd5bb8bea6810d3180..0000000000000000000000000000000000000000 --- a/crates/zeta/src/zeta_tests.rs +++ /dev/null @@ -1,671 +0,0 @@ -use client::test::FakeServer; -use clock::{FakeSystemClock, ReplicaId}; -use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; -use cloud_llm_client::{PredictEditsBody, PredictEditsResponse}; -use gpui::TestAppContext; -use http_client::FakeHttpClient; -use indoc::indoc; -use language::Point; -use parking_lot::Mutex; -use serde_json::json; -use settings::SettingsStore; -use util::{path, rel_path::rel_path}; - -use crate::zeta1::MAX_EVENT_TOKENS; - -use super::*; - -const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); - -#[gpui::test] -async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { - let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); - let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { - to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() - }); - - let edit_preview = cx - .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) - .await; - - let completion = EditPrediction { - edits, - edit_preview, - buffer: buffer.clone(), - snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - id: EditPredictionId("the-id".into()), - inputs: EditPredictionInputs { - events: Default::default(), - included_files: Default::default(), - cursor_point: cloud_llm_client::predict_edits_v3::Point { - line: Line(0), - column: 0, - }, - cursor_path: Path::new("").into(), - }, - buffer_snapshotted_at: Instant::now(), - response_received_at: Instant::now(), - }; - - cx.update(|cx| { - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".into()), (9..11, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..2, "REM".into()), (6..8, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".into()), (9..11, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(3..3, "EM".into()), (7..9, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into()), (8..10, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(9..11, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into()), (8..10, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); - assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); - }) -} - -#[gpui::test] -async fn test_clean_up_diff(cx: &mut TestAppContext) { - init_test(cx); - - assert_eq!( - apply_edit_prediction( - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word.len()..word.len(); - } - "}, - indoc! {" - <|editable_region_start|> - fn main() { - let word_1 = \"lorem\"; - let range = word_1.len()..word_1.len(); - } - - <|editable_region_end|> - "}, - cx, - ) - .await, - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word_1.len()..word_1.len(); - } - "}, - ); - - assert_eq!( - apply_edit_prediction( - indoc! {" - fn main() { - let story = \"the quick\" - } - "}, - indoc! {" - <|editable_region_start|> - fn main() { - let story = \"the quick brown fox jumps over the lazy dog\"; - } - - <|editable_region_end|> - "}, - cx, - ) - .await, - indoc! {" - fn main() { - let story = \"the quick brown fox jumps over the lazy dog\"; - } - "}, - ); -} - -#[gpui::test] -async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { - init_test(cx); - - let buffer_content = "lorem\n"; - let completion_response = indoc! {" - ```animals.js - <|start_of_file|> - <|editable_region_start|> - lorem - ipsum - <|editable_region_end|> - ```"}; - - assert_eq!( - apply_edit_prediction(buffer_content, completion_response, cx).await, - "lorem\nipsum" - ); -} - -#[gpui::test] -async fn test_can_collect_data(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/project/src/main.rs"), cx) - }) - .await - .unwrap(); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Disabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - - let buffer = cx.new(|_cx| { - Buffer::remote( - language::BufferId::new(1).unwrap(), - ReplicaId::new(1), - language::Capability::ReadWrite, - "fn main() {\n println!(\"Hello\");\n}", - ) - }); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "LICENSE": BSD_0_TXT, - ".env": "SECRET_KEY=secret" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/.env", cx) - }) - .await - .unwrap(); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - let buffer = cx.new(|cx| Buffer::local("", cx)); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/main.rs", cx) - }) - .await - .unwrap(); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/open_source_worktree"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), - ) - .await; - fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [ - path!("/open_source_worktree").as_ref(), - path!("/closed_source_worktree").as_ref(), - ], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) - }) - .await - .unwrap(); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - let closed_source_file = project - .update(cx, |project, cx| { - let worktree2 = project - .worktree_for_root_name("closed_source_worktree", cx) - .unwrap(); - worktree2.update(cx, |worktree2, cx| { - worktree2.load_file(rel_path("main.rs"), cx) - }) - }) - .await - .unwrap() - .file; - - buffer.update(cx, |buffer, cx| { - buffer.file_updated(closed_source_file, cx); - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/worktree1"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), - ) - .await; - fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree1/main.rs"), cx) - }) - .await - .unwrap(); - let private_buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree2/file.rs"), cx) - }) - .await - .unwrap(); - - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - // this has a side effect of registering the buffer to watch for edits - run_edit_prediction(&private_buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - private_buffer.update(cx, |private_buffer, cx| { - private_buffer.edit([(0..0, "An edit for the history!")], None, cx); - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - // make an edit that uses too many bytes, causing private_buffer edit to not be able to be - // included - buffer.update(cx, |buffer, cx| { - buffer.edit( - [( - 0..0, - " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS), - )], - None, - cx, - ); - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); -} - -fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); -} - -async fn apply_edit_prediction( - buffer_content: &str, - completion_response: &str, - cx: &mut TestAppContext, -) -> String { - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let (zeta, _, response) = make_test_zeta(&project, cx).await; - *response.lock() = completion_response.to_string(); - let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await; - buffer.update(cx, |buffer, cx| { - buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) - }); - buffer.read_with(cx, |buffer, _| buffer.text()) -} - -async fn run_edit_prediction( - buffer: &Entity, - project: &Entity, - zeta: &Entity, - cx: &mut TestAppContext, -) -> EditPrediction { - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); - zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx)); - cx.background_executor.run_until_parked(); - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, buffer, cursor, Default::default(), cx) - }); - prediction_task.await.unwrap().unwrap().prediction.unwrap() -} - -async fn make_test_zeta( - project: &Entity, - cx: &mut TestAppContext, -) -> ( - Entity, - Arc>>, - Arc>, -) { - let default_response = indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - hello world - <|editable_region_end|> - ```" - }; - let captured_request: Arc>> = Arc::new(Mutex::new(None)); - let completion_response: Arc> = - Arc::new(Mutex::new(default_response.to_string())); - let http_client = FakeHttpClient::create({ - let captured_request = captured_request.clone(); - let completion_response = completion_response.clone(); - let mut next_request_id = 0; - move |req| { - let captured_request = captured_request.clone(); - let completion_response = completion_response.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => { - let mut request_body = String::new(); - req.into_body().read_to_string(&mut request_body).await?; - *captured_request.lock() = - Some(serde_json::from_str(&request_body).unwrap()); - next_request_id += 1; - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: format!("request-{next_request_id}"), - output_excerpt: completion_response.lock().clone(), - }) - .unwrap() - .into(), - ) - .unwrap()) - } - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } - } - } - }); - - let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - RefreshLlmTokenListener::register(client.clone(), cx); - }); - let _server = FakeServer::for_client(42, &client, cx).await; - - let zeta = cx.new(|cx| { - let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx); - zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1); - - let worktrees = project.read(cx).worktrees(cx).collect::>(); - for worktree in worktrees { - let worktree_id = worktree.read(cx).id(); - zeta.get_or_init_zeta_project(project, cx) - .license_detection_watchers - .entry(worktree_id) - .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); - } - - zeta - }); - - (zeta, captured_request, completion_response) -} - -fn to_completion_edits( - iterator: impl IntoIterator, Arc)>, - buffer: &Entity, - cx: &App, -) -> Vec<(Range, Arc)> { - let buffer = buffer.read(cx); - iterator - .into_iter() - .map(|(range, text)| { - ( - buffer.anchor_after(range.start)..buffer.anchor_before(range.end), - text, - ) - }) - .collect() -} - -fn from_completion_edits( - editor_edits: &[(Range, Arc)], - buffer: &Entity, - cx: &App, -) -> Vec<(Range, Arc)> { - let buffer = buffer.read(cx); - editor_edits - .iter() - .map(|(range, text)| { - ( - range.start.to_offset(buffer)..range.end.to_offset(buffer), - text.clone(), - ) - }) - .collect() -} - -#[ctor::ctor] -fn init_logger() { - zlog::init_test(); -} diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml deleted file mode 100644 index 8e20224736c658d4d80d678b29d4231ec7e4b2f5..0000000000000000000000000000000000000000 --- a/crates/zeta2_tools/Cargo.toml +++ /dev/null @@ -1,48 +0,0 @@ -[package] -name = "zeta2_tools" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/zeta2_tools.rs" - -[dependencies] -anyhow.workspace = true -client.workspace = true -cloud_llm_client.workspace = true -collections.workspace = true -edit_prediction_context.workspace = true -editor.workspace = true -feature_flags.workspace = true -futures.workspace = true -gpui.workspace = true -language.workspace = true -multi_buffer.workspace = true -project.workspace = true -serde.workspace = true -serde_json.workspace = true -telemetry.workspace = true -text.workspace = true -ui.workspace = true -ui_input.workspace = true -util.workspace = true -workspace.workspace = true -zeta.workspace = true - -[dev-dependencies] -clap.workspace = true -gpui = { workspace = true, features = ["test-support"] } -indoc.workspace = true -language = { workspace = true, features = ["test-support"] } -pretty_assertions.workspace = true -project = { workspace = true, features = ["test-support"] } -serde_json.workspace = true -settings = { workspace = true, features = ["test-support"] } -text = { workspace = true, features = ["test-support"] } -util = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/zeta2_tools/LICENSE-GPL b/crates/zeta2_tools/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/zeta2_tools/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs deleted file mode 100644 index 26d68b075153557ab50ed0a231c5d45f0bb9646c..0000000000000000000000000000000000000000 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ /dev/null @@ -1,1035 +0,0 @@ -mod zeta2_context_view; - -use std::{str::FromStr, sync::Arc, time::Duration}; - -use client::{Client, UserStore}; -use cloud_llm_client::predict_edits_v3::PromptFormat; -use collections::HashMap; -use editor::{Editor, EditorEvent, EditorMode, MultiBuffer}; -use feature_flags::FeatureFlagAppExt as _; -use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; -use gpui::{ - Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions, - prelude::*, -}; -use language::Buffer; -use project::{Project, telemetry_snapshot::TelemetrySnapshot}; -use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*}; -use ui_input::InputField; -use util::ResultExt; -use workspace::{Item, SplitDirection, Workspace}; -use zeta::{ - AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta, - Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions, -}; - -use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions}; -use zeta2_context_view::Zeta2ContextView; - -actions!( - dev, - [ - /// Opens the edit prediction context view. - OpenZeta2ContextView, - /// Opens the edit prediction inspector. - OpenZeta2Inspector, - /// Rate prediction as positive. - Zeta2RatePredictionPositive, - /// Rate prediction as negative. - Zeta2RatePredictionNegative, - ] -); - -pub fn init(cx: &mut App) { - cx.observe_new(move |workspace: &mut Workspace, _, _cx| { - workspace.register_action_renderer(|div, _, _, cx| { - let has_flag = cx.has_flag::(); - div.when(has_flag, |div| { - div.on_action( - cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| { - let project = workspace.project(); - workspace.split_item( - SplitDirection::Right, - Box::new(cx.new(|cx| { - Zeta2Inspector::new( - &project, - workspace.client(), - workspace.user_store(), - window, - cx, - ) - })), - window, - cx, - ) - }), - ) - .on_action(cx.listener( - move |workspace, _: &OpenZeta2ContextView, window, cx| { - let project = workspace.project(); - workspace.split_item( - SplitDirection::Right, - Box::new(cx.new(|cx| { - Zeta2ContextView::new( - project.clone(), - workspace.client(), - workspace.user_store(), - window, - cx, - ) - })), - window, - cx, - ); - }, - )) - }) - }); - }) - .detach(); -} - -// TODO show included diagnostics, and events - -pub struct Zeta2Inspector { - focus_handle: FocusHandle, - project: Entity, - last_prediction: Option, - max_excerpt_bytes_input: Entity, - min_excerpt_bytes_input: Entity, - cursor_context_ratio_input: Entity, - max_prompt_bytes_input: Entity, - context_mode: ContextModeState, - zeta: Entity, - _active_editor_subscription: Option, - _update_state_task: Task<()>, - _receive_task: Task<()>, -} - -pub enum ContextModeState { - Llm, - Lsp, - Syntax { - max_retrieved_declarations: Entity, - }, -} - -struct LastPrediction { - prompt_editor: Entity, - retrieval_time: Duration, - request_time: Option, - buffer: WeakEntity, - position: language::Anchor, - state: LastPredictionState, - inputs: EditPredictionInputs, - project_snapshot: Shared>>, - _task: Option>, -} - -#[derive(Clone, Copy, PartialEq)] -enum Feedback { - Positive, - Negative, -} - -enum LastPredictionState { - Requested, - Success { - model_response_editor: Entity, - feedback_editor: Entity, - feedback: Option, - request_id: String, - }, - Failed { - message: String, - }, -} - -impl Zeta2Inspector { - pub fn new( - project: &Entity, - client: &Arc, - user_store: &Entity, - window: &mut Window, - cx: &mut Context, - ) -> Self { - let zeta = Zeta::global(client, user_store, cx); - let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info()); - - let receive_task = cx.spawn_in(window, async move |this, cx| { - while let Some(prediction) = request_rx.next().await { - this.update_in(cx, |this, window, cx| { - this.update_last_prediction(prediction, window, cx) - }) - .ok(); - } - }); - - let mut this = Self { - focus_handle: cx.focus_handle(), - project: project.clone(), - last_prediction: None, - max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx), - min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx), - cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx), - max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx), - context_mode: ContextModeState::Llm, - zeta: zeta.clone(), - _active_editor_subscription: None, - _update_state_task: Task::ready(()), - _receive_task: receive_task, - }; - this.set_options_state(&zeta.read(cx).options().clone(), window, cx); - this - } - - fn set_options_state( - &mut self, - options: &ZetaOptions, - window: &mut Window, - cx: &mut Context, - ) { - let excerpt_options = options.context.excerpt(); - self.max_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(excerpt_options.max_bytes.to_string(), window, cx); - }); - self.min_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(excerpt_options.min_bytes.to_string(), window, cx); - }); - self.cursor_context_ratio_input.update(cx, |input, cx| { - input.set_text( - format!( - "{:.2}", - excerpt_options.target_before_cursor_over_total_bytes - ), - window, - cx, - ); - }); - self.max_prompt_bytes_input.update(cx, |input, cx| { - input.set_text(options.max_prompt_bytes.to_string(), window, cx); - }); - - match &options.context { - ContextMode::Agentic(_) => { - self.context_mode = ContextModeState::Llm; - } - ContextMode::Syntax(_) => { - self.context_mode = ContextModeState::Syntax { - max_retrieved_declarations: Self::number_input( - "Max Retrieved Definitions", - window, - cx, - ), - }; - } - ContextMode::Lsp(_) => { - self.context_mode = ContextModeState::Lsp; - } - } - cx.notify(); - } - - fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context) { - self.zeta.update(cx, |this, _cx| this.set_options(options)); - - if let Some(prediction) = self.last_prediction.as_mut() { - if let Some(buffer) = prediction.buffer.upgrade() { - let position = prediction.position; - let project = self.project.clone(); - self.zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project, buffer, position, cx) - }); - prediction.state = LastPredictionState::Requested; - } else { - self.last_prediction.take(); - } - } - - cx.notify(); - } - - fn number_input( - label: &'static str, - window: &mut Window, - cx: &mut Context, - ) -> Entity { - let input = cx.new(|cx| { - InputField::new(window, cx, "") - .label(label) - .label_min_width(px(64.)) - }); - - cx.subscribe_in( - &input.read(cx).editor().clone(), - window, - |this, _, event, _window, cx| { - let EditorEvent::BufferEdited = event else { - return; - }; - - fn number_input_value( - input: &Entity, - cx: &App, - ) -> T { - input - .read(cx) - .editor() - .read(cx) - .text(cx) - .parse::() - .unwrap_or_default() - } - - let zeta_options = this.zeta.read(cx).options().clone(); - - let excerpt_options = EditPredictionExcerptOptions { - max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx), - min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx), - target_before_cursor_over_total_bytes: number_input_value( - &this.cursor_context_ratio_input, - cx, - ), - }; - - let context = match zeta_options.context { - ContextMode::Agentic(_context_options) => { - ContextMode::Agentic(AgenticContextOptions { - excerpt: excerpt_options, - }) - } - ContextMode::Syntax(context_options) => { - let max_retrieved_declarations = match &this.context_mode { - ContextModeState::Llm => { - zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations - } - ContextModeState::Syntax { - max_retrieved_declarations, - } => number_input_value(max_retrieved_declarations, cx), - ContextModeState::Lsp => { - zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations - } - }; - - ContextMode::Syntax(EditPredictionContextOptions { - excerpt: excerpt_options, - max_retrieved_declarations, - ..context_options - }) - } - ContextMode::Lsp(excerpt_options) => ContextMode::Lsp(excerpt_options), - }; - - this.set_zeta_options( - ZetaOptions { - context, - max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx), - max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, - prompt_format: zeta_options.prompt_format, - file_indexing_parallelism: zeta_options.file_indexing_parallelism, - buffer_change_grouping_interval: zeta_options - .buffer_change_grouping_interval, - }, - cx, - ); - }, - ) - .detach(); - input - } - - fn update_last_prediction( - &mut self, - prediction: zeta::ZetaDebugInfo, - window: &mut Window, - cx: &mut Context, - ) { - self._update_state_task = cx.spawn_in(window, { - let language_registry = self.project.read(cx).languages().clone(); - async move |this, cx| { - let mut languages = HashMap::default(); - let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else { - return; - }; - for ext in prediction - .inputs - .included_files - .iter() - .filter_map(|file| file.path.extension()) - { - if !languages.contains_key(ext) { - // Most snippets are gonna be the same language, - // so we think it's fine to do this sequentially for now - languages.insert( - ext.to_owned(), - language_registry - .language_for_name_or_extension(&ext.to_string_lossy()) - .await - .ok(), - ); - } - } - - let markdown_language = language_registry - .language_for_name("Markdown") - .await - .log_err(); - - let json_language = language_registry.language_for_name("Json").await.log_err(); - - this.update_in(cx, |this, window, cx| { - let ZetaEditPredictionDebugInfo { - response_rx, - position, - buffer, - retrieval_time, - local_prompt, - .. - } = prediction; - - let task = cx.spawn_in(window, { - let markdown_language = markdown_language.clone(); - let json_language = json_language.clone(); - async move |this, cx| { - let response = response_rx.await; - - this.update_in(cx, |this, window, cx| { - if let Some(prediction) = this.last_prediction.as_mut() { - prediction.state = match response { - Ok((Ok(response), request_time)) => { - prediction.request_time = Some(request_time); - - let feedback_editor = cx.new(|cx| { - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local("", cx); - buffer.set_language( - markdown_language.clone(), - cx, - ); - buffer - }); - let buffer = - cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - let mut editor = Editor::new( - EditorMode::AutoHeight { - min_lines: 3, - max_lines: None, - }, - buffer, - None, - window, - cx, - ); - editor.set_placeholder_text( - "Write feedback here", - window, - cx, - ); - editor.set_show_line_numbers(false, cx); - editor.set_show_gutter(false, cx); - editor.set_show_scrollbars(false, cx); - editor - }); - - cx.subscribe_in( - &feedback_editor, - window, - |this, editor, ev, window, cx| match ev { - EditorEvent::BufferEdited => { - if let Some(last_prediction) = - this.last_prediction.as_mut() - && let LastPredictionState::Success { - feedback: feedback_state, - .. - } = &mut last_prediction.state - { - if feedback_state.take().is_some() { - editor.update(cx, |editor, cx| { - editor.set_placeholder_text( - "Write feedback here", - window, - cx, - ); - }); - cx.notify(); - } - } - } - _ => {} - }, - ) - .detach(); - - LastPredictionState::Success { - model_response_editor: cx.new(|cx| { - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local( - serde_json::to_string_pretty(&response) - .unwrap_or_default(), - cx, - ); - buffer.set_language(json_language, cx); - buffer - }); - let buffer = cx.new(|cx| { - MultiBuffer::singleton(buffer, cx) - }); - let mut editor = Editor::new( - EditorMode::full(), - buffer, - None, - window, - cx, - ); - editor.set_read_only(true); - editor.set_show_line_numbers(false, cx); - editor.set_show_gutter(false, cx); - editor.set_show_scrollbars(false, cx); - editor - }), - feedback_editor, - feedback: None, - request_id: response.id.clone(), - } - } - Ok((Err(err), request_time)) => { - prediction.request_time = Some(request_time); - LastPredictionState::Failed { message: err } - } - Err(oneshot::Canceled) => LastPredictionState::Failed { - message: "Canceled".to_string(), - }, - }; - } - }) - .ok(); - } - }); - - let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx); - - this.last_prediction = Some(LastPrediction { - prompt_editor: cx.new(|cx| { - let buffer = cx.new(|cx| { - let mut buffer = - Buffer::local(local_prompt.unwrap_or_else(|err| err), cx); - buffer.set_language(markdown_language.clone(), cx); - buffer - }); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - let mut editor = - Editor::new(EditorMode::full(), buffer, None, window, cx); - editor.set_read_only(true); - editor.set_show_line_numbers(false, cx); - editor.set_show_gutter(false, cx); - editor.set_show_scrollbars(false, cx); - editor - }), - retrieval_time, - request_time: None, - buffer, - position, - state: LastPredictionState::Requested, - project_snapshot: cx - .foreground_executor() - .spawn(async move { Arc::new(project_snapshot_task.await) }) - .shared(), - inputs: prediction.inputs, - _task: Some(task), - }); - cx.notify(); - }) - .ok(); - } - }); - } - - fn handle_rate_positive( - &mut self, - _action: &Zeta2RatePredictionPositive, - window: &mut Window, - cx: &mut Context, - ) { - self.handle_rate(Feedback::Positive, window, cx); - } - - fn handle_rate_negative( - &mut self, - _action: &Zeta2RatePredictionNegative, - window: &mut Window, - cx: &mut Context, - ) { - self.handle_rate(Feedback::Negative, window, cx); - } - - fn handle_rate(&mut self, kind: Feedback, window: &mut Window, cx: &mut Context) { - let Some(last_prediction) = self.last_prediction.as_mut() else { - return; - }; - - let project_snapshot_task = last_prediction.project_snapshot.clone(); - - cx.spawn_in(window, async move |this, cx| { - let project_snapshot = project_snapshot_task.await; - this.update_in(cx, |this, window, cx| { - let Some(last_prediction) = this.last_prediction.as_mut() else { - return; - }; - - let LastPredictionState::Success { - feedback: feedback_state, - feedback_editor, - model_response_editor, - request_id, - .. - } = &mut last_prediction.state - else { - return; - }; - - *feedback_state = Some(kind); - let text = feedback_editor.update(cx, |feedback_editor, cx| { - feedback_editor.set_placeholder_text( - "Submitted. Edit or submit again to change.", - window, - cx, - ); - feedback_editor.text(cx) - }); - cx.notify(); - - cx.defer_in(window, { - let model_response_editor = model_response_editor.downgrade(); - move |_, window, cx| { - if let Some(model_response_editor) = model_response_editor.upgrade() { - model_response_editor.focus_handle(cx).focus(window); - } - } - }); - - let kind = match kind { - Feedback::Positive => "positive", - Feedback::Negative => "negative", - }; - - telemetry::event!( - "Zeta2 Prediction Rated", - id = request_id, - kind = kind, - text = text, - request = last_prediction.inputs, - project_snapshot = project_snapshot, - ); - }) - .log_err(); - }) - .detach(); - } - - fn render_options(&self, window: &mut Window, cx: &mut Context) -> Div { - v_flex() - .gap_2() - .child( - h_flex() - .child(Headline::new("Options").size(HeadlineSize::Small)) - .justify_between() - .child( - ui::Button::new("reset-options", "Reset") - .disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS) - .style(ButtonStyle::Outlined) - .size(ButtonSize::Large) - .on_click(cx.listener(|this, _, window, cx| { - this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx); - })), - ), - ) - .child( - v_flex() - .gap_2() - .child( - h_flex() - .gap_2() - .items_end() - .child(self.max_excerpt_bytes_input.clone()) - .child(self.min_excerpt_bytes_input.clone()) - .child(self.cursor_context_ratio_input.clone()) - .child(self.render_context_mode_dropdown(window, cx)), - ) - .child( - h_flex() - .gap_2() - .items_end() - .children(match &self.context_mode { - ContextModeState::Llm => None, - ContextModeState::Syntax { - max_retrieved_declarations, - } => Some(max_retrieved_declarations.clone()), - ContextModeState::Lsp => None, - }) - .child(self.max_prompt_bytes_input.clone()) - .child(self.render_prompt_format_dropdown(window, cx)), - ), - ) - } - - fn render_context_mode_dropdown(&self, window: &mut Window, cx: &mut Context) -> Div { - let this = cx.weak_entity(); - - v_flex() - .gap_1p5() - .child( - Label::new("Context Mode") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - DropdownMenu::new( - "ep-ctx-mode", - match &self.context_mode { - ContextModeState::Llm => "LLM-based", - ContextModeState::Syntax { .. } => "Syntax", - ContextModeState::Lsp => "LSP-based", - }, - ContextMenu::build(window, cx, move |menu, _window, _cx| { - menu.item( - ContextMenuEntry::new("LLM-based") - .toggleable( - IconPosition::End, - matches!(self.context_mode, ContextModeState::Llm), - ) - .handler({ - let this = this.clone(); - move |window, cx| { - this.update(cx, |this, cx| { - let current_options = - this.zeta.read(cx).options().clone(); - match current_options.context.clone() { - ContextMode::Agentic(_) => {} - ContextMode::Lsp(_) => {} - ContextMode::Syntax(context_options) => { - let options = ZetaOptions { - context: ContextMode::Agentic( - AgenticContextOptions { - excerpt: context_options.excerpt, - }, - ), - ..current_options - }; - this.set_options_state(&options, window, cx); - this.set_zeta_options(options, cx); - } - } - }) - .ok(); - } - }), - ) - .item( - ContextMenuEntry::new("Syntax") - .toggleable( - IconPosition::End, - matches!(self.context_mode, ContextModeState::Syntax { .. }), - ) - .handler({ - move |window, cx| { - this.update(cx, |this, cx| { - let current_options = - this.zeta.read(cx).options().clone(); - match current_options.context.clone() { - ContextMode::Agentic(context_options) => { - let options = ZetaOptions { - context: ContextMode::Syntax( - EditPredictionContextOptions { - excerpt: context_options.excerpt, - ..DEFAULT_SYNTAX_CONTEXT_OPTIONS - }, - ), - ..current_options - }; - this.set_options_state(&options, window, cx); - this.set_zeta_options(options, cx); - } - ContextMode::Syntax(_) => {} - ContextMode::Lsp(_) => {} - } - }) - .ok(); - } - }), - ) - }), - ) - .style(ui::DropdownStyle::Outlined), - ) - } - - fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context) -> Div { - let active_format = self.zeta.read(cx).options().prompt_format; - let this = cx.weak_entity(); - - v_flex() - .gap_1p5() - .child( - Label::new("Prompt Format") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - DropdownMenu::new( - "ep-prompt-format", - active_format.to_string(), - ContextMenu::build(window, cx, move |mut menu, _window, _cx| { - for prompt_format in PromptFormat::iter() { - menu = menu.item( - ContextMenuEntry::new(prompt_format.to_string()) - .toggleable(IconPosition::End, active_format == prompt_format) - .handler({ - let this = this.clone(); - move |_window, cx| { - this.update(cx, |this, cx| { - let current_options = - this.zeta.read(cx).options().clone(); - let options = ZetaOptions { - prompt_format, - ..current_options - }; - this.set_zeta_options(options, cx); - }) - .ok(); - } - }), - ) - } - menu - }), - ) - .style(ui::DropdownStyle::Outlined), - ) - } - - fn render_stats(&self) -> Option
{ - let Some(prediction) = self.last_prediction.as_ref() else { - return None; - }; - - Some( - v_flex() - .p_4() - .gap_2() - .min_w(px(160.)) - .child(Headline::new("Stats").size(HeadlineSize::Small)) - .child(Self::render_duration( - "Context retrieval", - Some(prediction.retrieval_time), - )) - .child(Self::render_duration("Request", prediction.request_time)), - ) - } - - fn render_duration(name: &'static str, time: Option) -> Div { - h_flex() - .gap_1() - .child(Label::new(name).color(Color::Muted).size(LabelSize::Small)) - .child(match time { - Some(time) => Label::new(if time.as_micros() >= 1000 { - format!("{} ms", time.as_millis()) - } else { - format!("{} µs", time.as_micros()) - }) - .size(LabelSize::Small), - None => Label::new("...").size(LabelSize::Small), - }) - } - - fn render_content(&self, _: &mut Window, cx: &mut Context) -> AnyElement { - if !cx.has_flag::() { - return Self::render_message("`zeta2` feature flag is not enabled"); - } - - match self.last_prediction.as_ref() { - None => Self::render_message("No prediction"), - Some(prediction) => self.render_last_prediction(prediction, cx).into_any(), - } - } - - fn render_message(message: impl Into) -> AnyElement { - v_flex() - .size_full() - .justify_center() - .items_center() - .child(Label::new(message).size(LabelSize::Large)) - .into_any() - } - - fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context) -> Div { - h_flex() - .items_start() - .w_full() - .flex_1() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().editor_background) - .child( - v_flex() - .flex_1() - .gap_2() - .p_4() - .h_full() - .child( - h_flex() - .justify_between() - .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) - .child(match prediction.state { - LastPredictionState::Requested - | LastPredictionState::Failed { .. } => ui::Chip::new("Local") - .bg_color(cx.theme().status().warning_background) - .label_color(Color::Success), - LastPredictionState::Success { .. } => ui::Chip::new("Cloud") - .bg_color(cx.theme().status().success_background) - .label_color(Color::Success), - }), - ) - .child(prediction.prompt_editor.clone()), - ) - .child(ui::vertical_divider()) - .child( - v_flex() - .flex_1() - .gap_2() - .h_full() - .child( - v_flex() - .flex_1() - .gap_2() - .p_4() - .child( - ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall), - ) - .child(match &prediction.state { - LastPredictionState::Success { - model_response_editor, - .. - } => model_response_editor.clone().into_any_element(), - LastPredictionState::Requested => v_flex() - .gap_2() - .child(Label::new("Loading...").buffer_font(cx)) - .into_any_element(), - LastPredictionState::Failed { message } => v_flex() - .gap_2() - .max_w_96() - .child(Label::new(message.clone()).buffer_font(cx)) - .into_any_element(), - }), - ) - .child(ui::divider()) - .child( - if let LastPredictionState::Success { - feedback_editor, - feedback: feedback_state, - .. - } = &prediction.state - { - v_flex() - .key_context("Zeta2Feedback") - .on_action(cx.listener(Self::handle_rate_positive)) - .on_action(cx.listener(Self::handle_rate_negative)) - .gap_2() - .p_2() - .child(feedback_editor.clone()) - .child( - h_flex() - .justify_end() - .w_full() - .child( - ButtonLike::new("rate-positive") - .when( - *feedback_state == Some(Feedback::Positive), - |this| this.style(ButtonStyle::Filled), - ) - .child( - KeyBinding::for_action( - &Zeta2RatePredictionPositive, - cx, - ) - .size(TextSize::Small.rems(cx)), - ) - .child(ui::Icon::new(ui::IconName::ThumbsUp)) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_rate_positive( - &Zeta2RatePredictionPositive, - window, - cx, - ); - })), - ) - .child( - ButtonLike::new("rate-negative") - .when( - *feedback_state == Some(Feedback::Negative), - |this| this.style(ButtonStyle::Filled), - ) - .child( - KeyBinding::for_action( - &Zeta2RatePredictionNegative, - cx, - ) - .size(TextSize::Small.rems(cx)), - ) - .child(ui::Icon::new(ui::IconName::ThumbsDown)) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_rate_negative( - &Zeta2RatePredictionNegative, - window, - cx, - ); - })), - ), - ) - .into_any() - } else { - Empty.into_any_element() - }, - ), - ) - } -} - -impl Focusable for Zeta2Inspector { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl Item for Zeta2Inspector { - type Event = (); - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Zeta2 Inspector".into() - } -} - -impl EventEmitter<()> for Zeta2Inspector {} - -impl Render for Zeta2Inspector { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - v_flex() - .size_full() - .bg(cx.theme().colors().editor_background) - .child( - h_flex() - .w_full() - .child( - v_flex() - .flex_1() - .p_4() - .h_full() - .justify_between() - .child(self.render_options(window, cx)) - .gap_4(), - ) - .child(ui::vertical_divider()) - .children(self.render_stats()), - ) - .child(self.render_content(window, cx)) - } -} diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/zeta_cli/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/src/syntax_retrieval_stats.rs b/crates/zeta_cli/src/syntax_retrieval_stats.rs deleted file mode 100644 index 4c7506ff78952da79acfeae751959bfe8182b9d4..0000000000000000000000000000000000000000 --- a/crates/zeta_cli/src/syntax_retrieval_stats.rs +++ /dev/null @@ -1,1260 +0,0 @@ -use ::util::rel_path::RelPath; -use ::util::{RangeExt, ResultExt as _}; -use anyhow::{Context as _, Result}; -use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; -use edit_prediction_context::{ - Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, Identifier, - Imports, Reference, ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range, -}; -use futures::StreamExt as _; -use futures::channel::mpsc; -use gpui::Entity; -use gpui::{AppContext, AsyncApp}; -use language::OffsetRangeExt; -use language::{BufferSnapshot, Point}; -use ordered_float::OrderedFloat; -use polars::prelude::*; -use project::{Project, ProjectEntryId, ProjectPath, Worktree}; -use serde::{Deserialize, Serialize}; -use std::fs; -use std::{ - cmp::Reverse, - collections::{HashMap, HashSet}, - fs::File, - hash::{Hash, Hasher}, - io::{BufRead, BufReader, BufWriter, Write as _}, - ops::Range, - path::{Path, PathBuf}, - sync::{ - Arc, - atomic::{self, AtomicUsize}, - }, - time::Duration, -}; -use util::paths::PathStyle; -use zeta::ContextMode; - -use crate::headless::ZetaCliAppState; -use crate::source_location::SourceLocation; -use crate::util::{open_buffer, open_buffer_with_language_server}; - -pub async fn retrieval_stats( - worktree: PathBuf, - app_state: Arc, - only_extension: Option, - file_limit: Option, - skip_files: Option, - options: zeta::ZetaOptions, - cx: &mut AsyncApp, -) -> Result { - let ContextMode::Syntax(context_options) = options.context.clone() else { - anyhow::bail!("retrieval stats only works in ContextMode::Syntax"); - }; - - let options = Arc::new(options); - let worktree_path = worktree.canonicalize()?; - - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(&worktree_path, true, cx) - })? - .await?; - - // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree. - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - - let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?; - index - .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))? - .await?; - let indexed_files = index - .read_with(cx, |index, cx| index.indexed_file_paths(cx))? - .await; - let mut filtered_files = indexed_files - .into_iter() - .filter(|project_path| { - let file_extension = project_path.path.extension(); - if let Some(only_extension) = only_extension.as_ref() { - file_extension.is_some_and(|extension| extension == only_extension) - } else { - file_extension - .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension)) - } - }) - .collect::>(); - filtered_files.sort_by(|a, b| a.path.cmp(&b.path)); - - let index_state = index.read_with(cx, |index, _cx| index.state().clone())?; - cx.update(|_| { - drop(index); - })?; - let index_state = Arc::new( - Arc::into_inner(index_state) - .context("Index state had more than 1 reference")? - .into_inner(), - ); - - struct FileSnapshot { - project_entry_id: ProjectEntryId, - snapshot: BufferSnapshot, - hash: u64, - parent_abs_path: Arc, - } - - let files: Vec = futures::future::try_join_all({ - filtered_files - .iter() - .map(|file| { - let buffer_task = - open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx); - cx.spawn(async move |cx| { - let buffer = buffer_task.await?; - let (project_entry_id, parent_abs_path, snapshot) = - buffer.read_with(cx, |buffer, cx| { - let file = project::File::from_dyn(buffer.file()).unwrap(); - let project_entry_id = file.project_entry_id().unwrap(); - let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path); - if !parent_abs_path.pop() { - panic!("Invalid worktree path"); - } - - (project_entry_id, parent_abs_path, buffer.snapshot()) - })?; - - anyhow::Ok( - cx.background_spawn(async move { - let mut hasher = collections::FxHasher::default(); - snapshot.text().hash(&mut hasher); - FileSnapshot { - project_entry_id, - snapshot, - hash: hasher.finish(), - parent_abs_path: parent_abs_path.into(), - } - }) - .await, - ) - }) - }) - .collect::>() - }) - .await?; - - let mut file_snapshots = HashMap::default(); - let mut hasher = collections::FxHasher::default(); - for FileSnapshot { - project_entry_id, - snapshot, - hash, - .. - } in &files - { - file_snapshots.insert(*project_entry_id, snapshot.clone()); - hash.hash(&mut hasher); - } - let files_hash = hasher.finish(); - let file_snapshots = Arc::new(file_snapshots); - let target_cli_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../target/zeta_cli"); - fs::create_dir_all(&target_cli_dir).unwrap(); - let target_cli_dir = target_cli_dir.canonicalize().unwrap(); - - let lsp_cache_dir = target_cli_dir.join("cache"); - fs::create_dir_all(&lsp_cache_dir).unwrap(); - - let lsp_definitions_path = lsp_cache_dir.join(format!( - "{}-{:x}.jsonl", - worktree_path.file_stem().unwrap_or_default().display(), - files_hash - )); - - let mut lsp_definitions = HashMap::default(); - let mut lsp_files = 0; - - if fs::exists(&lsp_definitions_path)? { - log::info!( - "Using cached LSP definitions from {}", - lsp_definitions_path.display() - ); - - let file = File::options() - .read(true) - .write(true) - .open(&lsp_definitions_path)?; - let lines = BufReader::new(&file).lines(); - let mut valid_len: usize = 0; - - for (line, expected_file) in lines.zip(files.iter()) { - let line = line?; - let FileLspDefinitions { path, references } = match serde_json::from_str(&line) { - Ok(ok) => ok, - Err(_) => { - log::error!("Found invalid cache line. Truncating to #{lsp_files}.",); - file.set_len(valid_len as u64)?; - break; - } - }; - let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str(); - if expected_path != path.as_ref() { - log::error!( - "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.", - lsp_files + 1 - ); - file.set_len(valid_len as u64)?; - break; - } - for (point, ranges) in references { - let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else { - log::warn!("Invalid path: {}", path); - continue; - }; - lsp_definitions.insert( - SourceLocation { - path: path.into_arc(), - point: point.into(), - }, - ranges, - ); - } - lsp_files += 1; - valid_len += line.len() + 1 - } - } - - if lsp_files < files.len() { - if lsp_files == 0 { - log::warn!( - "No LSP definitions found, populating {}", - lsp_definitions_path.display() - ); - } else { - log::warn!("{} files missing from LSP cache", files.len() - lsp_files); - } - - gather_lsp_definitions( - &lsp_definitions_path, - lsp_files, - &filtered_files, - &worktree, - &project, - &mut lsp_definitions, - cx, - ) - .await?; - } - let files_len = files.len().min(file_limit.unwrap_or(usize::MAX)); - let done_count = Arc::new(AtomicUsize::new(0)); - - let (output_tx, output_rx) = mpsc::unbounded::(); - - let tasks = files - .into_iter() - .skip(skip_files.unwrap_or(0)) - .take(file_limit.unwrap_or(usize::MAX)) - .map(|project_file| { - let index_state = index_state.clone(); - let lsp_definitions = lsp_definitions.clone(); - let output_tx = output_tx.clone(); - let done_count = done_count.clone(); - let file_snapshots = file_snapshots.clone(); - let context_options = context_options.clone(); - cx.background_spawn(async move { - let snapshot = project_file.snapshot; - - let full_range = 0..snapshot.len(); - let references = references_in_range( - full_range, - &snapshot.text(), - ReferenceRegion::Nearby, - &snapshot, - ); - - let imports = if context_options.use_imports { - Imports::gather(&snapshot, Some(&project_file.parent_abs_path)) - } else { - Imports::default() - }; - - let path = snapshot.file().unwrap().path(); - - for reference in references { - let query_point = snapshot.offset_to_point(reference.range.start); - let source_location = SourceLocation { - path: path.clone(), - point: query_point, - }; - let lsp_definitions = lsp_definitions - .get(&source_location) - .cloned() - .unwrap_or_else(|| { - log::warn!( - "No definitions found for source location: {:?}", - source_location - ); - Vec::new() - }); - - let retrieve_result = retrieve_definitions( - &reference, - &imports, - query_point, - &snapshot, - &index_state, - &file_snapshots, - &context_options, - ) - .await?; - - let result = ReferenceRetrievalResult { - cursor_path: path.clone(), - identifier: reference.identifier, - cursor_point: query_point, - lsp_definitions, - retrieved_definitions: retrieve_result.definitions, - excerpt_range: retrieve_result.excerpt_range, - }; - - output_tx.unbounded_send(result).ok(); - } - - println!( - "{:02}/{:02} done", - done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1, - files_len, - ); - - anyhow::Ok(()) - }) - }) - .collect::>(); - - drop(output_tx); - - let df_task = cx.background_spawn(build_dataframe(output_rx)); - - futures::future::try_join_all(tasks).await?; - let mut df = df_task.await?; - - let run_id = format!( - "{}-{}", - worktree_path.file_stem().unwrap_or_default().display(), - chrono::Local::now().format("%Y%m%d_%H%M%S") - ); - let run_dir = target_cli_dir.join(run_id); - fs::create_dir(&run_dir).unwrap(); - - let parquet_path = run_dir.join("stats.parquet"); - let mut parquet_file = fs::File::create(&parquet_path)?; - - ParquetWriter::new(&mut parquet_file) - .finish(&mut df) - .unwrap(); - - let stats = SummaryStats::from_dataframe(df)?; - - let stats_path = run_dir.join("stats.txt"); - fs::write(&stats_path, format!("{}", stats))?; - - println!("{}", stats); - println!("\nWrote:"); - println!("- {}", relativize_path(&parquet_path).display()); - println!("- {}", relativize_path(&stats_path).display()); - println!("- {}", relativize_path(&lsp_definitions_path).display()); - - Ok("".to_string()) -} - -async fn build_dataframe( - mut output_rx: mpsc::UnboundedReceiver, -) -> Result { - use soa_rs::{Soa, Soars}; - - #[derive(Default, Soars)] - struct Row { - ref_id: u32, - cursor_path: String, - cursor_row: u32, - cursor_column: u32, - cursor_identifier: String, - gold_in_excerpt: bool, - gold_path: String, - gold_row: u32, - gold_column: u32, - gold_is_external: bool, - candidate_count: u32, - candidate_path: Option, - candidate_row: Option, - candidate_column: Option, - candidate_is_gold: Option, - candidate_rank: Option, - candidate_is_same_file: Option, - candidate_is_referenced_nearby: Option, - candidate_is_referenced_in_breadcrumb: Option, - candidate_reference_count: Option, - candidate_same_file_declaration_count: Option, - candidate_declaration_count: Option, - candidate_reference_line_distance: Option, - candidate_declaration_line_distance: Option, - candidate_excerpt_vs_item_jaccard: Option, - candidate_excerpt_vs_signature_jaccard: Option, - candidate_adjacent_vs_item_jaccard: Option, - candidate_adjacent_vs_signature_jaccard: Option, - candidate_excerpt_vs_item_weighted_overlap: Option, - candidate_excerpt_vs_signature_weighted_overlap: Option, - candidate_adjacent_vs_item_weighted_overlap: Option, - candidate_adjacent_vs_signature_weighted_overlap: Option, - candidate_path_import_match_count: Option, - candidate_wildcard_path_import_match_count: Option, - candidate_import_similarity: Option, - candidate_max_import_similarity: Option, - candidate_normalized_import_similarity: Option, - candidate_wildcard_import_similarity: Option, - candidate_normalized_wildcard_import_similarity: Option, - candidate_included_by_others: Option, - candidate_includes_others: Option, - } - let mut rows = Soa::::new(); - let mut next_ref_id = 0; - - while let Some(result) = output_rx.next().await { - let mut gold_is_external = false; - let mut gold_in_excerpt = false; - let cursor_path = result.cursor_path.as_unix_str(); - let cursor_row = result.cursor_point.row + 1; - let cursor_column = result.cursor_point.column + 1; - let cursor_identifier = result.identifier.name.to_string(); - let ref_id = next_ref_id; - next_ref_id += 1; - - for lsp_definition in result.lsp_definitions { - let SourceRange { - path: gold_path, - point_range: gold_point_range, - offset_range: gold_offset_range, - } = lsp_definition; - let lsp_point_range = - SerializablePoint::into_language_point_range(gold_point_range.clone()); - - gold_is_external = gold_is_external - || gold_path.is_absolute() - || gold_path - .components() - .any(|component| component.as_os_str() == "node_modules"); - - gold_in_excerpt = gold_in_excerpt - || result.excerpt_range.as_ref().is_some_and(|excerpt_range| { - excerpt_range.contains_inclusive(&gold_offset_range) - }); - - let gold_row = gold_point_range.start.row; - let gold_column = gold_point_range.start.column; - let candidate_count = result.retrieved_definitions.len() as u32; - - for (candidate_rank, retrieved_definition) in - result.retrieved_definitions.iter().enumerate() - { - let candidate_is_gold = gold_path.as_path() - == retrieved_definition.path.as_std_path() - && retrieved_definition - .range - .contains_inclusive(&lsp_point_range); - - let candidate_row = retrieved_definition.range.start.row + 1; - let candidate_column = retrieved_definition.range.start.column + 1; - - let DeclarationScoreComponents { - is_same_file, - is_referenced_nearby, - is_referenced_in_breadcrumb, - reference_count, - same_file_declaration_count, - declaration_count, - reference_line_distance, - declaration_line_distance, - excerpt_vs_item_jaccard, - excerpt_vs_signature_jaccard, - adjacent_vs_item_jaccard, - adjacent_vs_signature_jaccard, - excerpt_vs_item_weighted_overlap, - excerpt_vs_signature_weighted_overlap, - adjacent_vs_item_weighted_overlap, - adjacent_vs_signature_weighted_overlap, - path_import_match_count, - wildcard_path_import_match_count, - import_similarity, - max_import_similarity, - normalized_import_similarity, - wildcard_import_similarity, - normalized_wildcard_import_similarity, - included_by_others, - includes_others, - } = retrieved_definition.components; - - rows.push(Row { - ref_id, - cursor_path: cursor_path.to_string(), - cursor_row, - cursor_column, - cursor_identifier: cursor_identifier.clone(), - gold_in_excerpt, - gold_path: gold_path.to_string_lossy().to_string(), - gold_row, - gold_column, - gold_is_external, - candidate_count, - candidate_path: Some(retrieved_definition.path.as_unix_str().to_string()), - candidate_row: Some(candidate_row), - candidate_column: Some(candidate_column), - candidate_is_gold: Some(candidate_is_gold), - candidate_rank: Some(candidate_rank as u32), - candidate_is_same_file: Some(is_same_file), - candidate_is_referenced_nearby: Some(is_referenced_nearby), - candidate_is_referenced_in_breadcrumb: Some(is_referenced_in_breadcrumb), - candidate_reference_count: Some(reference_count as u32), - candidate_same_file_declaration_count: Some(same_file_declaration_count as u32), - candidate_declaration_count: Some(declaration_count as u32), - candidate_reference_line_distance: Some(reference_line_distance), - candidate_declaration_line_distance: Some(declaration_line_distance), - candidate_excerpt_vs_item_jaccard: Some(excerpt_vs_item_jaccard), - candidate_excerpt_vs_signature_jaccard: Some(excerpt_vs_signature_jaccard), - candidate_adjacent_vs_item_jaccard: Some(adjacent_vs_item_jaccard), - candidate_adjacent_vs_signature_jaccard: Some(adjacent_vs_signature_jaccard), - candidate_excerpt_vs_item_weighted_overlap: Some( - excerpt_vs_item_weighted_overlap, - ), - candidate_excerpt_vs_signature_weighted_overlap: Some( - excerpt_vs_signature_weighted_overlap, - ), - candidate_adjacent_vs_item_weighted_overlap: Some( - adjacent_vs_item_weighted_overlap, - ), - candidate_adjacent_vs_signature_weighted_overlap: Some( - adjacent_vs_signature_weighted_overlap, - ), - candidate_path_import_match_count: Some(path_import_match_count as u32), - candidate_wildcard_path_import_match_count: Some( - wildcard_path_import_match_count as u32, - ), - candidate_import_similarity: Some(import_similarity), - candidate_max_import_similarity: Some(max_import_similarity), - candidate_normalized_import_similarity: Some(normalized_import_similarity), - candidate_wildcard_import_similarity: Some(wildcard_import_similarity), - candidate_normalized_wildcard_import_similarity: Some( - normalized_wildcard_import_similarity, - ), - candidate_included_by_others: Some(included_by_others as u32), - candidate_includes_others: Some(includes_others as u32), - }); - } - - if result.retrieved_definitions.is_empty() { - rows.push(Row { - ref_id, - cursor_path: cursor_path.to_string(), - cursor_row, - cursor_column, - cursor_identifier: cursor_identifier.clone(), - gold_in_excerpt, - gold_path: gold_path.to_string_lossy().to_string(), - gold_row, - gold_column, - gold_is_external, - candidate_count, - ..Default::default() - }); - } - } - } - let slices = rows.slices(); - - let RowSlices { - ref_id, - cursor_path, - cursor_row, - cursor_column, - cursor_identifier, - gold_in_excerpt, - gold_path, - gold_row, - gold_column, - gold_is_external, - candidate_path, - candidate_row, - candidate_column, - candidate_is_gold, - candidate_rank, - candidate_count, - candidate_is_same_file, - candidate_is_referenced_nearby, - candidate_is_referenced_in_breadcrumb, - candidate_reference_count, - candidate_same_file_declaration_count, - candidate_declaration_count, - candidate_reference_line_distance, - candidate_declaration_line_distance, - candidate_excerpt_vs_item_jaccard, - candidate_excerpt_vs_signature_jaccard, - candidate_adjacent_vs_item_jaccard, - candidate_adjacent_vs_signature_jaccard, - candidate_excerpt_vs_item_weighted_overlap, - candidate_excerpt_vs_signature_weighted_overlap, - candidate_adjacent_vs_item_weighted_overlap, - candidate_adjacent_vs_signature_weighted_overlap, - candidate_path_import_match_count, - candidate_wildcard_path_import_match_count, - candidate_import_similarity, - candidate_max_import_similarity, - candidate_normalized_import_similarity, - candidate_wildcard_import_similarity, - candidate_normalized_wildcard_import_similarity, - candidate_included_by_others, - candidate_includes_others, - } = slices; - - let df = DataFrame::new(vec![ - Series::new(PlSmallStr::from_str("ref_id"), ref_id).into(), - Series::new(PlSmallStr::from_str("cursor_path"), cursor_path).into(), - Series::new(PlSmallStr::from_str("cursor_row"), cursor_row).into(), - Series::new(PlSmallStr::from_str("cursor_column"), cursor_column).into(), - Series::new(PlSmallStr::from_str("cursor_identifier"), cursor_identifier).into(), - Series::new(PlSmallStr::from_str("gold_in_excerpt"), gold_in_excerpt).into(), - Series::new(PlSmallStr::from_str("gold_path"), gold_path).into(), - Series::new(PlSmallStr::from_str("gold_row"), gold_row).into(), - Series::new(PlSmallStr::from_str("gold_column"), gold_column).into(), - Series::new(PlSmallStr::from_str("gold_is_external"), gold_is_external).into(), - Series::new(PlSmallStr::from_str("candidate_count"), candidate_count).into(), - Series::new(PlSmallStr::from_str("candidate_path"), candidate_path).into(), - Series::new(PlSmallStr::from_str("candidate_row"), candidate_row).into(), - Series::new(PlSmallStr::from_str("candidate_column"), candidate_column).into(), - Series::new(PlSmallStr::from_str("candidate_is_gold"), candidate_is_gold).into(), - Series::new(PlSmallStr::from_str("candidate_rank"), candidate_rank).into(), - Series::new( - PlSmallStr::from_str("candidate_is_same_file"), - candidate_is_same_file, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_is_referenced_nearby"), - candidate_is_referenced_nearby, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_is_referenced_in_breadcrumb"), - candidate_is_referenced_in_breadcrumb, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_reference_count"), - candidate_reference_count, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_same_file_declaration_count"), - candidate_same_file_declaration_count, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_declaration_count"), - candidate_declaration_count, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_reference_line_distance"), - candidate_reference_line_distance, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_declaration_line_distance"), - candidate_declaration_line_distance, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_excerpt_vs_item_jaccard"), - candidate_excerpt_vs_item_jaccard, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_excerpt_vs_signature_jaccard"), - candidate_excerpt_vs_signature_jaccard, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_adjacent_vs_item_jaccard"), - candidate_adjacent_vs_item_jaccard, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_adjacent_vs_signature_jaccard"), - candidate_adjacent_vs_signature_jaccard, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_excerpt_vs_item_weighted_overlap"), - candidate_excerpt_vs_item_weighted_overlap, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_excerpt_vs_signature_weighted_overlap"), - candidate_excerpt_vs_signature_weighted_overlap, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_adjacent_vs_item_weighted_overlap"), - candidate_adjacent_vs_item_weighted_overlap, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_adjacent_vs_signature_weighted_overlap"), - candidate_adjacent_vs_signature_weighted_overlap, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_path_import_match_count"), - candidate_path_import_match_count, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_wildcard_path_import_match_count"), - candidate_wildcard_path_import_match_count, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_import_similarity"), - candidate_import_similarity, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_max_import_similarity"), - candidate_max_import_similarity, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_normalized_import_similarity"), - candidate_normalized_import_similarity, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_wildcard_import_similarity"), - candidate_wildcard_import_similarity, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_normalized_wildcard_import_similarity"), - candidate_normalized_wildcard_import_similarity, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_included_by_others"), - candidate_included_by_others, - ) - .into(), - Series::new( - PlSmallStr::from_str("candidate_includes_others"), - candidate_includes_others, - ) - .into(), - ])?; - - Ok(df) -} - -fn relativize_path(path: &Path) -> &Path { - path.strip_prefix(std::env::current_dir().unwrap()) - .unwrap_or(path) -} - -struct SummaryStats { - references_count: u32, - retrieved_count: u32, - top_match_count: u32, - non_top_match_count: u32, - ranking_involved_top_match_count: u32, - missing_none_retrieved: u32, - missing_wrong_retrieval: u32, - missing_external: u32, - in_excerpt_count: u32, -} - -impl SummaryStats { - fn from_dataframe(df: DataFrame) -> Result { - // TODO: use lazy more - let unique_refs = - df.unique::<(), ()>(Some(&["ref_id".into()]), UniqueKeepStrategy::Any, None)?; - let references_count = unique_refs.height() as u32; - - let gold_mask = df.column("candidate_is_gold")?.bool()?; - let gold_df = df.filter(&gold_mask)?; - let retrieved_count = gold_df.height() as u32; - - let top_match_mask = gold_df.column("candidate_rank")?.u32()?.equal(0); - let top_match_df = gold_df.filter(&top_match_mask)?; - let top_match_count = top_match_df.height() as u32; - - let ranking_involved_top_match_count = top_match_df - .column("candidate_count")? - .u32()? - .gt(1) - .sum() - .unwrap_or_default(); - - let non_top_match_count = (!top_match_mask).sum().unwrap_or(0); - - let not_retrieved_df = df - .lazy() - .group_by(&[col("ref_id"), col("candidate_count")]) - .agg(&[ - col("candidate_is_gold") - .fill_null(false) - .sum() - .alias("gold_count"), - col("gold_in_excerpt").sum().alias("gold_in_excerpt_count"), - col("gold_is_external") - .sum() - .alias("gold_is_external_count"), - ]) - .filter(col("gold_count").eq(lit(0))) - .collect()?; - - let in_excerpt_mask = not_retrieved_df - .column("gold_in_excerpt_count")? - .u32()? - .gt(0); - let in_excerpt_count = in_excerpt_mask.sum().unwrap_or(0); - - let missing_df = not_retrieved_df.filter(&!in_excerpt_mask)?; - - let missing_none_retrieved_mask = missing_df.column("candidate_count")?.u32()?.equal(0); - let missing_none_retrieved = missing_none_retrieved_mask.sum().unwrap_or(0); - let external_mask = missing_df.column("gold_is_external_count")?.u32()?.gt(0); - let missing_external = (missing_none_retrieved_mask & external_mask) - .sum() - .unwrap_or(0); - - let missing_wrong_retrieval = missing_df - .column("candidate_count")? - .u32()? - .gt(0) - .sum() - .unwrap_or(0); - - Ok(SummaryStats { - references_count, - retrieved_count, - top_match_count, - non_top_match_count, - ranking_involved_top_match_count, - missing_none_retrieved, - missing_wrong_retrieval, - missing_external, - in_excerpt_count, - }) - } - - fn count_and_percentage(part: u32, total: u32) -> String { - format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0) - } -} - -impl std::fmt::Display for SummaryStats { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let included = self.in_excerpt_count + self.retrieved_count; - let missing = self.references_count - included; - writeln!(f)?; - writeln!(f, "╮ references: {}", self.references_count)?; - writeln!( - f, - "├─╮ included: {}", - Self::count_and_percentage(included, self.references_count), - )?; - writeln!( - f, - "│ ├─╮ retrieved: {}", - Self::count_and_percentage(self.retrieved_count, self.references_count) - )?; - writeln!( - f, - "│ │ ├─╮ top match : {}", - Self::count_and_percentage(self.top_match_count, self.retrieved_count) - )?; - writeln!( - f, - "│ │ │ ╰─╴ involving ranking: {}", - Self::count_and_percentage(self.ranking_involved_top_match_count, self.top_match_count) - )?; - writeln!( - f, - "│ │ ╰─╴ non-top match: {}", - Self::count_and_percentage(self.non_top_match_count, self.retrieved_count) - )?; - writeln!( - f, - "│ ╰─╴ in excerpt: {}", - Self::count_and_percentage(self.in_excerpt_count, included) - )?; - writeln!( - f, - "╰─╮ missing: {}", - Self::count_and_percentage(missing, self.references_count) - )?; - writeln!( - f, - " ├─╮ none retrieved: {}", - Self::count_and_percentage(self.missing_none_retrieved, missing) - )?; - writeln!( - f, - " │ ╰─╴ external (expected): {}", - Self::count_and_percentage(self.missing_external, missing) - )?; - writeln!( - f, - " ╰─╴ wrong retrieval: {}", - Self::count_and_percentage(self.missing_wrong_retrieval, missing) - )?; - Ok(()) - } -} - -#[derive(Debug)] -struct ReferenceRetrievalResult { - cursor_path: Arc, - cursor_point: Point, - identifier: Identifier, - excerpt_range: Option>, - lsp_definitions: Vec, - retrieved_definitions: Vec, -} - -#[derive(Debug)] -struct RetrievedDefinition { - path: Arc, - range: Range, - score: f32, - #[allow(dead_code)] - retrieval_score: f32, - #[allow(dead_code)] - components: DeclarationScoreComponents, -} - -struct RetrieveResult { - definitions: Vec, - excerpt_range: Option>, -} - -async fn retrieve_definitions( - reference: &Reference, - imports: &Imports, - query_point: Point, - snapshot: &BufferSnapshot, - index: &Arc, - file_snapshots: &Arc>, - context_options: &EditPredictionContextOptions, -) -> Result { - let mut single_reference_map = HashMap::default(); - single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]); - let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn( - query_point, - snapshot, - imports, - &context_options, - Some(&index), - |_, _, _| single_reference_map, - ); - - let Some(edit_prediction_context) = edit_prediction_context else { - return Ok(RetrieveResult { - definitions: Vec::new(), - excerpt_range: None, - }); - }; - - let mut retrieved_definitions = Vec::new(); - for scored_declaration in edit_prediction_context.declarations { - match &scored_declaration.declaration { - Declaration::File { - project_entry_id, - declaration, - .. - } => { - let Some(snapshot) = file_snapshots.get(&project_entry_id) else { - log::error!("bug: file project entry not found"); - continue; - }; - let path = snapshot.file().unwrap().path().clone(); - retrieved_definitions.push(RetrievedDefinition { - path, - range: snapshot.offset_to_point(declaration.item_range.start) - ..snapshot.offset_to_point(declaration.item_range.end), - score: scored_declaration.score(DeclarationStyle::Declaration), - retrieval_score: scored_declaration.retrieval_score(), - components: scored_declaration.components, - }); - } - Declaration::Buffer { - project_entry_id, - rope, - declaration, - .. - } => { - let Some(snapshot) = file_snapshots.get(&project_entry_id) else { - // This case happens when dependency buffers have been opened by - // go-to-definition, resulting in single-file worktrees. - continue; - }; - let path = snapshot.file().unwrap().path().clone(); - retrieved_definitions.push(RetrievedDefinition { - path, - range: rope.offset_to_point(declaration.item_range.start) - ..rope.offset_to_point(declaration.item_range.end), - score: scored_declaration.score(DeclarationStyle::Declaration), - retrieval_score: scored_declaration.retrieval_score(), - components: scored_declaration.components, - }); - } - } - } - retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score))); - - Ok(RetrieveResult { - definitions: retrieved_definitions, - excerpt_range: Some(edit_prediction_context.excerpt.range), - }) -} - -async fn gather_lsp_definitions( - lsp_definitions_path: &Path, - start_index: usize, - files: &[ProjectPath], - worktree: &Entity, - project: &Entity, - definitions: &mut HashMap>, - cx: &mut AsyncApp, -) -> Result<()> { - let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?; - - let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?; - cx.subscribe(&lsp_store, { - move |_, event, _| { - if let project::LspStoreEvent::LanguageServerUpdate { - message: - client::proto::update_language_server::Variant::WorkProgress( - client::proto::LspWorkProgress { - message: Some(message), - .. - }, - ), - .. - } = event - { - println!("⟲ {message}") - } - } - })? - .detach(); - - let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::(); - - let cache_file = File::options() - .append(true) - .create(true) - .open(lsp_definitions_path) - .unwrap(); - - let cache_task = cx.background_spawn(async move { - let mut writer = BufWriter::new(cache_file); - while let Some(line) = cache_line_rx.next().await { - serde_json::to_writer(&mut writer, &line).unwrap(); - writer.write_all(&[b'\n']).unwrap(); - } - writer.flush().unwrap(); - }); - - let mut error_count = 0; - let mut lsp_open_handles = Vec::new(); - let mut ready_languages = HashSet::default(); - for (file_index, project_path) in files[start_index..].iter().enumerate() { - println!( - "Processing file {} of {}: {}", - start_index + file_index + 1, - files.len(), - project_path.path.display(PathStyle::Posix) - ); - - let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server( - project.clone(), - worktree.clone(), - project_path.path.clone(), - &mut ready_languages, - cx, - ) - .await - .log_err() else { - continue; - }; - lsp_open_handles.push(lsp_open_handle); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let full_range = 0..snapshot.len(); - let references = references_in_range( - full_range, - &snapshot.text(), - ReferenceRegion::Nearby, - &snapshot, - ); - - loop { - let is_ready = lsp_store - .read_with(cx, |lsp_store, _cx| { - lsp_store - .language_server_statuses - .get(&language_server_id) - .is_some_and(|status| status.pending_work.is_empty()) - }) - .unwrap(); - if is_ready { - break; - } - cx.background_executor() - .timer(Duration::from_millis(10)) - .await; - } - - let mut cache_line_references = Vec::with_capacity(references.len()); - - for reference in references { - // TODO: Rename declaration to definition in edit_prediction_context? - let lsp_result = project - .update(cx, |project, cx| { - project.definitions(&buffer, reference.range.start, cx) - })? - .await; - - match lsp_result { - Ok(lsp_definitions) => { - let mut targets = Vec::new(); - for target in lsp_definitions.unwrap_or_default() { - let buffer = target.target.buffer; - let anchor_range = target.target.range; - buffer.read_with(cx, |buffer, cx| { - let Some(file) = project::File::from_dyn(buffer.file()) else { - return; - }; - let file_worktree = file.worktree.read(cx); - let file_worktree_id = file_worktree.id(); - // Relative paths for worktree files, absolute for all others - let path = if worktree_id != file_worktree_id { - file.worktree.read(cx).absolutize(&file.path) - } else { - file.path.as_std_path().to_path_buf() - }; - let offset_range = anchor_range.to_offset(&buffer); - let point_range = SerializablePoint::from_language_point_range( - offset_range.to_point(&buffer), - ); - targets.push(SourceRange { - path, - offset_range, - point_range, - }); - })?; - } - - let point = snapshot.offset_to_point(reference.range.start); - - cache_line_references.push((point.into(), targets.clone())); - definitions.insert( - SourceLocation { - path: project_path.path.clone(), - point, - }, - targets, - ); - } - Err(err) => { - log::error!("Language server error: {err}"); - error_count += 1; - } - } - } - - cache_line_tx - .unbounded_send(FileLspDefinitions { - path: project_path.path.as_unix_str().into(), - references: cache_line_references, - }) - .log_err(); - } - - drop(cache_line_tx); - - if error_count > 0 { - log::error!("Encountered {} language server errors", error_count); - } - - cache_task.await; - - Ok(()) -} - -#[derive(Serialize, Deserialize)] -struct FileLspDefinitions { - path: Arc, - references: Vec<(SerializablePoint, Vec)>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct SourceRange { - path: PathBuf, - point_range: Range, - offset_range: Range, -} - -/// Serializes to 1-based row and column indices. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SerializablePoint { - pub row: u32, - pub column: u32, -} - -impl SerializablePoint { - pub fn into_language_point_range(range: Range) -> Range { - range.start.into()..range.end.into() - } - - pub fn from_language_point_range(range: Range) -> Range { - range.start.into()..range.end.into() - } -} - -impl From for SerializablePoint { - fn from(point: Point) -> Self { - SerializablePoint { - row: point.row + 1, - column: point.column + 1, - } - } -} - -impl From for Point { - fn from(serializable: SerializablePoint) -> Self { - Point { - row: serializable.row.saturating_sub(1), - column: serializable.column.saturating_sub(1), - } - } -} From d6241b17d35c4eae1d23dfbde16cae0eb8187ac2 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 4 Dec 2025 22:51:26 -0800 Subject: [PATCH 11/81] Fix infinite loop in assemble_excerpts (#44195) Also, expand the number of identifiers fetched. Release Notes: - N/A --- crates/edit_prediction/src/edit_prediction.rs | 8 +- .../src/assemble_excerpts.rs | 165 +---------------- .../src/edit_prediction_context.rs | 33 +++- .../src/edit_prediction_context_tests.rs | 174 +++++++++++++++++- 4 files changed, 204 insertions(+), 176 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index ddb29d0796a6c6b24ee3914533b29b967d224ac8..ea8f0af7e16dedd30a86284af5386829053d7fab 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -480,16 +480,16 @@ impl EditPredictionStore { shown_predictions: Default::default(), }; - this.enable_or_disable_context_retrieval(cx); + this.configure_context_retrieval(cx); let weak_this = cx.weak_entity(); cx.on_flags_ready(move |_, cx| { weak_this - .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx)) + .update(cx, |this, cx| this.configure_context_retrieval(cx)) .ok(); }) .detach(); cx.observe_global::(|this, cx| { - this.enable_or_disable_context_retrieval(cx); + this.configure_context_retrieval(cx); }) .detach(); @@ -1770,7 +1770,7 @@ impl EditPredictionStore { cx.notify(); } - fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) { + fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) { self.use_context = cx.has_flag::() && all_language_settings(None, cx).edit_predictions.use_context; } diff --git a/crates/edit_prediction_context/src/assemble_excerpts.rs b/crates/edit_prediction_context/src/assemble_excerpts.rs index b3b8d4f8bc480053a1e9ab9d498d5350039ed609..15f4c03d653429af671c22d6b5abc652d282a38e 100644 --- a/crates/edit_prediction_context/src/assemble_excerpts.rs +++ b/crates/edit_prediction_context/src/assemble_excerpts.rs @@ -61,8 +61,8 @@ pub fn assemble_excerpts( buffer, &mut outline_ranges, ); - child_outline_ix += 1; } + child_outline_ix += 1; } } } @@ -159,166 +159,3 @@ pub fn merge_ranges(ranges: &mut Vec>) { } } } - -#[cfg(test)] -mod tests { - use super::*; - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; - use pretty_assertions::assert_eq; - use std::{fmt::Write as _, sync::Arc}; - use util::test::marked_text_ranges; - - #[gpui::test] - fn test_rust(cx: &mut TestAppContext) { - let table = [ - ( - indoc! {r#" - struct User { - first_name: String, - «last_name»: String, - age: u32, - email: String, - create_at: Instant, - } - - impl User { - pub fn first_name(&self) -> String { - self.first_name.clone() - } - - pub fn full_name(&self) -> String { - « format!("{} {}", self.first_name, self.last_name) - » } - } - "#}, - indoc! {r#" - struct User { - first_name: String, - last_name: String, - … - } - - impl User { - … - pub fn full_name(&self) -> String { - format!("{} {}", self.first_name, self.last_name) - } - } - "#}, - ), - ( - indoc! {r#" - struct «User» { - first_name: String, - last_name: String, - age: u32, - } - - impl User { - // methods - } - "# - }, - indoc! {r#" - struct User { - first_name: String, - last_name: String, - age: u32, - } - … - "#}, - ), - ( - indoc! {r#" - trait «FooProvider» { - const NAME: &'static str; - - fn provide_foo(&self, id: usize) -> Foo; - - fn provide_foo_batched(&self, ids: &[usize]) -> Vec { - ids.iter() - .map(|id| self.provide_foo(*id)) - .collect() - } - - fn sync(&self); - } - "# - }, - indoc! {r#" - trait FooProvider { - const NAME: &'static str; - - fn provide_foo(&self, id: usize) -> Foo; - - fn provide_foo_batched(&self, ids: &[usize]) -> Vec { - … - } - - fn sync(&self); - } - "#}, - ), - ]; - - for (input, expected_output) in table { - let (input, ranges) = marked_text_ranges(&input, false); - let buffer = - cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); - buffer.read_with(cx, |buffer, _cx| { - let ranges: Vec> = ranges - .into_iter() - .map(|range| range.to_point(&buffer)) - .collect(); - - let excerpts = assemble_excerpts(&buffer.snapshot(), ranges); - - let output = format_excerpts(buffer, &excerpts); - assert_eq!(output, expected_output); - }); - } - } - - fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { - let mut output = String::new(); - let file_line_count = buffer.max_point().row; - let mut current_row = 0; - for excerpt in excerpts { - if excerpt.text.is_empty() { - continue; - } - if current_row < excerpt.point_range.start.row { - writeln!(&mut output, "…").unwrap(); - } - current_row = excerpt.point_range.start.row; - - for line in excerpt.text.to_string().lines() { - output.push_str(line); - output.push('\n'); - current_row += 1; - } - } - if current_row < file_line_count { - writeln!(&mut output, "…").unwrap(); - } - output - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(language::tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index e316c5a052acd241e7d33356bd0d5dfa5fd075bd..475050fabb8b17ad76c34234094cf798e36a76ab 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -25,11 +25,14 @@ mod fake_definition_lsp; pub use cloud_llm_client::predict_edits_v3::Line; pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +const IDENTIFIER_LINE_COUNT: u32 = 3; + pub struct RelatedExcerptStore { project: WeakEntity, related_files: Vec, cache: HashMap>, update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, + identifier_line_count: u32, } pub enum RelatedExcerptStoreEvent { @@ -178,9 +181,14 @@ impl RelatedExcerptStore { update_tx, related_files: Vec::new(), cache: Default::default(), + identifier_line_count: IDENTIFIER_LINE_COUNT, } } + pub fn set_identifier_line_count(&mut self, count: u32) { + self.identifier_line_count = count; + } + pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { self.update_tx.unbounded_send((buffer, position)).ok(); } @@ -195,8 +203,12 @@ impl RelatedExcerptStore { position: Anchor, cx: &mut AsyncApp, ) -> Result<()> { - let (project, snapshot) = this.read_with(cx, |this, cx| { - (this.project.upgrade(), buffer.read(cx).snapshot()) + let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| { + ( + this.project.upgrade(), + buffer.read(cx).snapshot(), + this.identifier_line_count, + ) })?; let Some(project) = project else { return Ok(()); @@ -212,7 +224,9 @@ impl RelatedExcerptStore { })?; let identifiers = cx - .background_spawn(async move { identifiers_for_position(&snapshot, position) }) + .background_spawn(async move { + identifiers_for_position(&snapshot, position, identifier_line_count) + }) .await; let async_cx = cx.clone(); @@ -393,14 +407,21 @@ fn process_definition( /// Gets all of the identifiers that are present in the given line, and its containing /// outline items. -fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec { +fn identifiers_for_position( + buffer: &BufferSnapshot, + position: Anchor, + identifier_line_count: u32, +) -> Vec { let offset = position.to_offset(buffer); let point = buffer.offset_to_point(offset); - let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point()); + // Search for identifiers on lines adjacent to the cursor. + let start = Point::new(point.row.saturating_sub(identifier_line_count), 0); + let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point()); + let line_range = start..end; let mut ranges = vec![line_range.to_offset(&buffer)]; - // Include the range of the outline item itself, but not its body. + // Search for identifiers mentioned in headers/signatures of containing outline items. let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); for item in outline_items { if let Some(body_range) = item.body_range(&buffer) { diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs index 05d1becc2167837a5f9741d77e7bc96c2f5b8d34..f62df37e551db19145e9ea631b6ab6a16fefda78 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -7,8 +7,8 @@ use lsp::FakeLanguageServer; use project::{FakeFs, LocationLink, Project}; use serde_json::json; use settings::SettingsStore; -use std::sync::Arc; -use util::path; +use std::{fmt::Write as _, sync::Arc}; +use util::{path, test::marked_text_ranges}; #[gpui::test] async fn test_edit_prediction_context(cx: &mut TestAppContext) { @@ -37,6 +37,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { buffer.anchor_before(offset) }; + store.set_identifier_line_count(0); store.refresh(buffer.clone(), position, cx); }); @@ -85,6 +86,150 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { }); } +#[gpui::test] +fn test_assemble_excerpts(cx: &mut TestAppContext) { + let table = [ + ( + indoc! {r#" + struct User { + first_name: String, + «last_name»: String, + age: u32, + email: String, + create_at: Instant, + } + + impl User { + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn full_name(&self) -> String { + « format!("{} {}", self.first_name, self.last_name) + » } + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + … + } + + impl User { + … + pub fn full_name(&self) -> String { + format!("{} {}", self.first_name, self.last_name) + } + } + "#}, + ), + ( + indoc! {r#" + struct «User» { + first_name: String, + last_name: String, + age: u32, + } + + impl User { + // methods + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + age: u32, + } + … + "#}, + ), + ( + indoc! {r#" + trait «FooProvider» { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + ids.iter() + .map(|id| self.provide_foo(*id)) + .collect() + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait FooProvider { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ( + indoc! {r#" + trait «Something» { + fn method1(&self, id: usize) -> Foo; + + fn method2(&self, ids: &[usize]) -> Vec { + struct Helper1 { + field1: usize, + } + + struct Helper2 { + field2: usize, + } + + struct Helper3 { + filed2: usize, + } + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait Something { + fn method1(&self, id: usize) -> Foo; + + fn method2(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ]; + + for (input, expected_output) in table { + let (input, ranges) = marked_text_ranges(&input, false); + let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx)); + buffer.read_with(cx, |buffer, _cx| { + let ranges: Vec> = ranges + .into_iter() + .map(|range| range.to_point(&buffer)) + .collect(); + + let excerpts = assemble_excerpts(&buffer.snapshot(), ranges); + + let output = format_excerpts(buffer, &excerpts); + assert_eq!(output, expected_output); + }); + } +} + #[gpui::test] async fn test_fake_definition_lsp(cx: &mut TestAppContext) { init_test(cx); @@ -339,6 +484,31 @@ fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &m assert_eq!(actual_first_lines, first_lines); } +fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { + let mut output = String::new(); + let file_line_count = buffer.max_point().row; + let mut current_row = 0; + for excerpt in excerpts { + if excerpt.text.is_empty() { + continue; + } + if current_row < excerpt.point_range.start.row { + writeln!(&mut output, "…").unwrap(); + } + current_row = excerpt.point_range.start.row; + + for line in excerpt.text.to_string().lines() { + output.push_str(line); + output.push('\n'); + current_row += 1; + } + } + if current_row < file_line_count { + writeln!(&mut output, "…").unwrap(); + } + output +} + pub(crate) fn rust_lang() -> Arc { Arc::new( Language::new( From 35da6d000aa047484bd5d489ecee6455c39cda57 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:08:04 +0100 Subject: [PATCH 12/81] debugger: Fix evaluate selection running two evaluations & failing for Python and go (#44205) Evaluate selection now acts as if the text was typed verbatim into the console. Closes ##33526 Release Notes: - debugger: Fixed "evaluate selection" not behaving as if the highlighted text was not typed verbatim into the console. --- crates/debugger_ui/src/debugger_ui.rs | 13 +++++++++++-- crates/debugger_ui/src/new_process_modal.rs | 2 +- crates/project/src/debugger/session.rs | 2 ++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index a9abb50bb68851334285b05064176e0347474014..bd5a7cda4a21a3d3fd0ac132d6ba2e7aace68722 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -387,7 +387,7 @@ pub fn init(cx: &mut App) { window.on_action( TypeId::of::(), move |_, _, window, cx| { - maybe!({ + let status = maybe!({ let text = editor .update(cx, |editor, cx| { let range = editor @@ -411,7 +411,13 @@ pub fn init(cx: &mut App) { state.session().update(cx, |session, cx| { session - .evaluate(text, None, stack_id, None, cx) + .evaluate( + text, + Some(dap::EvaluateArgumentsContext::Repl), + stack_id, + None, + cx, + ) .detach(); }); }); @@ -419,6 +425,9 @@ pub fn init(cx: &mut App) { Some(()) }); + if status.is_some() { + cx.stop_propagation(); + } }, ); }) diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 40187cef9cc55cb4192a3cea773f42dca15a2571..ca13e3eed5fd770e8b80f0cd5b8610ff1e9e2f51 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -1023,7 +1023,7 @@ impl DebugDelegate { Some(TaskSourceKind::Lsp { language_name, .. }) => { Some(format!("LSP: {language_name}")) } - Some(TaskSourceKind::Language { name }) => Some(format!("Lang: {name}")), + Some(TaskSourceKind::Language { name }) => Some(format!("Language: {name}")), _ => context.clone().and_then(|ctx| { ctx.task_context .task_variables diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index b5fbfd80d6152faf9d04715138859dc565e8cba8..47fe98827cbc163682ef6f002eff4008967d4ced 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -2656,6 +2656,8 @@ impl Session { this.update(cx, |this, cx| { this.memory.clear(cx.background_executor()); this.invalidate_command_type::(); + this.invalidate_command_type::(); + cx.emit(SessionEvent::Variables); match response { Ok(response) => { let event = dap::OutputEvent { From a5ab5c7d5dae496123624ed655ecd6ecd456b05a Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 5 Dec 2025 13:35:05 +0100 Subject: [PATCH 13/81] gpui: Document the leak detector (#44208) Release Notes: - N/A *or* Added/Fixed/Improved ... --- crates/gpui/src/app/entity_map.rs | 112 ++++++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 5 deletions(-) diff --git a/crates/gpui/src/app/entity_map.rs b/crates/gpui/src/app/entity_map.rs index 81dbfdbf5733eed92a77fc2dc18fb971bd9bd4a7..8c1bdfa1cee509dcbc061200cb651ce5d3bf4fcd 100644 --- a/crates/gpui/src/app/entity_map.rs +++ b/crates/gpui/src/app/entity_map.rs @@ -584,7 +584,33 @@ impl AnyWeakEntity { }) } - /// Assert that entity referenced by this weak handle has been released. + /// Asserts that the entity referenced by this weak handle has been fully released. + /// + /// # Example + /// + /// ```ignore + /// let entity = cx.new(|_| MyEntity::new()); + /// let weak = entity.downgrade(); + /// drop(entity); + /// + /// // Verify the entity was released + /// weak.assert_released(); + /// ``` + /// + /// # Debugging Leaks + /// + /// If this method panics due to leaked handles, set the `LEAK_BACKTRACE` environment + /// variable to see where the leaked handles were allocated: + /// + /// ```bash + /// LEAK_BACKTRACE=1 cargo test my_test + /// ``` + /// + /// # Panics + /// + /// - Panics if any strong handles to the entity are still alive. + /// - Panics if the entity was recently dropped but cleanup hasn't completed yet + /// (resources are retained until the end of the effect cycle). #[cfg(any(test, feature = "leak-detection"))] pub fn assert_released(&self) { self.entity_ref_counts @@ -814,16 +840,70 @@ impl PartialOrd for WeakEntity { } } +/// Controls whether backtraces are captured when entity handles are created. +/// +/// Set the `LEAK_BACKTRACE` environment variable to any non-empty value to enable +/// backtrace capture. This helps identify where leaked handles were allocated. #[cfg(any(test, feature = "leak-detection"))] static LEAK_BACKTRACE: std::sync::LazyLock = std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty())); +/// Unique identifier for a specific entity handle instance. +/// +/// This is distinct from `EntityId` - while multiple handles can point to the same +/// entity (same `EntityId`), each handle has its own unique `HandleId`. #[cfg(any(test, feature = "leak-detection"))] #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] pub(crate) struct HandleId { - id: u64, // id of the handle itself, not the pointed at object + id: u64, } +/// Tracks entity handle allocations to detect leaks. +/// +/// The leak detector is enabled in tests and when the `leak-detection` feature is active. +/// It tracks every `Entity` and `AnyEntity` handle that is created and released, +/// allowing you to verify that all handles to an entity have been properly dropped. +/// +/// # How do leaks happen? +/// +/// Entities are reference-counted structures that can own other entities +/// allowing to form cycles. If such a strong-reference counted cycle is +/// created, all participating strong entities in this cycle will effectively +/// leak as they cannot be released anymore. +/// +/// # Usage +/// +/// You can use `WeakEntity::assert_released` or `AnyWeakEntity::assert_released` +/// to verify that an entity has been fully released: +/// +/// ```ignore +/// let entity = cx.new(|_| MyEntity::new()); +/// let weak = entity.downgrade(); +/// drop(entity); +/// +/// // This will panic if any handles to the entity are still alive +/// weak.assert_released(); +/// ``` +/// +/// # Debugging Leaks +/// +/// When a leak is detected, the detector will panic with information about the leaked +/// handles. To see where the leaked handles were allocated, set the `LEAK_BACKTRACE` +/// environment variable: +/// +/// ```bash +/// LEAK_BACKTRACE=1 cargo test my_test +/// ``` +/// +/// This will capture and display backtraces for each leaked handle, helping you +/// identify where handles were created but not released. +/// +/// # How It Works +/// +/// - When an entity handle is created (via `Entity::new`, `Entity::clone`, or +/// `WeakEntity::upgrade`), `handle_created` is called to register the handle. +/// - When a handle is dropped, `handle_released` removes it from tracking. +/// - `assert_released` verifies that no handles remain for a given entity. #[cfg(any(test, feature = "leak-detection"))] pub(crate) struct LeakDetector { next_handle_id: u64, @@ -832,6 +912,11 @@ pub(crate) struct LeakDetector { #[cfg(any(test, feature = "leak-detection"))] impl LeakDetector { + /// Records that a new handle has been created for the given entity. + /// + /// Returns a unique `HandleId` that must be passed to `handle_released` when + /// the handle is dropped. If `LEAK_BACKTRACE` is set, captures a backtrace + /// at the allocation site. #[track_caller] pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId { let id = util::post_inc(&mut self.next_handle_id); @@ -844,23 +929,40 @@ impl LeakDetector { handle_id } + /// Records that a handle has been released (dropped). + /// + /// This removes the handle from tracking. The `handle_id` should be the same + /// one returned by `handle_created` when the handle was allocated. pub fn handle_released(&mut self, entity_id: EntityId, handle_id: HandleId) { let handles = self.entity_handles.entry(entity_id).or_default(); handles.remove(&handle_id); } + /// Asserts that all handles to the given entity have been released. + /// + /// # Panics + /// + /// Panics if any handles to the entity are still alive. The panic message + /// includes backtraces for each leaked handle if `LEAK_BACKTRACE` is set, + /// otherwise it suggests setting the environment variable to get more info. pub fn assert_released(&mut self, entity_id: EntityId) { + use std::fmt::Write as _; let handles = self.entity_handles.entry(entity_id).or_default(); if !handles.is_empty() { + let mut out = String::new(); for backtrace in handles.values_mut() { if let Some(mut backtrace) = backtrace.take() { backtrace.resolve(); - eprintln!("Leaked handle: {:#?}", backtrace); + writeln!(out, "Leaked handle:\n{:?}", backtrace).unwrap(); } else { - eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site"); + writeln!( + out, + "Leaked handle: (export LEAK_BACKTRACE to find allocation site)" + ) + .unwrap(); } } - panic!(); + panic!("{out}"); } } } From 126d708fa1dc493e2d0fafb477bd814d40df0238 Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:59:13 -0500 Subject: [PATCH 14/81] git: Fix branch picker creating new branches with refs/head/ prefixed on the branch name (#44206) The bug was introduced in this recent PR: https://github.com/zed-industries/zed/pull/42819. Since it's still in nightly, there is no need for release notes. I also polished the feature a bit by: - Ensuring branch names are always a single line so the branch picker's uniform list uses the correct element height. - Adding tooltip text for the filter remotes button. - Fixing the create branch from default icon showing up for non-new branch entries. Release Notes: - N/A --- assets/keymaps/default-linux.json | 3 +- assets/keymaps/default-macos.json | 3 +- assets/keymaps/default-windows.json | 3 +- crates/fs/src/fake_git_repo.rs | 17 ++++-- crates/git_ui/src/branch_picker.rs | 89 +++++++++++++++++------------ 5 files changed, 69 insertions(+), 46 deletions(-) diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 0b001f31790c7f8211a6b44d227c15a6ff605ca4..41415bf2047e1faadd86dd5be159f526d6c57678 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -1332,7 +1332,8 @@ "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "ctrl-shift-backspace": "branch_picker::DeleteBranch" + "ctrl-shift-backspace": "branch_picker::DeleteBranch", + "ctrl-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index e4595242d570628e2e70c43b66d14a0f9820512b..fa8edbe5c23b008eb2c267850e440a851c54087d 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -1437,7 +1437,8 @@ "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "cmd-shift-backspace": "branch_picker::DeleteBranch" + "cmd-shift-backspace": "branch_picker::DeleteBranch", + "cmd-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index b625e7c7018c0f4c8277fcf3f739a8f06361c4df..45f37fbd41af3fcc3108f0ffe150a80ff25332e1 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -1351,7 +1351,8 @@ "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "ctrl-shift-backspace": "branch_picker::DeleteBranch" + "ctrl-shift-backspace": "branch_picker::DeleteBranch", + "ctrl-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 3bc411ff2d9b917fd409c29cca03d2191ee80978..6ca8b5a58f9f8f75023aa73e7a80e8547eb156f3 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -381,11 +381,18 @@ impl GitRepository for FakeGitRepository { Ok(state .branches .iter() - .map(|branch_name| Branch { - is_head: Some(branch_name) == current_branch.as_ref(), - ref_name: branch_name.into(), - most_recent_commit: None, - upstream: None, + .map(|branch_name| { + let ref_name = if branch_name.starts_with("refs/") { + branch_name.into() + } else { + format!("refs/heads/{branch_name}").into() + }; + Branch { + is_head: Some(branch_name) == current_branch.as_ref(), + ref_name, + most_recent_commit: None, + upstream: None, + } }) .collect()) }) diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 42e043cada2813126af3489c9769aca9c675999f..33b852c1de9b1bd1a8abcc36dff964d14cbe1807 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -770,7 +770,7 @@ impl PickerDelegate for BranchListDelegate { } else { None }; - self.create_branch(from_branch, format!("refs/heads/{name}").into(), window, cx); + self.create_branch(from_branch, name.into(), window, cx); } } @@ -812,28 +812,21 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None, None)); - let icon = if let Some(default_branch) = self.default_branch.clone() { - let icon = match &entry { - Entry::Branch { .. } => Some(( - IconName::GitBranchAlt, - format!("Create branch based off default: {default_branch}"), - )), - Entry::NewUrl { url } => { - Some((IconName::Screen, format!("Create remote based off {url}"))) - } - Entry::NewBranch { .. } => None, - }; + let icon = if let Some(default_branch) = self.default_branch.clone() + && matches!(entry, Entry::NewBranch { .. }) + { + let tooltip_text = format!("Create branch based off default: {default_branch}"); - icon.map(|(icon, tooltip_text)| { - IconButton::new("branch-from-default", icon) + Some( + IconButton::new("branch-from-default", IconName::GitBranchAlt) .on_click(cx.listener(move |this, _, window, cx| { this.delegate.set_selected_index(ix, window, cx); this.delegate.confirm(true, window, cx); })) .tooltip(move |_window, cx| { Tooltip::for_action(tooltip_text.clone(), &menu::SecondaryConfirm, cx) - }) - }) + }), + ) } else { None }; @@ -875,7 +868,9 @@ impl PickerDelegate for BranchListDelegate { .max_w_48() .child(h_flex().mr_1().child(icon_element)) .child( - HighlightedLabel::new(branch.name().to_string(), positions.clone()).truncate(), + HighlightedLabel::new(branch.name().to_string(), positions.clone()) + .single_line() + .truncate(), ) .into_any_element(), }; @@ -962,18 +957,13 @@ impl PickerDelegate for BranchListDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - if matches!( - self.state, - PickerState::CreateRemote(_) | PickerState::NewRemote | PickerState::NewBranch - ) { - return None; - } - let label = if self.display_remotes { - "Remote" - } else { - "Local" - }; - Some( + matches!(self.state, PickerState::List).then(|| { + let label = if self.display_remotes { + "Remote" + } else { + "Local" + }; + h_flex() .w_full() .p_1p5() @@ -981,8 +971,8 @@ impl PickerDelegate for BranchListDelegate { .border_t_1() .border_color(cx.theme().colors().border_variant) .child(Label::new(label).size(LabelSize::Small).color(Color::Muted)) - .into_any(), - ) + .into_any() + }) } fn render_footer(&self, _: &mut Window, cx: &mut Context>) -> Option { @@ -1010,7 +1000,8 @@ impl PickerDelegate for BranchListDelegate { .border_t_1() .border_color(cx.theme().colors().border_variant) .justify_between() - .child( + .child({ + let focus_handle = focus_handle.clone(); Button::new("filter-remotes", "Filter remotes") .key_binding( KeyBinding::for_action_in( @@ -1028,8 +1019,26 @@ impl PickerDelegate for BranchListDelegate { }) .disabled(self.loading) .style(ButtonStyle::Subtle) - .toggle_state(self.display_remotes), - ) + .toggle_state(self.display_remotes) + .tooltip({ + let state = self.display_remotes; + + move |_window, cx| { + let tooltip_text = if state { + "Show local branches" + } else { + "Show remote branches" + }; + + Tooltip::for_action_in( + tooltip_text, + &branch_picker::FilterRemotes, + &focus_handle, + cx, + ) + } + }) + }) .child( Button::new("delete-branch", "Delete") .key_binding( @@ -1527,10 +1536,14 @@ mod tests { .unwrap() .unwrap(); - assert!( - branches - .into_iter() - .any(|branch| branch.name() == "new-feature-branch") + let new_branch = branches + .into_iter() + .find(|branch| branch.name() == "new-feature-branch") + .expect("new-feature-branch should exist"); + assert_eq!( + new_branch.ref_name.as_ref(), + "refs/heads/new-feature-branch", + "branch ref_name should not have duplicate refs/heads/ prefix" ); } From 822fc7ef167e7a26358afe7f13d0b05b4df468eb Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 5 Dec 2025 10:04:01 -0300 Subject: [PATCH 15/81] remote: Use last line of `uname` and shell output (#44165) We have seen cases (see https://github.com/zed-industries/zed/issues/43694) where the user's shell initialization script includes text that ends up in the output of the commands we use to detect the platform and shell of the remote. This solution isn't perfect, but it should address the issue in most situations since both commands should only output one line. Release Notes: - remote: Improve resiliency when initialization scripts output text --- crates/remote/src/transport/ssh.rs | 150 +++++++++++++++++++++-------- 1 file changed, 109 insertions(+), 41 deletions(-) diff --git a/crates/remote/src/transport/ssh.rs b/crates/remote/src/transport/ssh.rs index 20cd0c5ff4b427d3a37882603ce2962db9e4e1e0..56f29be092b5ed6ab4993664eb256056837047f5 100644 --- a/crates/remote/src/transport/ssh.rs +++ b/crates/remote/src/transport/ssh.rs @@ -1055,57 +1055,74 @@ impl SshSocket { } async fn platform(&self, shell: ShellKind) -> Result { - let uname = self.run_command(shell, "uname", &["-sm"], false).await?; - let Some((os, arch)) = uname.split_once(" ") else { - anyhow::bail!("unknown uname: {uname:?}") - }; - - let os = match os.trim() { - "Darwin" => "macos", - "Linux" => "linux", - _ => anyhow::bail!( - "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development" - ), - }; - // exclude armv5,6,7 as they are 32-bit. - let arch = if arch.starts_with("armv8") - || arch.starts_with("armv9") - || arch.starts_with("arm64") - || arch.starts_with("aarch64") - { - "aarch64" - } else if arch.starts_with("x86") { - "x86_64" - } else { - anyhow::bail!( - "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development" - ) - }; - - Ok(RemotePlatform { os, arch }) + let output = self.run_command(shell, "uname", &["-sm"], false).await?; + parse_platform(&output) } async fn shell(&self) -> String { - let default_shell = "sh"; match self .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false) .await { - Ok(shell) => match shell.trim() { - "" => { - log::error!("$SHELL is not set, falling back to {default_shell}"); - default_shell.to_owned() - } - shell => shell.to_owned(), - }, + Ok(output) => parse_shell(&output), Err(e) => { log::error!("Failed to get shell: {e}"); - default_shell.to_owned() + DEFAULT_SHELL.to_owned() } } } } +const DEFAULT_SHELL: &str = "sh"; + +/// Parses the output of `uname -sm` to determine the remote platform. +/// Takes the last line to skip possible shell initialization output. +fn parse_platform(output: &str) -> Result { + let output = output.trim(); + let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last); + let Some((os, arch)) = uname.split_once(" ") else { + anyhow::bail!("unknown uname: {uname:?}") + }; + + let os = match os { + "Darwin" => "macos", + "Linux" => "linux", + _ => anyhow::bail!( + "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development" + ), + }; + + // exclude armv5,6,7 as they are 32-bit. + let arch = if arch.starts_with("armv8") + || arch.starts_with("armv9") + || arch.starts_with("arm64") + || arch.starts_with("aarch64") + { + "aarch64" + } else if arch.starts_with("x86") { + "x86_64" + } else { + anyhow::bail!( + "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development" + ) + }; + + Ok(RemotePlatform { os, arch }) +} + +/// Parses the output of `echo $SHELL` to determine the remote shell. +/// Takes the last line to skip possible shell initialization output. +fn parse_shell(output: &str) -> String { + let output = output.trim(); + let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last); + if shell.is_empty() { + log::error!("$SHELL is not set, falling back to {DEFAULT_SHELL}"); + DEFAULT_SHELL.to_owned() + } else { + shell.to_owned() + } +} + fn parse_port_number(port_str: &str) -> Result { port_str .parse() @@ -1502,12 +1519,63 @@ mod tests { "-p".to_string(), "2222".to_string(), "-o".to_string(), - "StrictHostKeyChecking=no".to_string() + "StrictHostKeyChecking=no".to_string(), ] ); - assert!( - scp_args.iter().all(|arg| !arg.starts_with("-L")), - "scp args should not contain port forward flags: {scp_args:?}" + } + + #[test] + fn test_parse_platform() { + let result = parse_platform("Linux x86_64\n").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + let result = parse_platform("Darwin arm64\n").unwrap(); + assert_eq!(result.os, "macos"); + assert_eq!(result.arch, "aarch64"); + + let result = parse_platform("Linux x86_64").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "aarch64"); + + let result = parse_platform("some shell init output\nLinux aarch64").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "aarch64"); + + assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64"); + assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64"); + assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64"); + + let result = parse_platform( + r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#, + ) + .unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + assert!(parse_platform("Windows x86_64\n").is_err()); + assert!(parse_platform("Linux armv7l\n").is_err()); + } + + #[test] + fn test_parse_shell() { + assert_eq!(parse_shell("/bin/bash\n"), "/bin/bash"); + assert_eq!(parse_shell("/bin/zsh\n"), "/bin/zsh"); + + assert_eq!(parse_shell("/bin/bash"), "/bin/bash"); + assert_eq!( + parse_shell("some shell init output\n/bin/bash\n"), + "/bin/bash" + ); + assert_eq!( + parse_shell("some shell init output\n/bin/bash"), + "/bin/bash" ); + assert_eq!(parse_shell(""), "sh"); + assert_eq!(parse_shell("\n"), "sh"); } } From c7ef3025e42c1e16f16011ee7330856be9438e67 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 5 Dec 2025 12:16:46 -0300 Subject: [PATCH 16/81] remoting: Server download connect timeout (#44216) Sometimes machines are configured to drop outbound packets (rather than reject connections). In these cases, curl/wget just hang causing our download step to never complete. This PR adds a timeout of 10s for the connection (not the whole download), so that in situations like this we can fallback to our client-side download eventually. Related to but doesn't fully fix: https://github.com/zed-industries/zed/issues/43694 and https://github.com/zed-industries/zed/issues/43718 Release Notes: - remote: Add 10s connect timeout for server download --- crates/remote/src/transport/ssh.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crates/remote/src/transport/ssh.rs b/crates/remote/src/transport/ssh.rs index 56f29be092b5ed6ab4993664eb256056837047f5..bdc9cda08a9634258a4e18532556c1cde2bf8f32 100644 --- a/crates/remote/src/transport/ssh.rs +++ b/crates/remote/src/transport/ssh.rs @@ -668,6 +668,8 @@ impl SshRemoteConnection { delegate.set_status(Some("Downloading remote development server on host"), cx); + const CONNECT_TIMEOUT_SECS: &str = "10"; + match self .socket .run_command( @@ -676,6 +678,8 @@ impl SshRemoteConnection { &[ "-f", "-L", + "--connect-timeout", + CONNECT_TIMEOUT_SECS, url, "-o", &tmp_path_gz.display(self.path_style()), @@ -701,7 +705,15 @@ impl SshRemoteConnection { .run_command( self.ssh_shell_kind, "wget", - &[url, "-O", &tmp_path_gz.display(self.path_style())], + &[ + "--connect-timeout", + CONNECT_TIMEOUT_SECS, + "--tries", + "1", + url, + "-O", + &tmp_path_gz.display(self.path_style()), + ], true, ) .await From 1d0aef6b2229acc81de7708c315472fb0e7c627c Mon Sep 17 00:00:00 2001 From: Dino Date: Fri, 5 Dec 2025 15:24:07 +0000 Subject: [PATCH 17/81] Ensure font features are applied to styled text (#44219) - Replace `gpui::styled::Styled.font_family()` calls with `gpui::styled::Styled.font()` when laying out inline diagnostics and inline blame, to ensure that the font's features are also used, and not just the font feature. - Update both `editor::hover_popover::hover_markdown_style` and `editor::hover_popover::diagnostics_markdown_style` to ensure that both the UI and Buffer font features are used in both markdown and diagnostics popover. Closes #44209 Release Notes: - Fixed font feature application for inline git blame, inline diagnostics, markdown popovers and diagnostics popovers --- crates/editor/src/element.rs | 2 +- crates/editor/src/hover_popover.rs | 8 ++++++++ crates/git_ui/src/blame_ui.rs | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index fb9dc31a94441c81bccedfea66e2881acaf7ed82..edb3778ff94809ef880ffa167f2ff410a3199a37 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -2340,7 +2340,7 @@ impl EditorElement { .opacity(0.05)) .text_color(severity_to_color(&diagnostic_to_render.severity).color(cx)) .text_sm() - .font_family(style.text.font().family) + .font(style.text.font()) .child(diagnostic_to_render.message.clone()) .into_any(); diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index caabe6e6f5ab6ae80b3ead9d72fdcbec59937ff6..9ef54139d39ece6e9414d8fee3c7a75c9a89036d 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -607,13 +607,16 @@ async fn parse_blocks( pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let settings = ThemeSettings::get_global(cx); let ui_font_family = settings.ui_font.family.clone(); + let ui_font_features = settings.ui_font.features.clone(); let ui_font_fallbacks = settings.ui_font.fallbacks.clone(); let buffer_font_family = settings.buffer_font.family.clone(); + let buffer_font_features = settings.buffer_font.features.clone(); let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone(); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(ui_font_family), + font_features: Some(ui_font_features), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -624,6 +627,7 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { inline_code: TextStyleRefinement { background_color: Some(cx.theme().colors().background), font_family: Some(buffer_font_family), + font_features: Some(buffer_font_features), font_fallbacks: buffer_font_fallbacks, ..Default::default() }, @@ -657,12 +661,15 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let settings = ThemeSettings::get_global(cx); let ui_font_family = settings.ui_font.family.clone(); let ui_font_fallbacks = settings.ui_font.fallbacks.clone(); + let ui_font_features = settings.ui_font.features.clone(); let buffer_font_family = settings.buffer_font.family.clone(); + let buffer_font_features = settings.buffer_font.features.clone(); let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone(); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(ui_font_family), + font_features: Some(ui_font_features), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -673,6 +680,7 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { inline_code: TextStyleRefinement { background_color: Some(cx.theme().colors().editor_background.opacity(0.5)), font_family: Some(buffer_font_family), + font_features: Some(buffer_font_features), font_fallbacks: buffer_font_fallbacks, ..Default::default() }, diff --git a/crates/git_ui/src/blame_ui.rs b/crates/git_ui/src/blame_ui.rs index 47703e09824a49c633798c7967652d7f48f821be..c904c4b3b7cba499f6a81399a1ff87d2108f3012 100644 --- a/crates/git_ui/src/blame_ui.rs +++ b/crates/git_ui/src/blame_ui.rs @@ -148,7 +148,7 @@ impl BlameRenderer for GitBlameRenderer { h_flex() .id("inline-blame") .w_full() - .font_family(style.font().family) + .font(style.font()) .text_color(cx.theme().status().hint) .line_height(style.line_height) .child(Icon::new(IconName::FileGit).color(Color::Hint)) From b776178b52aa46e3aaed2720f019295def7eae45 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 5 Dec 2025 16:50:32 +0100 Subject: [PATCH 18/81] agent_ui: Fix mention and slash command menu not appearing with show_completions_on_input set to false (#44222) Addresses a regression introduced by https://github.com/zed-industries/zed/pull/44021 that caused @mentions and slash commands to stop working if you set `show_completions_on_input: false` in your settings. In this case, we should always show these menus, otherwise the features won't work at all. Release Notes: - N/A --- crates/agent_ui/src/acp/message_editor.rs | 1 + crates/editor/src/editor.rs | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index ae634e45dc17cc471d9ac621faf5b98c0a754c2b..827990599912fe832d40605fb1dceb58eab4ff2f 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -124,6 +124,7 @@ impl MessageEditor { let mut editor = Editor::new(mode, buffer, None, window, cx); editor.set_placeholder_text(placeholder, window, cx); editor.set_show_indent_guides(false, cx); + editor.set_show_completions_on_input(Some(true)); editor.set_soft_wrap(); editor.set_use_modal_editing(true); editor.set_context_menu_options(ContextMenuOptions { diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 6651cce374001865d21dfdb182659f2a8c008305..a4a8a5e02baad4e3306278ed11709d3527e868ce 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1128,6 +1128,7 @@ pub struct Editor { edit_prediction_settings: EditPredictionSettings, edit_predictions_hidden_for_vim_mode: bool, show_edit_predictions_override: Option, + show_completions_on_input_override: Option, menu_edit_predictions_policy: MenuEditPredictionsPolicy, edit_prediction_preview: EditPredictionPreview, edit_prediction_indent_conflict: bool, @@ -2275,6 +2276,7 @@ impl Editor { editor_actions: Rc::default(), edit_predictions_hidden_for_vim_mode: false, show_edit_predictions_override: None, + show_completions_on_input_override: None, menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider, edit_prediction_settings: EditPredictionSettings::Disabled, edit_prediction_indent_conflict: false, @@ -3157,6 +3159,10 @@ impl Editor { } } + pub fn set_show_completions_on_input(&mut self, show_completions_on_input: Option) { + self.show_completions_on_input_override = show_completions_on_input; + } + pub fn set_show_edit_predictions( &mut self, show_edit_predictions: Option, @@ -5533,7 +5539,10 @@ impl Editor { let language_settings = language_settings(language.clone(), buffer_snapshot.file(), cx); let completion_settings = language_settings.completions.clone(); - if !menu_is_open && trigger.is_some() && !language_settings.show_completions_on_input { + let show_completions_on_input = self + .show_completions_on_input_override + .unwrap_or(language_settings.show_completions_on_input); + if !menu_is_open && trigger.is_some() && !show_completions_on_input { return; } From 07fe8e9bb1484b2771d8a9d80f7fc370cee9c4ac Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 5 Dec 2025 13:47:29 -0300 Subject: [PATCH 19/81] remoting: Proxy configuration docs (#44225) Adds an explicit section about how to configure proxies when remoting. Release Notes: - N/A --- docs/src/remote-development.md | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/src/remote-development.md b/docs/src/remote-development.md index f046fa44334554230d19f885e6e38ab0274f2b44..c25d160a17549f6338f25741afd68391cf88d769 100644 --- a/docs/src/remote-development.md +++ b/docs/src/remote-development.md @@ -174,14 +174,38 @@ When opening a remote project there are three relevant settings locations: Both the local Zed and the server Zed read the project settings, but they are not aware of the other's main `settings.json`. -Depending on the kind of setting you want to make, which settings file you should use: +Which settings file you should use depends on the kind of setting you want to make: - Project settings should be used for things that affect the project: indentation settings, which formatter / language server to use, etc. -- Server settings should be used for things that affect the server: paths to language servers, etc. +- Server settings should be used for things that affect the server: paths to language servers, proxy settings, etc. - Local settings should be used for things that affect the UI: font size, etc. In addition any extensions you have installed locally will be propagated to the remote server. This means that language servers, etc. will run correctly. +## Proxy Configuration + +The remote server will not use your local machine's proxy configuration because they may be under different network policies. If your remote server requires a proxy to access the internet, you must configure it on the remote server itself. + +In most cases, your remote server will already have proxy environment variables configured. Zed will automatically use them when downloading language servers, communicating with LLM models, etc. + +If needed, you can set these environment variables in the server's shell configuration (e.g., `~/.bashrc`): + +```bash +export http_proxy="http://proxy.example.com:8080" +export https_proxy="http://proxy.example.com:8080" +export no_proxy="localhost,127.0.0.1" +``` + +Alternatively, you can configure the proxy in the remote machine's `~/.config/zed/settings.json` (Linux) or `~/.zed/settings.json` (macOS): + +```json +{ + "proxy": "http://proxy.example.com:8080" +} +``` + +See the [proxy documentation](./configuring-zed.md#network-proxy) for supported proxy types and additional configuration options. + ## Initializing the remote server Once you provide the SSH options, Zed shells out to `ssh` on your local machine to create a ControlMaster connection with the options you provide. From b558be7ec60b265837e34d6f9b6f0ef176c20082 Mon Sep 17 00:00:00 2001 From: David Kleingeld Date: Fri, 5 Dec 2025 18:23:06 +0100 Subject: [PATCH 20/81] adds tracing for instrumenting non-async functions (#44147) Tracing code is not included in normal release builds Documents how to use them in our performance docs Only the maps and cursors are instrumented atm # Compile times: current main: fresh release build (cargo clean then build --release) 377.34 secs current main: fresh debug build (cargo clean then build ) 89.31 secs tracing tracy: fresh release build (cargo clean then build --release) 374.84 secs tracing tracy: fresh debug build (cargo clean then build ) 88.95 secs tracing tracy: fresh release build with timings (cargo clean then build --release --features tracing) 375.77 secs tracing tracy: fresh debug build with timings (cargo clean then build --features tracing) 90.03 secs Release Notes: - N/A --------- Co-authored-by: localcc --- Cargo.lock | 103 ++- Cargo.toml | 6 + crates/collab/Cargo.toml | 2 +- crates/editor/Cargo.toml | 3 + crates/editor/src/display_map/block_map.rs | 42 + crates/editor/src/display_map/crease_map.rs | 17 + .../src/display_map/custom_highlights.rs | 3 + crates/editor/src/display_map/fold_map.rs | 35 + crates/editor/src/display_map/inlay_map.rs | 28 + crates/editor/src/display_map/invisibles.rs | 1 + crates/editor/src/display_map/tab_map.rs | 27 + crates/editor/src/display_map/wrap_map.rs | 38 + crates/git_ui/Cargo.toml | 7 +- crates/git_ui/src/project_diff.rs | 2 + crates/multi_buffer/Cargo.toml | 5 + crates/multi_buffer/src/multi_buffer.rs | 5 + crates/multi_buffer/src/path_key.rs | 872 +++++++++--------- crates/project/Cargo.toml | 5 + crates/project/src/git_store/branch_diff.rs | 3 + crates/rope/Cargo.toml | 5 + crates/rope/src/rope.rs | 2 + crates/sum_tree/Cargo.toml | 5 + crates/sum_tree/src/cursor.rs | 5 + crates/sum_tree/src/sum_tree.rs | 3 + crates/zed/Cargo.toml | 4 +- crates/zed/src/main.rs | 3 +- crates/ztracing/Cargo.toml | 17 + crates/ztracing/LICENSE-AGPL | 1 + crates/ztracing/LICENSE-APACHE | 1 + crates/ztracing/LICENSE-GPL | 1 + crates/ztracing/build.rs | 9 + crates/ztracing/src/lib.rs | 16 + crates/ztracing_macro/Cargo.toml | 11 + crates/ztracing_macro/LICENSE-AGPL | 1 + crates/ztracing_macro/LICENSE-APACHE | 1 + crates/ztracing_macro/LICENSE-GPL | 1 + crates/ztracing_macro/src/lib.rs | 7 + docs/src/performance.md | 52 +- 38 files changed, 898 insertions(+), 451 deletions(-) create mode 100644 crates/ztracing/Cargo.toml create mode 120000 crates/ztracing/LICENSE-AGPL create mode 120000 crates/ztracing/LICENSE-APACHE create mode 120000 crates/ztracing/LICENSE-GPL create mode 100644 crates/ztracing/build.rs create mode 100644 crates/ztracing/src/lib.rs create mode 100644 crates/ztracing_macro/Cargo.toml create mode 120000 crates/ztracing_macro/LICENSE-AGPL create mode 120000 crates/ztracing_macro/LICENSE-APACHE create mode 120000 crates/ztracing_macro/LICENSE-GPL create mode 100644 crates/ztracing_macro/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 885fbe77fd17a90e4cc948d4c40166d41a26cd35..a8f0096a7a1219ee30b287c61efd9f77f4b9d223 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5344,6 +5344,7 @@ dependencies = [ "text", "theme", "time", + "tracing", "tree-sitter-bash", "tree-sitter-c", "tree-sitter-html", @@ -5363,6 +5364,7 @@ dependencies = [ "workspace", "zed_actions", "zlog", + "ztracing", ] [[package]] @@ -6824,6 +6826,20 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generator" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "605183a538e3e2a9c1038635cc5c2d194e2ee8fd0d1b66b8349fad7dbacce5a2" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -7042,6 +7058,7 @@ dependencies = [ "theme", "time", "time_format", + "tracing", "ui", "unindent", "util", @@ -7051,6 +7068,7 @@ dependencies = [ "zed_actions", "zeroize", "zlog", + "ztracing", ] [[package]] @@ -9373,6 +9391,19 @@ dependencies = [ "value-bag", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "loop9" version = "0.1.5" @@ -10043,9 +10074,11 @@ dependencies = [ "sum_tree", "text", "theme", + "tracing", "tree-sitter", "util", "zlog", + "ztracing", ] [[package]] @@ -12400,6 +12433,7 @@ dependencies = [ "terminal", "text", "toml 0.8.23", + "tracing", "unindent", "url", "util", @@ -12409,6 +12443,7 @@ dependencies = [ "worktree", "zeroize", "zlog", + "ztracing", ] [[package]] @@ -13677,9 +13712,11 @@ dependencies = [ "rand 0.9.2", "rayon", "sum_tree", + "tracing", "unicode-segmentation", "util", "zlog", + "ztracing", ] [[package]] @@ -15615,7 +15652,9 @@ dependencies = [ "log", "rand 0.9.2", "rayon", + "tracing", "zlog", + "ztracing", ] [[package]] @@ -17100,9 +17139,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -17112,9 +17151,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -17123,9 +17162,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -17154,9 +17193,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -17173,6 +17212,38 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-tracy" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eaa1852afa96e0fe9e44caa53dc0bd2d9d05e0f2611ce09f97f8677af56e4ba" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracy-client", +] + +[[package]] +name = "tracy-client" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d722a05fe49b31fef971c4732a7d4aa6a18283d9ba46abddab35f484872947" +dependencies = [ + "loom", + "once_cell", + "tracy-client-sys", +] + +[[package]] +name = "tracy-client-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fb391ac70462b3097a755618fbf9c8f95ecc1eb379a414f7b46f202ed10db1f" +dependencies = [ + "cc", + "windows-targets 0.52.6", +] + [[package]] name = "trait-variant" version = "0.1.2" @@ -20515,6 +20586,7 @@ dependencies = [ "time", "title_bar", "toolchain_selector", + "tracing", "tree-sitter-md", "tree-sitter-rust", "ui", @@ -20537,6 +20609,7 @@ dependencies = [ "zed_env_vars", "zlog", "zlog_settings", + "ztracing", ] [[package]] @@ -20931,6 +21004,20 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "ztracing" +version = "0.1.0" +dependencies = [ + "tracing", + "tracing-subscriber", + "tracing-tracy", + "ztracing_macro", +] + +[[package]] +name = "ztracing_macro" +version = "0.1.0" + [[package]] name = "zune-core" version = "0.4.12" diff --git a/Cargo.toml b/Cargo.toml index 83bc42e353f6462148abe15327373a3d57a029e8..858da1dc460cda2fecbaf2ed94d437bfd25d9644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -204,6 +204,8 @@ members = [ "crates/edit_prediction_cli", "crates/zlog", "crates/zlog_settings", + "crates/ztracing", + "crates/ztracing_macro", # # Extensions @@ -434,6 +436,8 @@ zed_env_vars = { path = "crates/zed_env_vars" } edit_prediction = { path = "crates/edit_prediction" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } +ztracing = { path = "crates/ztracing" } +ztracing_macro = { path = "crates/ztracing_macro" } # # External crates @@ -694,6 +698,8 @@ tree-sitter-ruby = "0.23" tree-sitter-rust = "0.24" tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347 tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" } +tracing = "0.1.40" +tracing-tracy = "0.11.4" unicase = "2.6" unicode-script = "0.5.7" unicode-segmentation = "1.10" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b8a4c035499d45adc494c9f8175a772d15aa96df..79fc21fe33423d7eb887744b4ad84094a022862e 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -65,7 +65,7 @@ tokio = { workspace = true, features = ["full"] } toml.workspace = true tower = "0.4" tower-http = { workspace = true, features = ["trace"] } -tracing = "0.1.40" +tracing.workspace = true tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927 util.workspace = true uuid.workspace = true diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 94c9fb10f50f8e0440b2e91cf0c16d1f701d9451..f3ed28ab05c6839a478ebbf6c81ca5e66fc372e3 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -84,6 +84,8 @@ tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } tree-sitter-python = { workspace = true, optional = true } +ztracing.workspace = true +tracing.workspace = true unicode-segmentation.workspace = true unicode-script.workspace = true unindent = { workspace = true, optional = true } @@ -94,6 +96,7 @@ uuid.workspace = true vim_mode_setting.workspace = true workspace.workspace = true zed_actions.workspace = true +zlog.workspace = true [dev-dependencies] criterion.workspace = true diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index a6744041971101dafa4957523fb7a16250f38996..79d06dbf8b6e27cffffd47d6637c83eadcb00424 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -164,6 +164,7 @@ impl BlockPlacement { } impl BlockPlacement { + #[ztracing::instrument(skip_all)] fn cmp(&self, other: &Self, buffer: &MultiBufferSnapshot) -> Ordering { self.start() .cmp(other.start(), buffer) @@ -171,6 +172,7 @@ impl BlockPlacement { .then_with(|| self.tie_break().cmp(&other.tie_break())) } + #[ztracing::instrument(skip_all)] fn to_wrap_row(&self, wrap_snapshot: &WrapSnapshot) -> Option> { let buffer_snapshot = wrap_snapshot.buffer_snapshot(); match self { @@ -474,6 +476,7 @@ pub struct BlockRows<'a> { } impl BlockMap { + #[ztracing::instrument(skip_all)] pub fn new( wrap_snapshot: WrapSnapshot, buffer_header_height: u32, @@ -503,6 +506,7 @@ impl BlockMap { map } + #[ztracing::instrument(skip_all)] pub fn read(&self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapReader<'_> { self.sync(&wrap_snapshot, edits); *self.wrap_snapshot.borrow_mut() = wrap_snapshot.clone(); @@ -518,13 +522,17 @@ impl BlockMap { } } + #[ztracing::instrument(skip_all)] pub fn write(&mut self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapWriter<'_> { self.sync(&wrap_snapshot, edits); *self.wrap_snapshot.borrow_mut() = wrap_snapshot; BlockMapWriter(self) } + #[ztracing::instrument(skip_all, fields(edits))] fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) { + let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50)); + let buffer = wrap_snapshot.buffer_snapshot(); // Handle changing the last excerpt if it is empty. @@ -784,6 +792,7 @@ impl BlockMap { *transforms = new_transforms; } + #[ztracing::instrument(skip_all)] pub fn replace_blocks(&mut self, mut renderers: HashMap) { for block in &mut self.custom_blocks { if let Some(render) = renderers.remove(&block.id) { @@ -793,6 +802,7 @@ impl BlockMap { } /// Guarantees that `wrap_row_for` is called with points in increasing order. + #[ztracing::instrument(skip_all)] fn header_and_footer_blocks<'a, R, T>( &'a self, buffer: &'a multi_buffer::MultiBufferSnapshot, @@ -880,6 +890,7 @@ impl BlockMap { }) } + #[ztracing::instrument(skip_all)] fn sort_blocks(blocks: &mut Vec<(BlockPlacement, Block)>) { blocks.sort_unstable_by(|(placement_a, block_a), (placement_b, block_b)| { placement_a @@ -1016,6 +1027,7 @@ impl DerefMut for BlockMapReader<'_> { } impl BlockMapReader<'_> { + #[ztracing::instrument(skip_all)] pub fn row_for_block(&self, block_id: CustomBlockId) -> Option { let block = self.blocks.iter().find(|block| block.id == block_id)?; let buffer_row = block @@ -1054,6 +1066,7 @@ impl BlockMapReader<'_> { } impl BlockMapWriter<'_> { + #[ztracing::instrument(skip_all)] pub fn insert( &mut self, blocks: impl IntoIterator>, @@ -1120,6 +1133,7 @@ impl BlockMapWriter<'_> { ids } + #[ztracing::instrument(skip_all)] pub fn resize(&mut self, mut heights: HashMap) { let wrap_snapshot = &*self.0.wrap_snapshot.borrow(); let buffer = wrap_snapshot.buffer_snapshot(); @@ -1172,6 +1186,7 @@ impl BlockMapWriter<'_> { self.0.sync(wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] pub fn remove(&mut self, block_ids: HashSet) { let wrap_snapshot = &*self.0.wrap_snapshot.borrow(); let buffer = wrap_snapshot.buffer_snapshot(); @@ -1217,6 +1232,7 @@ impl BlockMapWriter<'_> { self.0.sync(wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] pub fn remove_intersecting_replace_blocks( &mut self, ranges: impl IntoIterator>, @@ -1239,6 +1255,7 @@ impl BlockMapWriter<'_> { self.0.buffers_with_disabled_headers.insert(buffer_id); } + #[ztracing::instrument(skip_all)] pub fn fold_buffers( &mut self, buffer_ids: impl IntoIterator, @@ -1248,6 +1265,7 @@ impl BlockMapWriter<'_> { self.fold_or_unfold_buffers(true, buffer_ids, multi_buffer, cx); } + #[ztracing::instrument(skip_all)] pub fn unfold_buffers( &mut self, buffer_ids: impl IntoIterator, @@ -1257,6 +1275,7 @@ impl BlockMapWriter<'_> { self.fold_or_unfold_buffers(false, buffer_ids, multi_buffer, cx); } + #[ztracing::instrument(skip_all)] fn fold_or_unfold_buffers( &mut self, fold: bool, @@ -1292,6 +1311,7 @@ impl BlockMapWriter<'_> { self.0.sync(&wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] fn blocks_intersecting_buffer_range( &self, range: Range, @@ -1326,6 +1346,7 @@ impl BlockMapWriter<'_> { impl BlockSnapshot { #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks( BlockRow(0)..self.transforms.summary().output_rows, @@ -1337,6 +1358,7 @@ impl BlockSnapshot { .collect() } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, rows: Range, @@ -1378,6 +1400,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&start_row, Bias::Right); @@ -1399,6 +1422,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn blocks_in_range( &self, rows: Range, @@ -1432,6 +1456,7 @@ impl BlockSnapshot { }) } + #[ztracing::instrument(skip_all)] pub(crate) fn sticky_header_excerpt(&self, position: f64) -> Option> { let top_row = position as u32; let mut cursor = self.transforms.cursor::(()); @@ -1455,6 +1480,7 @@ impl BlockSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn block_for_id(&self, block_id: BlockId) -> Option { let buffer = self.wrap_snapshot.buffer_snapshot(); let wrap_point = match block_id { @@ -1491,6 +1517,7 @@ impl BlockSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> BlockPoint { let row = self .transforms @@ -1500,10 +1527,12 @@ impl BlockSnapshot { BlockPoint::new(row, self.line_len(row)) } + #[ztracing::instrument(skip_all)] pub fn longest_row(&self) -> BlockRow { self.transforms.summary().longest_row } + #[ztracing::instrument(skip_all)] pub fn longest_row_in_range(&self, range: Range) -> BlockRow { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&range.start, Bias::Right); @@ -1555,6 +1584,7 @@ impl BlockSnapshot { longest_row } + #[ztracing::instrument(skip_all)] pub(super) fn line_len(&self, row: BlockRow) -> u32 { let (start, _, item) = self.transforms @@ -1574,11 +1604,13 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub(super) fn is_block_line(&self, row: BlockRow) -> bool { let (_, _, item) = self.transforms.find::((), &row, Bias::Right); item.is_some_and(|t| t.block.is_some()) } + #[ztracing::instrument(skip_all)] pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { let (_, _, item) = self.transforms.find::((), &row, Bias::Right); let Some(transform) = item else { @@ -1587,6 +1619,7 @@ impl BlockSnapshot { matches!(transform.block, Some(Block::FoldedBuffer { .. })) } + #[ztracing::instrument(skip_all)] pub(super) fn is_line_replaced(&self, row: MultiBufferRow) -> bool { let wrap_point = self .wrap_snapshot @@ -1602,6 +1635,7 @@ impl BlockSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&BlockRow(point.row), Bias::Right); @@ -1663,6 +1697,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { let (start, _, item) = self.transforms.find::, _>( (), @@ -1684,6 +1719,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { let (start, end, item) = self.transforms.find::, _>( (), @@ -1719,6 +1755,7 @@ impl BlockSnapshot { impl BlockChunks<'_> { /// Go to the next transform + #[ztracing::instrument(skip_all)] fn advance(&mut self) { self.input_chunk = Chunk::default(); self.transforms.next(); @@ -1759,6 +1796,7 @@ pub struct StickyHeaderExcerpt<'a> { impl<'a> Iterator for BlockChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_row >= self.max_output_row { return None; @@ -1858,6 +1896,7 @@ impl<'a> Iterator for BlockChunks<'a> { impl Iterator for BlockRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.started { self.output_row.0 += 1; @@ -1960,14 +1999,17 @@ impl DerefMut for BlockContext<'_, '_> { } impl CustomBlock { + #[ztracing::instrument(skip_all)] pub fn render(&self, cx: &mut BlockContext) -> AnyElement { self.render.lock()(cx) } + #[ztracing::instrument(skip_all)] pub fn start(&self) -> Anchor { *self.placement.start() } + #[ztracing::instrument(skip_all)] pub fn end(&self) -> Anchor { *self.placement.end() } diff --git a/crates/editor/src/display_map/crease_map.rs b/crates/editor/src/display_map/crease_map.rs index a68c27886733d34a60ef0ce2ef4006b92b679db9..8f4a3781f4f335f1a3e61ec5a19818661a7c6ea5 100644 --- a/crates/editor/src/display_map/crease_map.rs +++ b/crates/editor/src/display_map/crease_map.rs @@ -19,6 +19,7 @@ pub struct CreaseMap { } impl CreaseMap { + #[ztracing::instrument(skip_all)] pub fn new(snapshot: &MultiBufferSnapshot) -> Self { CreaseMap { snapshot: CreaseSnapshot::new(snapshot), @@ -40,11 +41,13 @@ impl CreaseSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn creases(&self) -> impl Iterator)> { self.creases.iter().map(|item| (item.id, &item.crease)) } /// Returns the first Crease starting on the specified buffer row. + #[ztracing::instrument(skip_all)] pub fn query_row<'a>( &'a self, row: MultiBufferRow, @@ -69,6 +72,7 @@ impl CreaseSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn creases_in_range<'a>( &'a self, range: Range, @@ -95,6 +99,7 @@ impl CreaseSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn crease_items_with_offsets( &self, snapshot: &MultiBufferSnapshot, @@ -156,6 +161,7 @@ pub struct CreaseMetadata { } impl Crease { + #[ztracing::instrument(skip_all)] pub fn simple(range: Range, placeholder: FoldPlaceholder) -> Self { Crease::Inline { range, @@ -166,6 +172,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn block(range: Range, height: u32, style: BlockStyle, render: RenderBlock) -> Self { Self::Block { range, @@ -177,6 +184,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn inline( range: Range, placeholder: FoldPlaceholder, @@ -216,6 +224,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn with_metadata(self, metadata: CreaseMetadata) -> Self { match self { Crease::Inline { @@ -235,6 +244,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn range(&self) -> &Range { match self { Crease::Inline { range, .. } => range, @@ -242,6 +252,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn metadata(&self) -> Option<&CreaseMetadata> { match self { Self::Inline { metadata, .. } => metadata.as_ref(), @@ -287,6 +298,7 @@ impl CreaseMap { self.snapshot.clone() } + #[ztracing::instrument(skip_all)] pub fn insert( &mut self, creases: impl IntoIterator>, @@ -312,6 +324,7 @@ impl CreaseMap { new_ids } + #[ztracing::instrument(skip_all)] pub fn remove( &mut self, ids: impl IntoIterator, @@ -379,6 +392,7 @@ impl sum_tree::Summary for ItemSummary { impl sum_tree::Item for CreaseItem { type Summary = ItemSummary; + #[ztracing::instrument(skip_all)] fn summary(&self, _cx: &MultiBufferSnapshot) -> Self::Summary { ItemSummary { range: self.crease.range().clone(), @@ -388,12 +402,14 @@ impl sum_tree::Item for CreaseItem { /// Implements `SeekTarget` for `Range` to enable seeking within a `SumTree` of `CreaseItem`s. impl SeekTarget<'_, ItemSummary, ItemSummary> for Range { + #[ztracing::instrument(skip_all)] fn cmp(&self, cursor_location: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering { AnchorRangeExt::cmp(self, &cursor_location.range, snapshot) } } impl SeekTarget<'_, ItemSummary, ItemSummary> for Anchor { + #[ztracing::instrument(skip_all)] fn cmp(&self, other: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering { self.cmp(&other.range.start, snapshot) } @@ -461,6 +477,7 @@ mod test { } #[gpui::test] + #[ztracing::instrument(skip_all)] fn test_creases_in_range(cx: &mut App) { let text = "line1\nline2\nline3\nline4\nline5\nline6\nline7"; let buffer = MultiBuffer::build_simple(text, cx); diff --git a/crates/editor/src/display_map/custom_highlights.rs b/crates/editor/src/display_map/custom_highlights.rs index a40d1adc82f4bc79308eaec901586232e9e2e5c2..c9202280bf957fac4d729bab558f686c0f62e774 100644 --- a/crates/editor/src/display_map/custom_highlights.rs +++ b/crates/editor/src/display_map/custom_highlights.rs @@ -30,6 +30,7 @@ struct HighlightEndpoint { } impl<'a> CustomHighlightsChunks<'a> { + #[ztracing::instrument(skip_all)] pub fn new( range: Range, language_aware: bool, @@ -51,6 +52,7 @@ impl<'a> CustomHighlightsChunks<'a> { } } + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, new_range: Range) { self.highlight_endpoints = create_highlight_endpoints(&new_range, self.text_highlights, self.multibuffer_snapshot); @@ -108,6 +110,7 @@ fn create_highlight_endpoints( impl<'a> Iterator for CustomHighlightsChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let mut next_highlight_endpoint = MultiBufferOffset(usize::MAX); while let Some(endpoint) = self.highlight_endpoints.peek().copied() { diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 2d37dea38a93cc609a3a92064a6e35cdc76eb3da..bb0d6885acc2afd95e97fe9121acd2d0580554f3 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -99,6 +99,7 @@ impl FoldPoint { &mut self.0.column } + #[ztracing::instrument(skip_all)] pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { let (start, _, _) = snapshot .transforms @@ -107,6 +108,7 @@ impl FoldPoint { InlayPoint(start.1.0 + overshoot) } + #[ztracing::instrument(skip_all)] pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset { let (start, _, item) = snapshot .transforms @@ -138,6 +140,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for FoldPoint { pub(crate) struct FoldMapWriter<'a>(&'a mut FoldMap); impl FoldMapWriter<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn fold( &mut self, ranges: impl IntoIterator, FoldPlaceholder)>, @@ -202,6 +205,7 @@ impl FoldMapWriter<'_> { } /// Removes any folds with the given ranges. + #[ztracing::instrument(skip_all)] pub(crate) fn remove_folds( &mut self, ranges: impl IntoIterator>, @@ -215,6 +219,7 @@ impl FoldMapWriter<'_> { } /// Removes any folds whose ranges intersect the given ranges. + #[ztracing::instrument(skip_all)] pub(crate) fn unfold_intersecting( &mut self, ranges: impl IntoIterator>, @@ -225,6 +230,7 @@ impl FoldMapWriter<'_> { /// Removes any folds that intersect the given ranges and for which the given predicate /// returns true. + #[ztracing::instrument(skip_all)] fn remove_folds_with( &mut self, ranges: impl IntoIterator>, @@ -277,6 +283,7 @@ impl FoldMapWriter<'_> { (self.0.snapshot.clone(), edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn update_fold_widths( &mut self, new_widths: impl IntoIterator, @@ -326,6 +333,7 @@ pub struct FoldMap { } impl FoldMap { + #[ztracing::instrument(skip_all)] pub fn new(inlay_snapshot: InlaySnapshot) -> (Self, FoldSnapshot) { let this = Self { snapshot: FoldSnapshot { @@ -350,6 +358,7 @@ impl FoldMap { (this, snapshot) } + #[ztracing::instrument(skip_all)] pub fn read( &mut self, inlay_snapshot: InlaySnapshot, @@ -360,6 +369,7 @@ impl FoldMap { (self.snapshot.clone(), edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn write( &mut self, inlay_snapshot: InlaySnapshot, @@ -369,6 +379,7 @@ impl FoldMap { (FoldMapWriter(self), snapshot, edits) } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { if cfg!(test) { assert_eq!( @@ -398,6 +409,7 @@ impl FoldMap { } } + #[ztracing::instrument(skip_all)] fn sync( &mut self, inlay_snapshot: InlaySnapshot, @@ -645,6 +657,7 @@ impl FoldSnapshot { &self.inlay_snapshot.buffer } + #[ztracing::instrument(skip_all)] fn fold_width(&self, fold_id: &FoldId) -> Option { self.fold_metadata_by_id.get(fold_id)?.width } @@ -665,6 +678,7 @@ impl FoldSnapshot { self.folds.items(&self.inlay_snapshot.buffer).len() } + #[ztracing::instrument(skip_all)] pub fn text_summary_for_range(&self, range: Range) -> MBTextSummary { let mut summary = MBTextSummary::default(); @@ -718,6 +732,7 @@ impl FoldSnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { let (start, end, item) = self .transforms @@ -734,6 +749,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn fold_point_cursor(&self) -> FoldPointCursor<'_> { let cursor = self .transforms @@ -741,10 +757,12 @@ impl FoldSnapshot { FoldPointCursor { cursor } } + #[ztracing::instrument(skip_all)] pub fn len(&self) -> FoldOffset { FoldOffset(self.transforms.summary().output.len) } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let line_start = FoldPoint::new(row, 0).to_offset(self).0; let line_end = if row >= self.max_point().row() { @@ -755,6 +773,7 @@ impl FoldSnapshot { (line_end - line_start) as u32 } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, start_row: u32) -> FoldRows<'_> { if start_row > self.transforms.summary().output.lines.row { panic!("invalid display row {}", start_row); @@ -777,6 +796,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> FoldPoint { FoldPoint(self.transforms.summary().output.lines) } @@ -786,6 +806,7 @@ impl FoldSnapshot { self.transforms.summary().output.longest_row } + #[ztracing::instrument(skip_all)] pub fn folds_in_range(&self, range: Range) -> impl Iterator where T: ToOffset, @@ -800,6 +821,7 @@ impl FoldSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn intersects_fold(&self, offset: T) -> bool where T: ToOffset, @@ -812,6 +834,7 @@ impl FoldSnapshot { item.is_some_and(|t| t.placeholder.is_some()) } + #[ztracing::instrument(skip_all)] pub fn is_line_folded(&self, buffer_row: MultiBufferRow) -> bool { let mut inlay_point = self .inlay_snapshot @@ -840,6 +863,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -884,6 +908,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn chars_at(&self, start: FoldPoint) -> impl '_ + Iterator { self.chunks( start.to_offset(self)..self.len(), @@ -893,6 +918,7 @@ impl FoldSnapshot { .flat_map(|chunk| chunk.text.chars()) } + #[ztracing::instrument(skip_all)] pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> { self.chunks( start.to_offset(self)..self.len(), @@ -902,6 +928,7 @@ impl FoldSnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn clip_offset(&self, offset: FoldOffset, bias: Bias) -> FoldOffset { if offset > self.len() { self.len() @@ -910,6 +937,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { let (start, end, item) = self .transforms @@ -939,6 +967,7 @@ pub struct FoldPointCursor<'transforms> { } impl FoldPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: InlayPoint, bias: Bias) -> FoldPoint { let cursor = &mut self.cursor; if cursor.did_seek() { @@ -1267,6 +1296,7 @@ pub struct FoldRows<'a> { } impl FoldRows<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, row: u32) { let fold_point = FoldPoint::new(row, 0); self.cursor.seek(&fold_point, Bias::Left); @@ -1280,6 +1310,7 @@ impl FoldRows<'_> { impl Iterator for FoldRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let mut traversed_fold = false; while self.fold_point > self.cursor.end().0 { @@ -1391,6 +1422,7 @@ pub struct FoldChunks<'a> { } impl FoldChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, range: Range) { self.transform_cursor.seek(&range.start, Bias::Right); @@ -1425,6 +1457,7 @@ impl FoldChunks<'_> { impl<'a> Iterator for FoldChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_offset >= self.max_output_offset { return None; @@ -1524,6 +1557,7 @@ impl<'a> Iterator for FoldChunks<'a> { pub struct FoldOffset(pub MultiBufferOffset); impl FoldOffset { + #[ztracing::instrument(skip_all)] pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint { let (start, _, item) = snapshot .transforms @@ -1539,6 +1573,7 @@ impl FoldOffset { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { let (start, _, _) = snapshot .transforms diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index 73174c3018e1f76a16acbff3f4bad1c7af84da33..d85f761a82e2f466b6868c4ce28bcb3a4e6b061d 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -52,6 +52,7 @@ enum Transform { impl sum_tree::Item for Transform { type Summary = TransformSummary; + #[ztracing::instrument(skip_all)] fn summary(&self, _: ()) -> Self::Summary { match self { Transform::Isomorphic(summary) => TransformSummary { @@ -228,6 +229,7 @@ pub struct InlayChunk<'a> { } impl InlayChunks<'_> { + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, new_range: Range) { self.transforms.seek(&new_range.start, Bias::Right); @@ -248,6 +250,7 @@ impl InlayChunks<'_> { impl<'a> Iterator for InlayChunks<'a> { type Item = InlayChunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_offset == self.max_output_offset { return None; @@ -441,6 +444,7 @@ impl<'a> Iterator for InlayChunks<'a> { } impl InlayBufferRows<'_> { + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, row: u32) { let inlay_point = InlayPoint::new(row, 0); self.transforms.seek(&inlay_point, Bias::Left); @@ -465,6 +469,7 @@ impl InlayBufferRows<'_> { impl Iterator for InlayBufferRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let buffer_row = if self.inlay_row == 0 { self.buffer_rows.next().unwrap() @@ -494,6 +499,7 @@ impl InlayPoint { } impl InlayMap { + #[ztracing::instrument(skip_all)] pub fn new(buffer: MultiBufferSnapshot) -> (Self, InlaySnapshot) { let version = 0; let snapshot = InlaySnapshot { @@ -511,6 +517,7 @@ impl InlayMap { ) } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, buffer_snapshot: MultiBufferSnapshot, @@ -643,6 +650,7 @@ impl InlayMap { } } + #[ztracing::instrument(skip_all)] pub fn splice( &mut self, to_remove: &[InlayId], @@ -693,11 +701,13 @@ impl InlayMap { (snapshot, edits) } + #[ztracing::instrument(skip_all)] pub fn current_inlays(&self) -> impl Iterator { self.inlays.iter() } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub(crate) fn randomly_mutate( &mut self, next_inlay_id: &mut usize, @@ -766,6 +776,7 @@ impl InlayMap { } impl InlaySnapshot { + #[ztracing::instrument(skip_all)] pub fn to_point(&self, offset: InlayOffset) -> InlayPoint { let (start, _, item) = self.transforms.find:: InlayOffset { InlayOffset(self.transforms.summary().output.len) } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> InlayPoint { InlayPoint(self.transforms.summary().output.lines) } + #[ztracing::instrument(skip_all, fields(point))] pub fn to_offset(&self, point: InlayPoint) -> InlayOffset { let (start, _, item) = self .transforms @@ -817,6 +831,7 @@ impl InlaySnapshot { None => self.len(), } } + #[ztracing::instrument(skip_all)] pub fn to_buffer_point(&self, point: InlayPoint) -> Point { let (start, _, item) = self.transforms @@ -830,6 +845,7 @@ impl InlaySnapshot { None => self.buffer.max_point(), } } + #[ztracing::instrument(skip_all)] pub fn to_buffer_offset(&self, offset: InlayOffset) -> MultiBufferOffset { let (start, _, item) = self .transforms @@ -844,6 +860,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_inlay_offset(&self, offset: MultiBufferOffset) -> InlayOffset { let mut cursor = self .transforms @@ -880,10 +897,12 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_inlay_point(&self, point: Point) -> InlayPoint { self.inlay_point_cursor().map(point) } + #[ztracing::instrument(skip_all)] pub fn inlay_point_cursor(&self) -> InlayPointCursor<'_> { let cursor = self.transforms.cursor::>(()); InlayPointCursor { @@ -892,6 +911,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&point, Bias::Left); @@ -983,10 +1003,12 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn text_summary(&self) -> MBTextSummary { self.transforms.summary().output } + #[ztracing::instrument(skip_all)] pub fn text_summary_for_range(&self, range: Range) -> MBTextSummary { let mut summary = MBTextSummary::default(); @@ -1044,6 +1066,7 @@ impl InlaySnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { let mut cursor = self.transforms.cursor::>(()); let inlay_point = InlayPoint::new(row, 0); @@ -1071,6 +1094,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let line_start = self.to_offset(InlayPoint::new(row, 0)).0; let line_end = if row >= self.max_point().row() { @@ -1081,6 +1105,7 @@ impl InlaySnapshot { (line_end - line_start) as u32 } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -1115,12 +1140,14 @@ impl InlaySnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks(Default::default()..self.len(), false, Highlights::default()) .map(|chunk| chunk.chunk.text) .collect() } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { #[cfg(any(debug_assertions, feature = "test-support"))] { @@ -1147,6 +1174,7 @@ pub struct InlayPointCursor<'transforms> { } impl InlayPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: Point) -> InlayPoint { let cursor = &mut self.cursor; if cursor.did_seek() { diff --git a/crates/editor/src/display_map/invisibles.rs b/crates/editor/src/display_map/invisibles.rs index 5622a659b7acf850d24f6a476b23b53d214d855d..90bd54ab2807bbef703ac29e4ac4eaf49bcf71fd 100644 --- a/crates/editor/src/display_map/invisibles.rs +++ b/crates/editor/src/display_map/invisibles.rs @@ -30,6 +30,7 @@ // ref: https://gist.github.com/ConradIrwin/f759e1fc29267143c4c7895aa495dca5?h=1 // ref: https://unicode.org/Public/emoji/13.0/emoji-test.txt // https://github.com/bits/UTF-8-Unicode-Test-Documents/blob/master/UTF-8_sequence_separated/utf8_sequence_0-0x10ffff_assigned_including-unprintable-asis.txt +#[ztracing::instrument(skip_all)] pub fn is_invisible(c: char) -> bool { if c <= '\u{1f}' { c != '\t' && c != '\n' && c != '\r' diff --git a/crates/editor/src/display_map/tab_map.rs b/crates/editor/src/display_map/tab_map.rs index 347d7732151e172812de1e0252ca8d65f4cdbb8b..4e768a477159820ea380aa48a123d103c0c2f6a2 100644 --- a/crates/editor/src/display_map/tab_map.rs +++ b/crates/editor/src/display_map/tab_map.rs @@ -20,6 +20,7 @@ const MAX_TABS: NonZeroU32 = NonZeroU32::new(SPACES.len() as u32).unwrap(); pub struct TabMap(TabSnapshot); impl TabMap { + #[ztracing::instrument(skip_all)] pub fn new(fold_snapshot: FoldSnapshot, tab_size: NonZeroU32) -> (Self, TabSnapshot) { let snapshot = TabSnapshot { fold_snapshot, @@ -36,6 +37,7 @@ impl TabMap { self.0.clone() } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, fold_snapshot: FoldSnapshot, @@ -176,10 +178,12 @@ impl std::ops::Deref for TabSnapshot { } impl TabSnapshot { + #[ztracing::instrument(skip_all)] pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot { &self.fold_snapshot.inlay_snapshot.buffer } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let max_point = self.max_point(); if row < max_point.row() { @@ -191,10 +195,12 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn text_summary(&self) -> TextSummary { self.text_summary_for_range(TabPoint::zero()..self.max_point()) } + #[ztracing::instrument(skip_all, fields(rows))] pub fn text_summary_for_range(&self, range: Range) -> TextSummary { let input_start = self.tab_point_to_fold_point(range.start, Bias::Left).0; let input_end = self.tab_point_to_fold_point(range.end, Bias::Right).0; @@ -234,6 +240,7 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -276,11 +283,13 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn rows(&self, row: u32) -> fold_map::FoldRows<'_> { self.fold_snapshot.row_infos(row) } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks( TabPoint::zero()..self.max_point(), @@ -291,10 +300,12 @@ impl TabSnapshot { .collect() } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> TabPoint { self.fold_point_to_tab_point(self.fold_snapshot.max_point()) } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: TabPoint, bias: Bias) -> TabPoint { self.fold_point_to_tab_point( self.fold_snapshot @@ -302,6 +313,7 @@ impl TabSnapshot { ) } + #[ztracing::instrument(skip_all)] pub fn fold_point_to_tab_point(&self, input: FoldPoint) -> TabPoint { let chunks = self.fold_snapshot.chunks_at(FoldPoint::new(input.row(), 0)); let tab_cursor = TabStopCursor::new(chunks); @@ -309,10 +321,12 @@ impl TabSnapshot { TabPoint::new(input.row(), expanded) } + #[ztracing::instrument(skip_all)] pub fn tab_point_cursor(&self) -> TabPointCursor<'_> { TabPointCursor { this: self } } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_fold_point(&self, output: TabPoint, bias: Bias) -> (FoldPoint, u32, u32) { let chunks = self .fold_snapshot @@ -330,12 +344,14 @@ impl TabSnapshot { ) } + #[ztracing::instrument(skip_all)] pub fn point_to_tab_point(&self, point: Point, bias: Bias) -> TabPoint { let inlay_point = self.fold_snapshot.inlay_snapshot.to_inlay_point(point); let fold_point = self.fold_snapshot.to_fold_point(inlay_point, bias); self.fold_point_to_tab_point(fold_point) } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_point(&self, point: TabPoint, bias: Bias) -> Point { let fold_point = self.tab_point_to_fold_point(point, bias).0; let inlay_point = fold_point.to_inlay_point(&self.fold_snapshot); @@ -344,6 +360,7 @@ impl TabSnapshot { .to_buffer_point(inlay_point) } + #[ztracing::instrument(skip_all)] fn expand_tabs<'a, I>(&self, mut cursor: TabStopCursor<'a, I>, column: u32) -> u32 where I: Iterator>, @@ -377,6 +394,7 @@ impl TabSnapshot { expanded_bytes + column.saturating_sub(collapsed_bytes) } + #[ztracing::instrument(skip_all)] fn collapse_tabs<'a, I>( &self, mut cursor: TabStopCursor<'a, I>, @@ -442,6 +460,7 @@ pub struct TabPointCursor<'this> { } impl TabPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: FoldPoint) -> TabPoint { self.this.fold_point_to_tab_point(point) } @@ -486,6 +505,7 @@ pub struct TextSummary { } impl<'a> From<&'a str> for TextSummary { + #[ztracing::instrument(skip_all)] fn from(text: &'a str) -> Self { let sum = text::TextSummary::from(text); @@ -500,6 +520,7 @@ impl<'a> From<&'a str> for TextSummary { } impl<'a> std::ops::AddAssign<&'a Self> for TextSummary { + #[ztracing::instrument(skip_all)] fn add_assign(&mut self, other: &'a Self) { let joined_chars = self.last_line_chars + other.first_line_chars; if joined_chars > self.longest_row_chars { @@ -541,6 +562,7 @@ pub struct TabChunks<'a> { } impl TabChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, range: Range) { let (input_start, expanded_char_column, to_next_stop) = self .snapshot @@ -576,6 +598,7 @@ impl TabChunks<'_> { impl<'a> Iterator for TabChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.chunk.text.is_empty() { if let Some(chunk) = self.fold_chunks.next() { @@ -1452,6 +1475,7 @@ impl<'a, I> TabStopCursor<'a, I> where I: Iterator>, { + #[ztracing::instrument(skip_all)] fn new(chunks: impl IntoIterator, IntoIter = I>) -> Self { Self { chunks: chunks.into_iter(), @@ -1461,6 +1485,7 @@ where } } + #[ztracing::instrument(skip_all)] fn bytes_until_next_char(&self) -> Option { self.current_chunk.as_ref().and_then(|(chunk, idx)| { let mut idx = *idx; @@ -1482,6 +1507,7 @@ where }) } + #[ztracing::instrument(skip_all)] fn is_char_boundary(&self) -> bool { self.current_chunk .as_ref() @@ -1489,6 +1515,7 @@ where } /// distance: length to move forward while searching for the next tab stop + #[ztracing::instrument(skip_all)] fn seek(&mut self, distance: u32) -> Option { if distance == 0 { return None; diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index 20ef9391888e6a824b87fe5de2607500049904ff..51d5324c838dc7cb7f4df04b0e58577108aab6c8 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -86,6 +86,7 @@ pub struct WrapRows<'a> { } impl WrapRows<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, start_row: WrapRow) { self.transforms .seek(&WrapPoint::new(start_row, 0), Bias::Left); @@ -101,6 +102,7 @@ impl WrapRows<'_> { } impl WrapMap { + #[ztracing::instrument(skip_all)] pub fn new( tab_snapshot: TabSnapshot, font: Font, @@ -131,6 +133,7 @@ impl WrapMap { self.background_task.is_some() } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, tab_snapshot: TabSnapshot, @@ -150,6 +153,7 @@ impl WrapMap { (self.snapshot.clone(), mem::take(&mut self.edits_since_sync)) } + #[ztracing::instrument(skip_all)] pub fn set_font_with_size( &mut self, font: Font, @@ -167,6 +171,7 @@ impl WrapMap { } } + #[ztracing::instrument(skip_all)] pub fn set_wrap_width(&mut self, wrap_width: Option, cx: &mut Context) -> bool { if wrap_width == self.wrap_width { return false; @@ -177,6 +182,7 @@ impl WrapMap { true } + #[ztracing::instrument(skip_all)] fn rewrap(&mut self, cx: &mut Context) { self.background_task.take(); self.interpolated_edits.clear(); @@ -248,6 +254,7 @@ impl WrapMap { } } + #[ztracing::instrument(skip_all)] fn flush_edits(&mut self, cx: &mut Context) { if !self.snapshot.interpolated { let mut to_remove_len = 0; @@ -330,6 +337,7 @@ impl WrapMap { } impl WrapSnapshot { + #[ztracing::instrument(skip_all)] fn new(tab_snapshot: TabSnapshot) -> Self { let mut transforms = SumTree::default(); let extent = tab_snapshot.text_summary(); @@ -343,10 +351,12 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot { self.tab_snapshot.buffer_snapshot() } + #[ztracing::instrument(skip_all)] fn interpolate(&mut self, new_tab_snapshot: TabSnapshot, tab_edits: &[TabEdit]) -> WrapPatch { let mut new_transforms; if tab_edits.is_empty() { @@ -411,6 +421,7 @@ impl WrapSnapshot { old_snapshot.compute_edits(tab_edits, self) } + #[ztracing::instrument(skip_all)] async fn update( &mut self, new_tab_snapshot: TabSnapshot, @@ -570,6 +581,7 @@ impl WrapSnapshot { old_snapshot.compute_edits(tab_edits, self) } + #[ztracing::instrument(skip_all)] fn compute_edits(&self, tab_edits: &[TabEdit], new_snapshot: &WrapSnapshot) -> WrapPatch { let mut wrap_edits = Vec::with_capacity(tab_edits.len()); let mut old_cursor = self.transforms.cursor::(()); @@ -606,6 +618,7 @@ impl WrapSnapshot { Patch::new(wrap_edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, rows: Range, @@ -640,10 +653,12 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> WrapPoint { WrapPoint(self.transforms.summary().output.lines) } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: WrapRow) -> u32 { let (start, _, item) = self.transforms.find::, _>( (), @@ -664,6 +679,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all, fields(rows))] pub fn text_summary_for_range(&self, rows: Range) -> TextSummary { let mut summary = TextSummary::default(); @@ -725,6 +741,7 @@ impl WrapSnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn soft_wrap_indent(&self, row: WrapRow) -> Option { let (.., item) = self.transforms.find::( (), @@ -740,10 +757,12 @@ impl WrapSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn longest_row(&self) -> u32 { self.transforms.summary().output.longest_row } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, start_row: WrapRow) -> WrapRows<'_> { let mut transforms = self .transforms @@ -766,6 +785,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { let (start, _, item) = self.transforms @@ -777,15 +797,18 @@ impl WrapSnapshot { TabPoint(tab_point) } + #[ztracing::instrument(skip_all)] pub fn to_point(&self, point: WrapPoint, bias: Bias) -> Point { self.tab_snapshot .tab_point_to_point(self.to_tab_point(point), bias) } + #[ztracing::instrument(skip_all)] pub fn make_wrap_point(&self, point: Point, bias: Bias) -> WrapPoint { self.tab_point_to_wrap_point(self.tab_snapshot.point_to_tab_point(point, bias)) } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { let (start, ..) = self.transforms @@ -793,6 +816,7 @@ impl WrapSnapshot { WrapPoint(start.1.0 + (point.0 - start.0.0)) } + #[ztracing::instrument(skip_all)] pub fn wrap_point_cursor(&self) -> WrapPointCursor<'_> { WrapPointCursor { cursor: self @@ -801,6 +825,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, mut point: WrapPoint, bias: Bias) -> WrapPoint { if bias == Bias::Left { let (start, _, item) = self @@ -815,6 +840,7 @@ impl WrapSnapshot { self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias)) } + #[ztracing::instrument(skip_all, fields(point, ret))] pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow { if self.transforms.is_empty() { return WrapRow(0); @@ -841,6 +867,7 @@ impl WrapSnapshot { unreachable!() } + #[ztracing::instrument(skip_all)] pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option { point.0 += Point::new(1, 0); @@ -860,11 +887,13 @@ impl WrapSnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.text_chunks(WrapRow(0)).collect() } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text_chunks(&self, wrap_row: WrapRow) -> impl Iterator { self.chunks( wrap_row..self.max_point().row() + WrapRow(1), @@ -874,6 +903,7 @@ impl WrapSnapshot { .map(|h| h.text) } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { #[cfg(test)] { @@ -927,6 +957,7 @@ pub struct WrapPointCursor<'transforms> { } impl WrapPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: TabPoint) -> WrapPoint { let cursor = &mut self.cursor; if cursor.did_seek() { @@ -939,6 +970,7 @@ impl WrapPointCursor<'_> { } impl WrapChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, rows: Range) { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); @@ -961,6 +993,7 @@ impl WrapChunks<'_> { impl<'a> Iterator for WrapChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_position.row() >= self.max_output_row { return None; @@ -1033,6 +1066,7 @@ impl<'a> Iterator for WrapChunks<'a> { impl Iterator for WrapRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_row > self.max_output_row { return None; @@ -1069,6 +1103,7 @@ impl Iterator for WrapRows<'_> { } impl Transform { + #[ztracing::instrument(skip_all)] fn isomorphic(summary: TextSummary) -> Self { #[cfg(test)] assert!(!summary.lines.is_zero()); @@ -1082,6 +1117,7 @@ impl Transform { } } + #[ztracing::instrument(skip_all)] fn wrap(indent: u32) -> Self { static WRAP_TEXT: LazyLock = LazyLock::new(|| { let mut wrap_text = String::new(); @@ -1134,6 +1170,7 @@ trait SumTreeExt { } impl SumTreeExt for SumTree { + #[ztracing::instrument(skip_all)] fn push_or_extend(&mut self, transform: Transform) { let mut transform = Some(transform); self.update_last( @@ -1197,6 +1234,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for TabPoint { } impl sum_tree::SeekTarget<'_, TransformSummary, TransformSummary> for TabPoint { + #[ztracing::instrument(skip_all)] fn cmp(&self, cursor_location: &TransformSummary, _: ()) -> std::cmp::Ordering { Ord::cmp(&self.0, &cursor_location.input.lines) } diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 5e96cd3529b48bb401ee14e1a704b9bec485e356..beaf192b0ef538fb524ff4986710255040b89f27 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -13,7 +13,6 @@ name = "git_ui" path = "src/git_ui.rs" [features] -default = [] test-support = ["multi_buffer/test-support"] [dependencies] @@ -62,7 +61,8 @@ watch.workspace = true workspace.workspace = true zed_actions.workspace = true zeroize.workspace = true - +ztracing.workspace = true +tracing.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true @@ -78,3 +78,6 @@ settings = { workspace = true, features = ["test-support"] } unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/git_ui/src/project_diff.rs b/crates/git_ui/src/project_diff.rs index 0a8667ba6c753f9b7925948f212388f0668c1c92..f211483c5efeb14fd230def9235d82a1a79f49b4 100644 --- a/crates/git_ui/src/project_diff.rs +++ b/crates/git_ui/src/project_diff.rs @@ -46,6 +46,7 @@ use workspace::{ notifications::NotifyTaskExt, searchable::SearchableItemHandle, }; +use ztracing::instrument; actions!( git, @@ -469,6 +470,7 @@ impl ProjectDiff { } } + #[instrument(skip_all)] fn register_buffer( &mut self, path_key: PathKey, diff --git a/crates/multi_buffer/Cargo.toml b/crates/multi_buffer/Cargo.toml index 93747140c1960b70b9a9ddffe2a609e8a32a7dc7..524c916682f4d17b4e4b598a9af158e259b40ffc 100644 --- a/crates/multi_buffer/Cargo.toml +++ b/crates/multi_buffer/Cargo.toml @@ -42,6 +42,8 @@ sum_tree.workspace = true text.workspace = true theme.workspace = true tree-sitter.workspace = true +ztracing.workspace = true +tracing.workspace = true util.workspace = true [dev-dependencies] @@ -56,3 +58,6 @@ settings = { workspace = true, features = ["test-support"] } text = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } zlog.workspace = true + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index af36aaadf02b53224c4ef0bcf0a17d3643ab8f0f..24cb55d2f5e7311cc492ec70ab320eb12e78f8ee 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -57,6 +57,7 @@ use text::{ }; use theme::SyntaxTheme; use util::post_inc; +use ztracing::instrument; pub use self::path_key::PathKey; @@ -1671,6 +1672,7 @@ impl MultiBuffer { self.insert_excerpts_after(ExcerptId::max(), buffer, ranges, cx) } + #[instrument(skip_all)] fn merge_excerpt_ranges<'a>( expanded_ranges: impl IntoIterator> + 'a, ) -> (Vec>, Vec) { @@ -4483,6 +4485,7 @@ impl MultiBufferSnapshot { self.convert_dimension(point, text::BufferSnapshot::point_utf16_to_point) } + #[instrument(skip_all)] pub fn point_to_offset(&self, point: Point) -> MultiBufferOffset { self.convert_dimension(point, text::BufferSnapshot::point_to_offset) } @@ -4536,6 +4539,7 @@ impl MultiBufferSnapshot { } } + #[instrument(skip_all)] fn convert_dimension( &self, key: MBR1, @@ -6684,6 +6688,7 @@ where MBD: MultiBufferDimension + Ord + Sub + ops::AddAssign<::Output>, BD: TextDimension + AddAssign<::Output>, { + #[instrument(skip_all)] fn seek(&mut self, position: &MBD) { let position = OutputDimension(*position); self.cached_region.take(); diff --git a/crates/multi_buffer/src/path_key.rs b/crates/multi_buffer/src/path_key.rs index 1685e7a27329b1beea5f0d2c9563acfab07d8d8b..82bb902c230180d98c54225e8b57bf85beeedc2d 100644 --- a/crates/multi_buffer/src/path_key.rs +++ b/crates/multi_buffer/src/path_key.rs @@ -1,435 +1,437 @@ -use std::{mem, ops::Range, sync::Arc}; - -use collections::HashSet; -use gpui::{App, AppContext, Context, Entity}; -use itertools::Itertools; -use language::{Buffer, BufferSnapshot}; -use rope::Point; -use text::{Bias, BufferId, OffsetRangeExt, locator::Locator}; -use util::{post_inc, rel_path::RelPath}; - -use crate::{ - Anchor, ExcerptId, ExcerptRange, ExpandExcerptDirection, MultiBuffer, build_excerpt_ranges, -}; - -#[derive(PartialEq, Eq, Ord, PartialOrd, Clone, Hash, Debug)] -pub struct PathKey { - // Used by the derived PartialOrd & Ord - pub sort_prefix: Option, - pub path: Arc, -} - -impl PathKey { - pub fn with_sort_prefix(sort_prefix: u64, path: Arc) -> Self { - Self { - sort_prefix: Some(sort_prefix), - path, - } - } - - pub fn for_buffer(buffer: &Entity, cx: &App) -> Self { - if let Some(file) = buffer.read(cx).file() { - Self::with_sort_prefix(file.worktree_id(cx).to_proto(), file.path().clone()) - } else { - Self { - sort_prefix: None, - path: RelPath::unix(&buffer.entity_id().to_string()) - .unwrap() - .into_arc(), - } - } - } -} - -impl MultiBuffer { - pub fn paths(&self) -> impl Iterator + '_ { - self.excerpts_by_path.keys().cloned() - } - - pub fn remove_excerpts_for_path(&mut self, path: PathKey, cx: &mut Context) { - if let Some(to_remove) = self.excerpts_by_path.remove(&path) { - self.remove_excerpts(to_remove, cx) - } - if let Some(follower) = &self.follower { - follower.update(cx, |follower, cx| { - follower.remove_excerpts_for_path(path, cx); - }); - } - } - - pub fn location_for_path(&self, path: &PathKey, cx: &App) -> Option { - let excerpt_id = self.excerpts_by_path.get(path)?.first()?; - let snapshot = self.read(cx); - let excerpt = snapshot.excerpt(*excerpt_id)?; - Some(Anchor::in_buffer(excerpt.id, excerpt.range.context.start)) - } - - pub fn excerpt_paths(&self) -> impl Iterator { - self.excerpts_by_path.keys() - } - - /// Sets excerpts, returns `true` if at least one new excerpt was added. - pub fn set_excerpts_for_path( - &mut self, - path: PathKey, - buffer: Entity, - ranges: impl IntoIterator>, - context_line_count: u32, - cx: &mut Context, - ) -> (Vec>, bool) { - let buffer_snapshot = buffer.read(cx).snapshot(); - let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot); - - let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); - self.set_merged_excerpt_ranges_for_path( - path, - buffer, - excerpt_ranges, - &buffer_snapshot, - new, - counts, - cx, - ) - } - - pub fn set_excerpt_ranges_for_path( - &mut self, - path: PathKey, - buffer: Entity, - buffer_snapshot: &BufferSnapshot, - excerpt_ranges: Vec>, - cx: &mut Context, - ) -> (Vec>, bool) { - let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); - self.set_merged_excerpt_ranges_for_path( - path, - buffer, - excerpt_ranges, - buffer_snapshot, - new, - counts, - cx, - ) - } - - pub fn set_anchored_excerpts_for_path( - &self, - path_key: PathKey, - buffer: Entity, - ranges: Vec>, - context_line_count: u32, - cx: &Context, - ) -> impl Future>> + use<> { - let buffer_snapshot = buffer.read(cx).snapshot(); - let multi_buffer = cx.weak_entity(); - let mut app = cx.to_async(); - async move { - let snapshot = buffer_snapshot.clone(); - let (excerpt_ranges, new, counts) = app - .background_spawn(async move { - let ranges = ranges.into_iter().map(|range| range.to_point(&snapshot)); - let excerpt_ranges = - build_excerpt_ranges(ranges, context_line_count, &snapshot); - let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); - (excerpt_ranges, new, counts) - }) - .await; - - multi_buffer - .update(&mut app, move |multi_buffer, cx| { - let (ranges, _) = multi_buffer.set_merged_excerpt_ranges_for_path( - path_key, - buffer, - excerpt_ranges, - &buffer_snapshot, - new, - counts, - cx, - ); - ranges - }) - .ok() - .unwrap_or_default() - } - } - - pub fn remove_excerpts_for_buffer(&mut self, buffer: BufferId, cx: &mut Context) { - self.remove_excerpts( - self.excerpts_for_buffer(buffer, cx) - .into_iter() - .map(|(excerpt, _)| excerpt), - cx, - ); - } - - pub(super) fn expand_excerpts_with_paths( - &mut self, - ids: impl IntoIterator, - line_count: u32, - direction: ExpandExcerptDirection, - cx: &mut Context, - ) { - let grouped = ids - .into_iter() - .chunk_by(|id| self.paths_by_excerpt.get(id).cloned()) - .into_iter() - .filter_map(|(k, v)| Some((k?, v.into_iter().collect::>()))) - .collect::>(); - let snapshot = self.snapshot(cx); - - for (path, ids) in grouped.into_iter() { - let Some(excerpt_ids) = self.excerpts_by_path.get(&path) else { - continue; - }; - - let ids_to_expand = HashSet::from_iter(ids); - let mut excerpt_id_ = None; - let expanded_ranges = excerpt_ids.iter().filter_map(|excerpt_id| { - let excerpt = snapshot.excerpt(*excerpt_id)?; - let excerpt_id = excerpt.id; - if excerpt_id_.is_none() { - excerpt_id_ = Some(excerpt_id); - } - - let mut context = excerpt.range.context.to_point(&excerpt.buffer); - if ids_to_expand.contains(&excerpt_id) { - match direction { - ExpandExcerptDirection::Up => { - context.start.row = context.start.row.saturating_sub(line_count); - context.start.column = 0; - } - ExpandExcerptDirection::Down => { - context.end.row = - (context.end.row + line_count).min(excerpt.buffer.max_point().row); - context.end.column = excerpt.buffer.line_len(context.end.row); - } - ExpandExcerptDirection::UpAndDown => { - context.start.row = context.start.row.saturating_sub(line_count); - context.start.column = 0; - context.end.row = - (context.end.row + line_count).min(excerpt.buffer.max_point().row); - context.end.column = excerpt.buffer.line_len(context.end.row); - } - } - } - - Some(ExcerptRange { - context, - primary: excerpt.range.primary.to_point(&excerpt.buffer), - }) - }); - let mut merged_ranges: Vec> = Vec::new(); - for range in expanded_ranges { - if let Some(last_range) = merged_ranges.last_mut() - && last_range.context.end >= range.context.start - { - last_range.context.end = range.context.end; - continue; - } - merged_ranges.push(range) - } - let Some(excerpt_id) = excerpt_id_ else { - continue; - }; - let Some(buffer_id) = &snapshot.buffer_id_for_excerpt(excerpt_id) else { - continue; - }; - - let Some(buffer) = self.buffers.get(buffer_id).map(|b| b.buffer.clone()) else { - continue; - }; - - let buffer_snapshot = buffer.read(cx).snapshot(); - self.update_path_excerpts(path.clone(), buffer, &buffer_snapshot, merged_ranges, cx); - } - } - - /// Sets excerpts, returns `true` if at least one new excerpt was added. - fn set_merged_excerpt_ranges_for_path( - &mut self, - path: PathKey, - buffer: Entity, - ranges: Vec>, - buffer_snapshot: &BufferSnapshot, - new: Vec>, - counts: Vec, - cx: &mut Context, - ) -> (Vec>, bool) { - let (excerpt_ids, added_a_new_excerpt) = - self.update_path_excerpts(path, buffer, buffer_snapshot, new, cx); - - let mut result = Vec::new(); - let mut ranges = ranges.into_iter(); - for (excerpt_id, range_count) in excerpt_ids.into_iter().zip(counts.into_iter()) { - for range in ranges.by_ref().take(range_count) { - let range = Anchor::range_in_buffer( - excerpt_id, - buffer_snapshot.anchor_before(&range.primary.start) - ..buffer_snapshot.anchor_after(&range.primary.end), - ); - result.push(range) - } - } - (result, added_a_new_excerpt) - } - - fn update_path_excerpts( - &mut self, - path: PathKey, - buffer: Entity, - buffer_snapshot: &BufferSnapshot, - new: Vec>, - cx: &mut Context, - ) -> (Vec, bool) { - let mut insert_after = self - .excerpts_by_path - .range(..path.clone()) - .next_back() - .and_then(|(_, value)| value.last().copied()) - .unwrap_or(ExcerptId::min()); - - let existing = self - .excerpts_by_path - .get(&path) - .cloned() - .unwrap_or_default(); - let mut new_iter = new.into_iter().peekable(); - let mut existing_iter = existing.into_iter().peekable(); - - let mut excerpt_ids = Vec::new(); - let mut to_remove = Vec::new(); - let mut to_insert: Vec<(ExcerptId, ExcerptRange)> = Vec::new(); - let mut added_a_new_excerpt = false; - let snapshot = self.snapshot(cx); - - let mut next_excerpt_id = - // is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping? - if let Some(last_entry) = self.snapshot.borrow().excerpt_ids.last() { - last_entry.id.0 + 1 - } else { - 1 - }; - - let mut next_excerpt_id = move || ExcerptId(post_inc(&mut next_excerpt_id)); - - let mut excerpts_cursor = snapshot.excerpts.cursor::>(()); - excerpts_cursor.next(); - - loop { - let existing = if let Some(&existing_id) = existing_iter.peek() { - let locator = snapshot.excerpt_locator_for_id(existing_id); - excerpts_cursor.seek_forward(&Some(locator), Bias::Left); - if let Some(excerpt) = excerpts_cursor.item() { - if excerpt.buffer_id != buffer_snapshot.remote_id() { - to_remove.push(existing_id); - existing_iter.next(); - continue; - } - Some((existing_id, excerpt.range.context.to_point(buffer_snapshot))) - } else { - None - } - } else { - None - }; - - let new = new_iter.peek(); - if let Some((last_id, last)) = to_insert.last_mut() { - if let Some(new) = new - && last.context.end >= new.context.start - { - last.context.end = last.context.end.max(new.context.end); - excerpt_ids.push(*last_id); - new_iter.next(); - continue; - } - if let Some((existing_id, existing_range)) = &existing - && last.context.end >= existing_range.start - { - last.context.end = last.context.end.max(existing_range.end); - to_remove.push(*existing_id); - self.snapshot - .get_mut() - .replaced_excerpts - .insert(*existing_id, *last_id); - existing_iter.next(); - continue; - } - } - - match (new, existing) { - (None, None) => break, - (None, Some((existing_id, _))) => { - existing_iter.next(); - to_remove.push(existing_id); - continue; - } - (Some(_), None) => { - added_a_new_excerpt = true; - let new_id = next_excerpt_id(); - excerpt_ids.push(new_id); - to_insert.push((new_id, new_iter.next().unwrap())); - continue; - } - (Some(new), Some((_, existing_range))) => { - if existing_range.end < new.context.start { - let existing_id = existing_iter.next().unwrap(); - to_remove.push(existing_id); - continue; - } else if existing_range.start > new.context.end { - let new_id = next_excerpt_id(); - excerpt_ids.push(new_id); - to_insert.push((new_id, new_iter.next().unwrap())); - continue; - } - - if existing_range.start == new.context.start - && existing_range.end == new.context.end - { - self.insert_excerpts_with_ids_after( - insert_after, - buffer.clone(), - mem::take(&mut to_insert), - cx, - ); - insert_after = existing_iter.next().unwrap(); - excerpt_ids.push(insert_after); - new_iter.next(); - } else { - let existing_id = existing_iter.next().unwrap(); - let new_id = next_excerpt_id(); - self.snapshot - .get_mut() - .replaced_excerpts - .insert(existing_id, new_id); - to_remove.push(existing_id); - let mut range = new_iter.next().unwrap(); - range.context.start = range.context.start.min(existing_range.start); - range.context.end = range.context.end.max(existing_range.end); - excerpt_ids.push(new_id); - to_insert.push((new_id, range)); - } - } - }; - } - - self.insert_excerpts_with_ids_after(insert_after, buffer, to_insert, cx); - // todo(lw): There is a logic bug somewhere that causes the to_remove vector to be not ordered correctly - to_remove.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id)); - self.remove_excerpts(to_remove, cx); - - if excerpt_ids.is_empty() { - self.excerpts_by_path.remove(&path); - } else { - for excerpt_id in &excerpt_ids { - self.paths_by_excerpt.insert(*excerpt_id, path.clone()); - } - let snapshot = &*self.snapshot.get_mut(); - let mut excerpt_ids: Vec<_> = excerpt_ids.iter().dedup().cloned().collect(); - excerpt_ids.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id)); - self.excerpts_by_path.insert(path, excerpt_ids); - } - - (excerpt_ids, added_a_new_excerpt) - } -} +use std::{mem, ops::Range, sync::Arc}; + +use collections::HashSet; +use gpui::{App, AppContext, Context, Entity}; +use itertools::Itertools; +use language::{Buffer, BufferSnapshot}; +use rope::Point; +use text::{Bias, BufferId, OffsetRangeExt, locator::Locator}; +use util::{post_inc, rel_path::RelPath}; +use ztracing::instrument; + +use crate::{ + Anchor, ExcerptId, ExcerptRange, ExpandExcerptDirection, MultiBuffer, build_excerpt_ranges, +}; + +#[derive(PartialEq, Eq, Ord, PartialOrd, Clone, Hash, Debug)] +pub struct PathKey { + // Used by the derived PartialOrd & Ord + pub sort_prefix: Option, + pub path: Arc, +} + +impl PathKey { + pub fn with_sort_prefix(sort_prefix: u64, path: Arc) -> Self { + Self { + sort_prefix: Some(sort_prefix), + path, + } + } + + pub fn for_buffer(buffer: &Entity, cx: &App) -> Self { + if let Some(file) = buffer.read(cx).file() { + Self::with_sort_prefix(file.worktree_id(cx).to_proto(), file.path().clone()) + } else { + Self { + sort_prefix: None, + path: RelPath::unix(&buffer.entity_id().to_string()) + .unwrap() + .into_arc(), + } + } + } +} + +impl MultiBuffer { + pub fn paths(&self) -> impl Iterator + '_ { + self.excerpts_by_path.keys().cloned() + } + + pub fn remove_excerpts_for_path(&mut self, path: PathKey, cx: &mut Context) { + if let Some(to_remove) = self.excerpts_by_path.remove(&path) { + self.remove_excerpts(to_remove, cx) + } + if let Some(follower) = &self.follower { + follower.update(cx, |follower, cx| { + follower.remove_excerpts_for_path(path, cx); + }); + } + } + + pub fn location_for_path(&self, path: &PathKey, cx: &App) -> Option { + let excerpt_id = self.excerpts_by_path.get(path)?.first()?; + let snapshot = self.read(cx); + let excerpt = snapshot.excerpt(*excerpt_id)?; + Some(Anchor::in_buffer(excerpt.id, excerpt.range.context.start)) + } + + pub fn excerpt_paths(&self) -> impl Iterator { + self.excerpts_by_path.keys() + } + + /// Sets excerpts, returns `true` if at least one new excerpt was added. + #[instrument(skip_all)] + pub fn set_excerpts_for_path( + &mut self, + path: PathKey, + buffer: Entity, + ranges: impl IntoIterator>, + context_line_count: u32, + cx: &mut Context, + ) -> (Vec>, bool) { + let buffer_snapshot = buffer.read(cx).snapshot(); + let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot); + + let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); + self.set_merged_excerpt_ranges_for_path( + path, + buffer, + excerpt_ranges, + &buffer_snapshot, + new, + counts, + cx, + ) + } + + pub fn set_excerpt_ranges_for_path( + &mut self, + path: PathKey, + buffer: Entity, + buffer_snapshot: &BufferSnapshot, + excerpt_ranges: Vec>, + cx: &mut Context, + ) -> (Vec>, bool) { + let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); + self.set_merged_excerpt_ranges_for_path( + path, + buffer, + excerpt_ranges, + buffer_snapshot, + new, + counts, + cx, + ) + } + + pub fn set_anchored_excerpts_for_path( + &self, + path_key: PathKey, + buffer: Entity, + ranges: Vec>, + context_line_count: u32, + cx: &Context, + ) -> impl Future>> + use<> { + let buffer_snapshot = buffer.read(cx).snapshot(); + let multi_buffer = cx.weak_entity(); + let mut app = cx.to_async(); + async move { + let snapshot = buffer_snapshot.clone(); + let (excerpt_ranges, new, counts) = app + .background_spawn(async move { + let ranges = ranges.into_iter().map(|range| range.to_point(&snapshot)); + let excerpt_ranges = + build_excerpt_ranges(ranges, context_line_count, &snapshot); + let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges); + (excerpt_ranges, new, counts) + }) + .await; + + multi_buffer + .update(&mut app, move |multi_buffer, cx| { + let (ranges, _) = multi_buffer.set_merged_excerpt_ranges_for_path( + path_key, + buffer, + excerpt_ranges, + &buffer_snapshot, + new, + counts, + cx, + ); + ranges + }) + .ok() + .unwrap_or_default() + } + } + + pub fn remove_excerpts_for_buffer(&mut self, buffer: BufferId, cx: &mut Context) { + self.remove_excerpts( + self.excerpts_for_buffer(buffer, cx) + .into_iter() + .map(|(excerpt, _)| excerpt), + cx, + ); + } + + pub(super) fn expand_excerpts_with_paths( + &mut self, + ids: impl IntoIterator, + line_count: u32, + direction: ExpandExcerptDirection, + cx: &mut Context, + ) { + let grouped = ids + .into_iter() + .chunk_by(|id| self.paths_by_excerpt.get(id).cloned()) + .into_iter() + .filter_map(|(k, v)| Some((k?, v.into_iter().collect::>()))) + .collect::>(); + let snapshot = self.snapshot(cx); + + for (path, ids) in grouped.into_iter() { + let Some(excerpt_ids) = self.excerpts_by_path.get(&path) else { + continue; + }; + + let ids_to_expand = HashSet::from_iter(ids); + let mut excerpt_id_ = None; + let expanded_ranges = excerpt_ids.iter().filter_map(|excerpt_id| { + let excerpt = snapshot.excerpt(*excerpt_id)?; + let excerpt_id = excerpt.id; + if excerpt_id_.is_none() { + excerpt_id_ = Some(excerpt_id); + } + + let mut context = excerpt.range.context.to_point(&excerpt.buffer); + if ids_to_expand.contains(&excerpt_id) { + match direction { + ExpandExcerptDirection::Up => { + context.start.row = context.start.row.saturating_sub(line_count); + context.start.column = 0; + } + ExpandExcerptDirection::Down => { + context.end.row = + (context.end.row + line_count).min(excerpt.buffer.max_point().row); + context.end.column = excerpt.buffer.line_len(context.end.row); + } + ExpandExcerptDirection::UpAndDown => { + context.start.row = context.start.row.saturating_sub(line_count); + context.start.column = 0; + context.end.row = + (context.end.row + line_count).min(excerpt.buffer.max_point().row); + context.end.column = excerpt.buffer.line_len(context.end.row); + } + } + } + + Some(ExcerptRange { + context, + primary: excerpt.range.primary.to_point(&excerpt.buffer), + }) + }); + let mut merged_ranges: Vec> = Vec::new(); + for range in expanded_ranges { + if let Some(last_range) = merged_ranges.last_mut() + && last_range.context.end >= range.context.start + { + last_range.context.end = range.context.end; + continue; + } + merged_ranges.push(range) + } + let Some(excerpt_id) = excerpt_id_ else { + continue; + }; + let Some(buffer_id) = &snapshot.buffer_id_for_excerpt(excerpt_id) else { + continue; + }; + + let Some(buffer) = self.buffers.get(buffer_id).map(|b| b.buffer.clone()) else { + continue; + }; + + let buffer_snapshot = buffer.read(cx).snapshot(); + self.update_path_excerpts(path.clone(), buffer, &buffer_snapshot, merged_ranges, cx); + } + } + + /// Sets excerpts, returns `true` if at least one new excerpt was added. + fn set_merged_excerpt_ranges_for_path( + &mut self, + path: PathKey, + buffer: Entity, + ranges: Vec>, + buffer_snapshot: &BufferSnapshot, + new: Vec>, + counts: Vec, + cx: &mut Context, + ) -> (Vec>, bool) { + let (excerpt_ids, added_a_new_excerpt) = + self.update_path_excerpts(path, buffer, buffer_snapshot, new, cx); + + let mut result = Vec::new(); + let mut ranges = ranges.into_iter(); + for (excerpt_id, range_count) in excerpt_ids.into_iter().zip(counts.into_iter()) { + for range in ranges.by_ref().take(range_count) { + let range = Anchor::range_in_buffer( + excerpt_id, + buffer_snapshot.anchor_before(&range.primary.start) + ..buffer_snapshot.anchor_after(&range.primary.end), + ); + result.push(range) + } + } + (result, added_a_new_excerpt) + } + + fn update_path_excerpts( + &mut self, + path: PathKey, + buffer: Entity, + buffer_snapshot: &BufferSnapshot, + new: Vec>, + cx: &mut Context, + ) -> (Vec, bool) { + let mut insert_after = self + .excerpts_by_path + .range(..path.clone()) + .next_back() + .and_then(|(_, value)| value.last().copied()) + .unwrap_or(ExcerptId::min()); + + let existing = self + .excerpts_by_path + .get(&path) + .cloned() + .unwrap_or_default(); + let mut new_iter = new.into_iter().peekable(); + let mut existing_iter = existing.into_iter().peekable(); + + let mut excerpt_ids = Vec::new(); + let mut to_remove = Vec::new(); + let mut to_insert: Vec<(ExcerptId, ExcerptRange)> = Vec::new(); + let mut added_a_new_excerpt = false; + let snapshot = self.snapshot(cx); + + let mut next_excerpt_id = + // is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping? + if let Some(last_entry) = self.snapshot.borrow().excerpt_ids.last() { + last_entry.id.0 + 1 + } else { + 1 + }; + + let mut next_excerpt_id = move || ExcerptId(post_inc(&mut next_excerpt_id)); + + let mut excerpts_cursor = snapshot.excerpts.cursor::>(()); + excerpts_cursor.next(); + + loop { + let existing = if let Some(&existing_id) = existing_iter.peek() { + let locator = snapshot.excerpt_locator_for_id(existing_id); + excerpts_cursor.seek_forward(&Some(locator), Bias::Left); + if let Some(excerpt) = excerpts_cursor.item() { + if excerpt.buffer_id != buffer_snapshot.remote_id() { + to_remove.push(existing_id); + existing_iter.next(); + continue; + } + Some((existing_id, excerpt.range.context.to_point(buffer_snapshot))) + } else { + None + } + } else { + None + }; + + let new = new_iter.peek(); + if let Some((last_id, last)) = to_insert.last_mut() { + if let Some(new) = new + && last.context.end >= new.context.start + { + last.context.end = last.context.end.max(new.context.end); + excerpt_ids.push(*last_id); + new_iter.next(); + continue; + } + if let Some((existing_id, existing_range)) = &existing + && last.context.end >= existing_range.start + { + last.context.end = last.context.end.max(existing_range.end); + to_remove.push(*existing_id); + self.snapshot + .get_mut() + .replaced_excerpts + .insert(*existing_id, *last_id); + existing_iter.next(); + continue; + } + } + + match (new, existing) { + (None, None) => break, + (None, Some((existing_id, _))) => { + existing_iter.next(); + to_remove.push(existing_id); + continue; + } + (Some(_), None) => { + added_a_new_excerpt = true; + let new_id = next_excerpt_id(); + excerpt_ids.push(new_id); + to_insert.push((new_id, new_iter.next().unwrap())); + continue; + } + (Some(new), Some((_, existing_range))) => { + if existing_range.end < new.context.start { + let existing_id = existing_iter.next().unwrap(); + to_remove.push(existing_id); + continue; + } else if existing_range.start > new.context.end { + let new_id = next_excerpt_id(); + excerpt_ids.push(new_id); + to_insert.push((new_id, new_iter.next().unwrap())); + continue; + } + + if existing_range.start == new.context.start + && existing_range.end == new.context.end + { + self.insert_excerpts_with_ids_after( + insert_after, + buffer.clone(), + mem::take(&mut to_insert), + cx, + ); + insert_after = existing_iter.next().unwrap(); + excerpt_ids.push(insert_after); + new_iter.next(); + } else { + let existing_id = existing_iter.next().unwrap(); + let new_id = next_excerpt_id(); + self.snapshot + .get_mut() + .replaced_excerpts + .insert(existing_id, new_id); + to_remove.push(existing_id); + let mut range = new_iter.next().unwrap(); + range.context.start = range.context.start.min(existing_range.start); + range.context.end = range.context.end.max(existing_range.end); + excerpt_ids.push(new_id); + to_insert.push((new_id, range)); + } + } + }; + } + + self.insert_excerpts_with_ids_after(insert_after, buffer, to_insert, cx); + // todo(lw): There is a logic bug somewhere that causes the to_remove vector to be not ordered correctly + to_remove.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id)); + self.remove_excerpts(to_remove, cx); + + if excerpt_ids.is_empty() { + self.excerpts_by_path.remove(&path); + } else { + for excerpt_id in &excerpt_ids { + self.paths_by_excerpt.insert(*excerpt_id, path.clone()); + } + let snapshot = &*self.snapshot.get_mut(); + let mut excerpt_ids: Vec<_> = excerpt_ids.iter().dedup().cloned().collect(); + excerpt_ids.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id)); + self.excerpts_by_path.insert(path, excerpt_ids); + } + + (excerpt_ids, added_a_new_excerpt) + } +} diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index a33efb9896959cc12fd828986c881f73e84e0ec7..9e2789fc109b8217f0f1033cc6d4832105c0ad48 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -91,6 +91,8 @@ which.workspace = true worktree.workspace = true zeroize.workspace = true zlog.workspace = true +ztracing.workspace = true +tracing.workspace = true [dev-dependencies] client = { workspace = true, features = ["test-support"] } @@ -113,3 +115,6 @@ snippet_provider = { workspace = true, features = ["test-support"] } unindent.workspace = true util = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/project/src/git_store/branch_diff.rs b/crates/project/src/git_store/branch_diff.rs index 5065eafe4e185e65ce144f6d797ac8ccd616d5fa..dd0026961ec7ad77b674e2e9506b3133f07ce3f2 100644 --- a/crates/project/src/git_store/branch_diff.rs +++ b/crates/project/src/git_store/branch_diff.rs @@ -14,6 +14,7 @@ use gpui::{ use language::Buffer; use text::BufferId; use util::ResultExt; +use ztracing::instrument; use crate::{ Project, @@ -254,6 +255,7 @@ impl BranchDiff { self.repo.as_ref() } + #[instrument(skip_all)] pub fn load_buffers(&mut self, cx: &mut Context) -> Vec { let mut output = Vec::default(); let Some(repo) = self.repo.clone() else { @@ -318,6 +320,7 @@ impl BranchDiff { output } + #[instrument(skip_all)] fn load_buffer( branch_diff: Option, project_path: crate::ProjectPath, diff --git a/crates/rope/Cargo.toml b/crates/rope/Cargo.toml index 4107c2e012debc13b0cc44003250f4da63e5039f..9f0fc2be8a021a4cd43679beefb18a3567452dde 100644 --- a/crates/rope/Cargo.toml +++ b/crates/rope/Cargo.toml @@ -18,6 +18,8 @@ rayon.workspace = true sum_tree.workspace = true unicode-segmentation.workspace = true util.workspace = true +ztracing.workspace = true +tracing.workspace = true [dev-dependencies] ctor.workspace = true @@ -30,3 +32,6 @@ zlog.workspace = true [[bench]] name = "rope_benchmark" harness = false + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 2d3c811e179fbd47cada7c2bebb89b03acd3eeb0..50f9ba044d90072aa9c6fc2fc4abfd6d0e6b98cb 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -12,6 +12,7 @@ use std::{ str, }; use sum_tree::{Bias, Dimension, Dimensions, SumTree}; +use ztracing::instrument; pub use chunk::{Chunk, ChunkSlice}; pub use offset_utf16::OffsetUtf16; @@ -428,6 +429,7 @@ impl Rope { }) } + #[instrument(skip_all)] pub fn point_to_offset(&self, point: Point) -> usize { if point >= self.summary().lines { return self.summary().len; diff --git a/crates/sum_tree/Cargo.toml b/crates/sum_tree/Cargo.toml index 81916c842225085ceec4721dbd8d212608f6bcb9..3e06ede162dad37f94017207ccbd6ee5c38f26a5 100644 --- a/crates/sum_tree/Cargo.toml +++ b/crates/sum_tree/Cargo.toml @@ -17,8 +17,13 @@ doctest = false arrayvec = "0.7.1" rayon.workspace = true log.workspace = true +ztracing.workspace = true +tracing.workspace = true [dev-dependencies] ctor.workspace = true rand.workspace = true zlog.workspace = true + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/sum_tree/src/cursor.rs b/crates/sum_tree/src/cursor.rs index 0ca89d16db9f8b4dae6e8283c673f781dbdd27dc..589ae96a2aa3293490aa91674dd3e0cac127e3cc 100644 --- a/crates/sum_tree/src/cursor.rs +++ b/crates/sum_tree/src/cursor.rs @@ -1,6 +1,7 @@ use super::*; use arrayvec::ArrayVec; use std::{cmp::Ordering, mem, sync::Arc}; +use ztracing::instrument; #[derive(Clone)] struct StackEntry<'a, T: Item, D> { @@ -211,6 +212,7 @@ where } #[track_caller] + #[instrument(skip_all)] pub fn prev(&mut self) { self.search_backward(|_| true) } @@ -394,6 +396,7 @@ where { /// Returns whether we found the item you were seeking for. #[track_caller] + #[instrument(skip_all)] pub fn seek(&mut self, pos: &Target, bias: Bias) -> bool where Target: SeekTarget<'a, T::Summary, D>, @@ -408,6 +411,7 @@ where /// /// If we did not seek before, use seek instead in that case. #[track_caller] + #[instrument(skip_all)] pub fn seek_forward(&mut self, pos: &Target, bias: Bias) -> bool where Target: SeekTarget<'a, T::Summary, D>, @@ -449,6 +453,7 @@ where /// Returns whether we found the item you were seeking for. #[track_caller] + #[instrument(skip_all)] fn seek_internal( &mut self, target: &dyn SeekTarget<'a, T::Summary, D>, diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index da700201f558a0b29ed4dc45bd3d3d3e7474a297..bfc4587969ec67bbda2fb90d34550c7d464317c9 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -8,6 +8,7 @@ use std::marker::PhantomData; use std::mem; use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc}; pub use tree_map::{MapSeekTarget, TreeMap, TreeSet}; +use ztracing::instrument; #[cfg(test)] pub const TREE_BASE: usize = 2; @@ -379,6 +380,7 @@ impl SumTree { /// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`. /// /// Only returns the item that exactly has the target match. + #[instrument(skip_all)] pub fn find_exact<'a, 'slf, D, Target>( &'slf self, cx: ::Context<'a>, @@ -404,6 +406,7 @@ impl SumTree { } /// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()` + #[instrument(skip_all)] pub fn find<'a, 'slf, D, Target>( &'slf self, cx: ::Context<'a>, diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 6ee7d0a4ea75ff5e13a4db6f5fe73c2a5ba80193..e304ad7f5cd94c05daab2755cb9e7bed21fe0f8d 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -144,6 +144,8 @@ theme_extension.workspace = true theme_selector.workspace = true time.workspace = true title_bar.workspace = true +ztracing.workspace = true +tracing.workspace = true toolchain_selector.workspace = true ui.workspace = true ui_input.workspace = true @@ -223,4 +225,4 @@ osx_info_plist_exts = ["resources/info/*"] osx_url_schemes = ["zed"] [package.metadata.cargo-machete] -ignored = ["profiling", "zstd"] +ignored = ["profiling", "zstd", "tracing"] diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 10f599e876032bf297d3eaf173093a308d666cc9..7751e6cb0118e3590488600ca2601645d6657fb7 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -162,10 +162,11 @@ fn fail_to_open_window(e: anyhow::Error, _cx: &mut App) { .detach(); } } - pub static STARTUP_TIME: OnceLock = OnceLock::new(); pub fn main() { + ztracing::init(); + STARTUP_TIME.get_or_init(|| Instant::now()); #[cfg(unix)] diff --git a/crates/ztracing/Cargo.toml b/crates/ztracing/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..fbc9dc032d2d485f74a15e5fe3b073a7017911fd --- /dev/null +++ b/crates/ztracing/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ztracing" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[dependencies] +tracing.workspace = true + +tracing-subscriber = "0.3.22" +tracing-tracy = { workspace = true, features = ["enable", "ondemand"] } + +ztracing_macro.workspace = true diff --git a/crates/ztracing/LICENSE-AGPL b/crates/ztracing/LICENSE-AGPL new file mode 120000 index 0000000000000000000000000000000000000000..5f5cf25dc458e75f4050c7378c186fca9b68fd19 --- /dev/null +++ b/crates/ztracing/LICENSE-AGPL @@ -0,0 +1 @@ +../../LICENSE-AGPL \ No newline at end of file diff --git a/crates/ztracing/LICENSE-APACHE b/crates/ztracing/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/ztracing/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/ztracing/LICENSE-GPL b/crates/ztracing/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/ztracing/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/ztracing/build.rs b/crates/ztracing/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc0d0ad704d49c4c0ab639d769024330e10d2481 --- /dev/null +++ b/crates/ztracing/build.rs @@ -0,0 +1,9 @@ +use std::env; + +fn main() { + if env::var_os("ZTRACING").is_some() { + println!(r"cargo::rustc-cfg=ztracing"); + } + println!("cargo::rerun-if-changed=build.rs"); + println!("cargo::rerun-if-env-changed=ZTRACING"); +} diff --git a/crates/ztracing/src/lib.rs b/crates/ztracing/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..1ab687a2f4550e9b08432764dd7f80aedf5791c0 --- /dev/null +++ b/crates/ztracing/src/lib.rs @@ -0,0 +1,16 @@ +#[cfg(ztracing)] +pub use tracing::instrument; +#[cfg(not(ztracing))] +pub use ztracing_macro::instrument; + +#[cfg(ztracing)] +pub fn init() { + use tracing_subscriber::prelude::*; + tracing::subscriber::set_global_default( + tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()), + ) + .expect("setup tracy layer"); +} + +#[cfg(not(ztracing))] +pub fn init() {} diff --git a/crates/ztracing_macro/Cargo.toml b/crates/ztracing_macro/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..dbd7adce5fccd054c3dc87acaf1283e9e7c36889 --- /dev/null +++ b/crates/ztracing_macro/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ztracing_macro" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lib] +proc-macro = true + +[dependencies] diff --git a/crates/ztracing_macro/LICENSE-AGPL b/crates/ztracing_macro/LICENSE-AGPL new file mode 120000 index 0000000000000000000000000000000000000000..5f5cf25dc458e75f4050c7378c186fca9b68fd19 --- /dev/null +++ b/crates/ztracing_macro/LICENSE-AGPL @@ -0,0 +1 @@ +../../LICENSE-AGPL \ No newline at end of file diff --git a/crates/ztracing_macro/LICENSE-APACHE b/crates/ztracing_macro/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/ztracing_macro/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/ztracing_macro/LICENSE-GPL b/crates/ztracing_macro/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/ztracing_macro/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/ztracing_macro/src/lib.rs b/crates/ztracing_macro/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..d9b073ed130bdc829e4d5d943b6d4b6a6d802888 --- /dev/null +++ b/crates/ztracing_macro/src/lib.rs @@ -0,0 +1,7 @@ +#[proc_macro_attribute] +pub fn instrument( + _attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + item +} diff --git a/docs/src/performance.md b/docs/src/performance.md index a04d7c5c342d4f0dfa506451d4b890bfdfd1013c..4adc38f5eea27de26f1d5818b6787fb78ae1d1ad 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -1,6 +1,6 @@ How to use our internal tools to profile and keep Zed fast. -# Flamechart/CPU profiling +# Rough quick CPU profiling (Flamechart) See what the CPU spends the most time on. Strongly recommend you use [samply](https://github.com/mstange/samply). It opens an interactive profile in @@ -12,6 +12,46 @@ The profile.json does not contain any symbols. Firefox profiler can add the loca image +# In depth CPU profiling (Tracing) + +See how long each annotated function call took and its arguments (if +configured). + +Annotate any function you need appear in the profile with instrument. For more +details see +[tracing-instrument](https://docs.rs/tracing/latest/tracing/attr.instrument.html): + +```rust +#[instrument(skip_all)] +fn should_appear_in_profile(kitty: Cat) { + sleep(QUITE_LONG) +} +``` + +Then either compile Zed with `ZTRACING=1 cargo r --release`. The release build is optional but highly recommended as like every program Zeds performance characteristics change dramatically with optimizations. You do not want to chase slowdowns that do not exist in release. + +## One time Setup/Building the profiler: + +Download the profiler: +[linux x86_64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-profiler-linux-x86_64) +[macos aarch64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-profiler-0.13.0-macos-aarch64) + +### Alternative: Building it yourself + +- Clone the repo at git@github.com:wolfpld/tracy.git +- `cd profiler && mkdir build && cd build` +- Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..` +- Build the profiler: `ninja` +- [Optional] move the profiler somewhere nice like ~/.local/bin on linux + +## Usage + +Open the profiler (tracy-profiler), you should see zed in the list of `Discovered clients` click it. +image + +To find functions that take a long time follow this image: +image + # Task/Async profiling Get a profile of the zed foreground executor and background executors. Check if @@ -23,11 +63,17 @@ look at the results live. ## Setup/Building the importer: +Download the importer +[linux x86_64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-import-miniprofiler-linux-x86_64) +[mac aarch64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-import-miniprofiler-macos-aarch64) + +### Alternative: Building it yourself + - Clone the repo at git@github.com:zed-industries/tracy.git on v0.12.2 branch -- `cd profiler && mkdir build && cd build` +- `cd import && mkdir build && cd build` - Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..` - Build the importer: `ninja` -- Run the impoter on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy` +- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy` - Open the trace in tracy: - If you're on windows download the v0.12.2 version from the releases on the upstream repo - If you're on other platforms open it on the website: https://tracy.nereid.pl/ (the version might mismatch so your luck might vary, we need to host our own ideally..) From d76dd86272361eebd96cf9495e4d9899e287ddad Mon Sep 17 00:00:00 2001 From: Dino Date: Fri, 5 Dec 2025 18:18:51 +0000 Subject: [PATCH 21/81] tab_switcher: Add documentation for tab switcher (#44189) Release Notes: - Added documentation for Tab Switcher --- assets/keymaps/default-linux.json | 2 +- assets/keymaps/default-macos.json | 2 +- assets/keymaps/default-windows.json | 2 +- docs/src/SUMMARY.md | 1 + docs/src/tab-switcher.md | 46 +++++++++++++++++++++++++++++ 5 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 docs/src/tab-switcher.md diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 41415bf2047e1faadd86dd5be159f526d6c57678..54a4f331c0b0c59eca79065fe42c1a8ecbf646b7 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -616,8 +616,8 @@ "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "ctrl-e": "file_finder::Toggle", "f1": "command_palette::Toggle", "ctrl-shift-p": "command_palette::Toggle", diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index fa8edbe5c23b008eb2c267850e440a851c54087d..060151c647e42370f5aa0be5d2fa186774c2574d 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -684,8 +684,8 @@ "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "cmd-shift-p": "command_palette::Toggle", "cmd-shift-m": "diagnostics::Deploy", "cmd-shift-e": "project_panel::ToggleFocus", diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 45f37fbd41af3fcc3108f0ffe150a80ff25332e1..32b52365e08e50266ad5feb7630a7b03f860c8e8 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -608,8 +608,8 @@ "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "ctrl-e": "file_finder::Toggle", "f1": "command_palette::Toggle", "ctrl-shift-p": "command_palette::Toggle", diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 0d0cd35f43610d206749dea7a87af553620633f0..9d1f6f61d446b67256c00bf6322aed73af922c5e 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -41,6 +41,7 @@ - [Debugger](./debugger.md) - [Diagnostics](./diagnostics.md) - [Tasks](./tasks.md) +- [Tab Switcher](./tab-switcher.md) - [Remote Development](./remote-development.md) - [Environment Variables](./environment.md) - [REPL](./repl.md) diff --git a/docs/src/tab-switcher.md b/docs/src/tab-switcher.md new file mode 100644 index 0000000000000000000000000000000000000000..5cc72be449c94c38fbe4814893595289cb499b5a --- /dev/null +++ b/docs/src/tab-switcher.md @@ -0,0 +1,46 @@ +# Tab Switcher + +The Tab Switcher provides a quick way to navigate between open tabs in Zed. It +displays a list of your open tabs sorted by recent usage, making it easy to jump +back to whatever you were just working on. + +![Tab Switcher with multiple panes](https://zed.dev/img/features/tab-switcher.png) + +## Quick Switching + +When the Tab Switcher is opened using {#kb tab_switcher::Toggle}, instead of +running the {#action tab_switcher::Toggle} from the command palette, it'll stay +active as long as the ctrl key is held down. + +While holding down ctrl, each subsequent tab press cycles to the next item (shift to cycle backwards) and, when ctrl is released, the selected item is confirmed and +the switcher is closed. + +## Opening the Tab Switcher + +The Tab Switcher can also be opened with either {#action tab_switcher::Toggle} +or {#action tab_switcher::ToggleAll}. Using {#kb tab_switcher::Toggle} will show +only the tabs for the current pane, while {#kb tab_switcher::ToggleAll} shows +all tabs for all panes. + +While the Tab Switcher is open, you can: + +- Press {#kb menu::SelectNext} to move to the next tab in the list +- Press {#kb menu::SelectPrevious} to move to the previous tab +- Press enter to confirm the selected tab and close the switcher +- Press escape to close the switcher and return to the original tab from which + the switcher was opened +- Press {#kb tab_switcher::CloseSelectedItem} to close the currently selected tab + +As you navigate through the list, Zed will update the pane's active item to +match the selected tab. + +## Action Reference + +| Action | Description | +| ----------------------------------------- | ------------------------------------------------- | +| {#action tab_switcher::Toggle} | Open the Tab Switcher for the current pane | +| {#action tab_switcher::ToggleAll} | Open the Tab Switcher showing tabs from all panes | +| {#action tab_switcher::CloseSelectedItem} | Close the selected tab in the Tab Switcher | From 37b0cdf94ba931db41039a42a78993eb3e7b0bd0 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 5 Dec 2025 19:20:29 +0100 Subject: [PATCH 22/81] multi_buffer: Remap excerpt ids to latest excerpt in excerpt fetching (#44229) Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... Co-authored by: Cole Miller --- crates/editor/src/selections_collection.rs | 16 ++++++++++++---- crates/multi_buffer/src/multi_buffer.rs | 5 +++-- crates/multi_buffer/src/path_key.rs | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/crates/editor/src/selections_collection.rs b/crates/editor/src/selections_collection.rs index f8ff9da763403b0946e99a4e39c934ff43ad6634..6c6c88faf5e32195e049228fe573d19e13ae111a 100644 --- a/crates/editor/src/selections_collection.rs +++ b/crates/editor/src/selections_collection.rs @@ -419,22 +419,30 @@ impl SelectionsCollection { mutable_collection.disjoint.iter().for_each(|selection| { assert!( snapshot.can_resolve(&selection.start), - "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}", + "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}, {excerpt:?}", + excerpt = snapshot.buffer_for_excerpt(selection.start.excerpt_id).map(|snapshot| snapshot.remote_id()), ); assert!( snapshot.can_resolve(&selection.end), - "disjoint selection end is not resolvable for the given snapshot: {selection:?}", + "disjoint selection end is not resolvable for the given snapshot: {selection:?}, {excerpt:?}", + excerpt = snapshot.buffer_for_excerpt(selection.end.excerpt_id).map(|snapshot| snapshot.remote_id()), ); }); if let Some(pending) = &mutable_collection.pending { let selection = &pending.selection; assert!( snapshot.can_resolve(&selection.start), - "pending selection start is not resolvable for the given snapshot: {pending:?}", + "pending selection start is not resolvable for the given snapshot: {pending:?}, {excerpt:?}", + excerpt = snapshot + .buffer_for_excerpt(selection.start.excerpt_id) + .map(|snapshot| snapshot.remote_id()), ); assert!( snapshot.can_resolve(&selection.end), - "pending selection end is not resolvable for the given snapshot: {pending:?}", + "pending selection end is not resolvable for the given snapshot: {pending:?}, {excerpt:?}", + excerpt = snapshot + .buffer_for_excerpt(selection.end.excerpt_id) + .map(|snapshot| snapshot.remote_id()), ); } } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index 24cb55d2f5e7311cc492ec70ab320eb12e78f8ee..bd163557c4f6239353e7cd5ad08a6120e20e4a3d 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -6457,12 +6457,13 @@ impl MultiBufferSnapshot { } /// Returns the excerpt for the given id. The returned excerpt is guaranteed - /// to have the same excerpt id as the one passed in, with the exception of - /// `ExcerptId::max()`. + /// to have the latest excerpt id for the one passed in and will also remap + /// `ExcerptId::max()` to the corresponding excertp ID. /// /// Callers of this function should generally use the resulting excerpt's `id` field /// afterwards. fn excerpt(&self, excerpt_id: ExcerptId) -> Option<&Excerpt> { + let excerpt_id = self.latest_excerpt_id(excerpt_id); let mut cursor = self.excerpts.cursor::>(()); let locator = self.excerpt_locator_for_id(excerpt_id); cursor.seek(&Some(locator), Bias::Left); diff --git a/crates/multi_buffer/src/path_key.rs b/crates/multi_buffer/src/path_key.rs index 82bb902c230180d98c54225e8b57bf85beeedc2d..119194d088c946941b13ffab3f6f2b3ea126cd09 100644 --- a/crates/multi_buffer/src/path_key.rs +++ b/crates/multi_buffer/src/path_key.rs @@ -305,7 +305,7 @@ impl MultiBuffer { let snapshot = self.snapshot(cx); let mut next_excerpt_id = - // is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping? + // todo(lw): is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping? if let Some(last_entry) = self.snapshot.borrow().excerpt_ids.last() { last_entry.id.0 + 1 } else { From 3bb6c2546a1d104c9198096339e9317080f6ee87 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:46:28 -0300 Subject: [PATCH 23/81] git_ui: Fix history view label truncation (#44218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There's still a weird problem happening where the labels (and the label on the tab, too, for what is worth) flicker as the file history view gets smaller. I suspect that problem is related to something else—potentially the truncation algorithm or focus management—so I'm not solving it here. Screenshot 2025-12-05 at 11  24@2x Release Notes: - N/A --- crates/git_ui/src/file_history_view.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/git_ui/src/file_history_view.rs b/crates/git_ui/src/file_history_view.rs index e34806aacae48122caf3a12246b04862898f2bed..5b3588d29678ec406749ec45be3de154fd71c5f8 100644 --- a/crates/git_ui/src/file_history_view.rs +++ b/crates/git_ui/src/file_history_view.rs @@ -267,15 +267,19 @@ impl FileHistoryView { .child(self.render_commit_avatar(&entry.sha, window, cx)) .child( h_flex() + .min_w_0() .w_full() .justify_between() .child( h_flex() + .min_w_0() + .w_full() .gap_1() .child( Label::new(entry.author_name.clone()) .size(LabelSize::Small) - .color(Color::Default), + .color(Color::Default) + .truncate(), ) .child( Label::new(&entry.subject) @@ -285,9 +289,11 @@ impl FileHistoryView { ), ) .child( - Label::new(relative_timestamp) - .size(LabelSize::Small) - .color(Color::Muted), + h_flex().flex_none().child( + Label::new(relative_timestamp) + .size(LabelSize::Small) + .color(Color::Muted), + ), ), ), ) From f9cea5af29a988be7932e48d5a0f1b5e15c51d0d Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 5 Dec 2025 19:53:53 +0100 Subject: [PATCH 24/81] Fix project not getting dropped after closing window (#44237) --- crates/agent_ui/src/acp/entry_view_state.rs | 12 ++-- crates/agent_ui/src/acp/message_editor.rs | 44 ++++++------- crates/agent_ui/src/acp/thread_view.rs | 4 +- .../assistant_text_thread/src/text_thread.rs | 16 ++--- .../src/text_thread_store.rs | 65 +++++++++++++------ 5 files changed, 80 insertions(+), 61 deletions(-) diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 53f24947658be8def877eb6b3a7d4e29b541d0c0..feae74a86bc241c5d2e01f0941eafc60210f1bf6 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -22,7 +22,7 @@ use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; pub struct EntryViewState { workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, entries: Vec, @@ -34,7 +34,7 @@ pub struct EntryViewState { impl EntryViewState { pub fn new( workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, prompt_capabilities: Rc>, @@ -328,7 +328,7 @@ impl Entry { fn create_terminal( workspace: WeakEntity, - project: Entity, + project: WeakEntity, terminal: Entity, window: &mut Window, cx: &mut App, @@ -336,9 +336,9 @@ fn create_terminal( cx.new(|cx| { let mut view = TerminalView::new( terminal.read(cx).inner().clone(), - workspace.clone(), + workspace, None, - project.downgrade(), + project, window, cx, ); @@ -458,7 +458,7 @@ mod tests { let view_state = cx.new(|_cx| { EntryViewState::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store, None, Default::default(), diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 827990599912fe832d40605fb1dceb58eab4ff2f..875dc495cea710d1df950b47b328042bbda4a287 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -39,7 +39,6 @@ use zed_actions::agent::Chat; pub struct MessageEditor { mention_set: Entity, editor: Entity, - project: Entity, workspace: WeakEntity, prompt_capabilities: Rc>, available_commands: Rc>>, @@ -98,7 +97,7 @@ impl PromptCompletionProviderDelegate for Entity { impl MessageEditor { pub fn new( workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, prompt_capabilities: Rc>, @@ -135,13 +134,8 @@ impl MessageEditor { editor.register_addon(MessageEditorAddon::new()); editor }); - let mention_set = cx.new(|_cx| { - MentionSet::new( - project.downgrade(), - history_store.clone(), - prompt_store.clone(), - ) - }); + let mention_set = + cx.new(|_cx| MentionSet::new(project, history_store.clone(), prompt_store.clone())); let completion_provider = Rc::new(PromptCompletionProvider::new( cx.entity(), editor.downgrade(), @@ -199,7 +193,6 @@ impl MessageEditor { Self { editor, - project, mention_set, workspace, prompt_capabilities, @@ -572,17 +565,18 @@ impl MessageEditor { let Some(workspace) = self.workspace.upgrade() else { return; }; - let path_style = self.project.read(cx).path_style(cx); + let project = workspace.read(cx).project().clone(); + let path_style = project.read(cx).path_style(cx); let buffer = self.editor.read(cx).buffer().clone(); let Some(buffer) = buffer.read(cx).as_singleton() else { return; }; let mut tasks = Vec::new(); for path in paths { - let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { + let Some(entry) = project.read(cx).entry_for_path(&path, cx) else { continue; }; - let Some(worktree) = self.project.read(cx).worktree_for_id(path.worktree_id, cx) else { + let Some(worktree) = project.read(cx).worktree_for_id(path.worktree_id, cx) else { continue; }; let abs_path = worktree.read(cx).absolutize(&path.path); @@ -690,9 +684,13 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + self.clear(window, cx); - let path_style = self.project.read(cx).path_style(cx); + let path_style = workspace.read(cx).project().read(cx).path_style(cx); let mut text = String::new(); let mut mentions = Vec::new(); @@ -935,7 +933,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -1046,7 +1044,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace_handle.clone(), - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1207,7 +1205,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1429,7 +1427,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1920,7 +1918,7 @@ mod tests { cx.new(|cx| { let editor = MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2025,7 +2023,7 @@ mod tests { cx.new(|cx| { let mut editor = MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2094,7 +2092,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2157,7 +2155,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2315,7 +2313,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index aedb96bb82f07723f934d0ec73aa1fd545461f00..c917a48ad5bac67e7dcdef94dbace97b26843404 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -344,7 +344,7 @@ impl AcpThreadView { let message_editor = cx.new(|cx| { let mut editor = MessageEditor::new( workspace.clone(), - project.clone(), + project.downgrade(), history_store.clone(), prompt_store.clone(), prompt_capabilities.clone(), @@ -369,7 +369,7 @@ impl AcpThreadView { let entry_view_state = cx.new(|_| { EntryViewState::new( workspace.clone(), - project.clone(), + project.downgrade(), history_store.clone(), prompt_store.clone(), prompt_capabilities.clone(), diff --git a/crates/assistant_text_thread/src/text_thread.rs b/crates/assistant_text_thread/src/text_thread.rs index 7f24c8f665f8d34aed199562dce1131797f13c9d..b808d9fb0019ccad25366d9ae60cc1f765126c74 100644 --- a/crates/assistant_text_thread/src/text_thread.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -14,7 +14,7 @@ use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription, - Task, + Task, WeakEntity, }; use itertools::Itertools as _; use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; @@ -688,7 +688,7 @@ pub struct TextThread { _subscriptions: Vec, telemetry: Option>, language_registry: Arc, - project: Option>, + project: Option>, prompt_builder: Arc, completion_mode: agent_settings::CompletionMode, } @@ -708,7 +708,7 @@ impl EventEmitter for TextThread {} impl TextThread { pub fn local( language_registry: Arc, - project: Option>, + project: Option>, telemetry: Option>, prompt_builder: Arc, slash_commands: Arc, @@ -742,7 +742,7 @@ impl TextThread { language_registry: Arc, prompt_builder: Arc, slash_commands: Arc, - project: Option>, + project: Option>, telemetry: Option>, cx: &mut Context, ) -> Self { @@ -873,7 +873,7 @@ impl TextThread { language_registry: Arc, prompt_builder: Arc, slash_commands: Arc, - project: Option>, + project: Option>, telemetry: Option>, cx: &mut Context, ) -> Self { @@ -1167,10 +1167,6 @@ impl TextThread { self.language_registry.clone() } - pub fn project(&self) -> Option> { - self.project.clone() - } - pub fn prompt_builder(&self) -> Arc { self.prompt_builder.clone() } @@ -2967,7 +2963,7 @@ impl TextThread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { - let Some(project) = &self.project else { + let Some(project) = self.project.as_ref().and_then(|project| project.upgrade()) else { return; }; project.read(cx).user_store().update(cx, |user_store, cx| { diff --git a/crates/assistant_text_thread/src/text_thread_store.rs b/crates/assistant_text_thread/src/text_thread_store.rs index 19c317baf0fa728c77faebc388b5e36008aa39b3..71fabed503e8c04a8865bed72c28ae5b30e75574 100644 --- a/crates/assistant_text_thread/src/text_thread_store.rs +++ b/crates/assistant_text_thread/src/text_thread_store.rs @@ -51,7 +51,7 @@ pub struct TextThreadStore { telemetry: Arc, _watch_updates: Task>, client: Arc, - project: Entity, + project: WeakEntity, project_is_shared: bool, client_subscription: Option, _project_subscriptions: Vec, @@ -119,10 +119,10 @@ impl TextThreadStore { ], project_is_shared: false, client: project.read(cx).client(), - project: project.clone(), + project: project.downgrade(), prompt_builder, }; - this.handle_project_shared(project.clone(), cx); + this.handle_project_shared(cx); this.synchronize_contexts(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); @@ -146,7 +146,7 @@ impl TextThreadStore { telemetry: project.read(cx).client().telemetry().clone(), _watch_updates: Task::ready(None), client: project.read(cx).client(), - project, + project: project.downgrade(), project_is_shared: false, client_subscription: None, _project_subscriptions: Default::default(), @@ -180,8 +180,10 @@ impl TextThreadStore { ) -> Result { let context_id = TextThreadId::from_proto(envelope.payload.context_id); let operations = this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; + anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "only the host contexts can be opened" ); @@ -211,8 +213,9 @@ impl TextThreadStore { mut cx: AsyncApp, ) -> Result { let (context_id, operations) = this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "can only create contexts as the host" ); @@ -255,8 +258,9 @@ impl TextThreadStore { mut cx: AsyncApp, ) -> Result { this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "only the host can synchronize contexts" ); @@ -293,8 +297,12 @@ impl TextThreadStore { })? } - fn handle_project_shared(&mut self, _: Entity, cx: &mut Context) { - let is_shared = self.project.read(cx).is_shared(); + fn handle_project_shared(&mut self, cx: &mut Context) { + let Some(project) = self.project.upgrade() else { + return; + }; + + let is_shared = project.read(cx).is_shared(); let was_shared = mem::replace(&mut self.project_is_shared, is_shared); if is_shared == was_shared { return; @@ -309,7 +317,7 @@ impl TextThreadStore { false } }); - let remote_id = self.project.read(cx).remote_id().unwrap(); + let remote_id = project.read(cx).remote_id().unwrap(); self.client_subscription = self .client .subscribe_to_entity(remote_id) @@ -323,13 +331,13 @@ impl TextThreadStore { fn handle_project_event( &mut self, - project: Entity, + _project: Entity, event: &project::Event, cx: &mut Context, ) { match event { project::Event::RemoteIdChanged(_) => { - self.handle_project_shared(project, cx); + self.handle_project_shared(cx); } project::Event::Reshared => { self.advertise_contexts(cx); @@ -382,7 +390,10 @@ impl TextThreadStore { } pub fn create_remote(&mut self, cx: &mut Context) -> Task>> { - let project = self.project.read(cx); + let Some(project) = self.project.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("project was dropped"))); + }; + let project = project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); }; @@ -541,7 +552,10 @@ impl TextThreadStore { text_thread_id: TextThreadId, cx: &mut Context, ) -> Task>> { - let project = self.project.read(cx); + let Some(project) = self.project.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("project was dropped"))); + }; + let project = project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); }; @@ -618,7 +632,10 @@ impl TextThreadStore { event: &TextThreadEvent, cx: &mut Context, ) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; @@ -652,12 +669,14 @@ impl TextThreadStore { } fn advertise_contexts(&self, cx: &App) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; - // For now, only the host can advertise their open contexts. - if self.project.read(cx).is_via_collab() { + if project.read(cx).is_via_collab() { return; } @@ -689,7 +708,10 @@ impl TextThreadStore { } fn synchronize_contexts(&mut self, cx: &mut Context) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; @@ -828,7 +850,10 @@ impl TextThreadStore { } fn register_context_server_handlers(&self, cx: &mut Context) { - let context_server_store = self.project.read(cx).context_server_store(); + let Some(project) = self.project.upgrade() else { + return; + }; + let context_server_store = project.read(cx).context_server_store(); cx.subscribe(&context_server_store, Self::handle_context_server_event) .detach(); From bd6ca841ad48b9fedf868761618ef6f9cccd9f83 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:17:50 -0300 Subject: [PATCH 25/81] git_ui: Improve the branch picker UI (#44217) Follow up to https://github.com/zed-industries/zed/pull/42819 and https://github.com/zed-industries/zed/pull/44206. - Make this picker feel more consistent with other similar pickers (namely, the project picker) - Move actions to the footer and toggle them conditionally - Only show the "Create" and "Create New From: {default}" when we're selecting the "Create" list item _or_ when that item is the only visible. This means I'm changing here the state transition to only change to `NewBranch/NewRemote` if we only have those items available. - Reuse more UI code and use components when available (e.g., `ListHeader`) - Remove secondary actions from the list item Next step (in another PR), will be refine the same picker in the smaller, panel version. https://github.com/user-attachments/assets/fe72ac06-c1df-4829-a8a4-df8a9222672f Release Notes: - N/A --- crates/git_ui/src/branch_picker.rs | 476 ++++++++++++++++------------- 1 file changed, 260 insertions(+), 216 deletions(-) diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 33b852c1de9b1bd1a8abcc36dff964d14cbe1807..06405651206befad38c938c9fec35a98dab1ef2c 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -17,8 +17,8 @@ use settings::Settings; use std::sync::Arc; use time::OffsetDateTime; use ui::{ - CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, Tooltip, - prelude::*, + CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListHeader, ListItem, + ListItemSpacing, Tooltip, prelude::*, }; use util::ResultExt; use workspace::notifications::DetachAndPromptErr; @@ -440,13 +440,6 @@ impl BranchListDelegate { cx.emit(DismissEvent); } - fn loader(&self) -> AnyElement { - Icon::new(IconName::LoadCircle) - .size(IconSize::Small) - .with_rotate_animation(3) - .into_any_element() - } - fn delete_at(&self, idx: usize, window: &mut Window, cx: &mut Context>) { let Some(entry) = self.matches.get(idx).cloned() else { return; @@ -683,10 +676,16 @@ impl PickerDelegate for BranchListDelegate { } else { Entry::NewBranch { name: query } }; - picker.delegate.state = if is_url { - PickerState::NewRemote + // Only transition to NewBranch/NewRemote states when we only show their list item + // Otherwise, stay in List state so footer buttons remain visible + picker.delegate.state = if matches.is_empty() { + if is_url { + PickerState::NewRemote + } else { + PickerState::NewBranch + } } else { - PickerState::NewBranch + PickerState::List }; matches.push(entry); } else { @@ -812,67 +811,35 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None, None)); - let icon = if let Some(default_branch) = self.default_branch.clone() - && matches!(entry, Entry::NewBranch { .. }) - { - let tooltip_text = format!("Create branch based off default: {default_branch}"); - - Some( - IconButton::new("branch-from-default", IconName::GitBranchAlt) - .on_click(cx.listener(move |this, _, window, cx| { - this.delegate.set_selected_index(ix, window, cx); - this.delegate.confirm(true, window, cx); - })) - .tooltip(move |_window, cx| { - Tooltip::for_action(tooltip_text.clone(), &menu::SecondaryConfirm, cx) - }), - ) - } else { - None - }; + let entry_icon = match entry { + Entry::NewUrl { .. } | Entry::NewBranch { .. } => { + Icon::new(IconName::Plus).color(Color::Muted) + } - let icon_element = if self.display_remotes { - Icon::new(IconName::Screen) - } else { - Icon::new(IconName::GitBranchAlt) + Entry::Branch { .. } => { + if self.display_remotes { + Icon::new(IconName::Screen).color(Color::Muted) + } else { + Icon::new(IconName::GitBranchAlt).color(Color::Muted) + } + } }; - let entry_name = match entry { - Entry::NewUrl { .. } => h_flex() - .gap_1() - .child( - Icon::new(IconName::Plus) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child( - Label::new("Create remote repository".to_string()) - .single_line() - .truncate(), - ) + let entry_title = match entry { + Entry::NewUrl { .. } => Label::new("Create Remote Repository") + .single_line() + .truncate() .into_any_element(), - Entry::NewBranch { name } => h_flex() - .gap_1() - .child( - Icon::new(IconName::Plus) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child( - Label::new(format!("Create branch \"{name}\"…")) - .single_line() - .truncate(), - ) - .into_any_element(), - Entry::Branch { branch, positions } => h_flex() - .max_w_48() - .child(h_flex().mr_1().child(icon_element)) - .child( - HighlightedLabel::new(branch.name().to_string(), positions.clone()) - .single_line() - .truncate(), - ) + Entry::NewBranch { name } => Label::new(format!("Create Branch: \"{name}\"…")) + .single_line() + .truncate() .into_any_element(), + Entry::Branch { branch, positions } => { + HighlightedLabel::new(branch.name().to_string(), positions.clone()) + .single_line() + .truncate() + .into_any_element() + } }; Some( @@ -880,82 +847,96 @@ impl PickerDelegate for BranchListDelegate { .inset(true) .spacing(ListItemSpacing::Sparse) .toggle_state(selected) - .tooltip({ - match entry { - Entry::Branch { branch, .. } => Tooltip::text(branch.name().to_string()), - Entry::NewUrl { .. } => { - Tooltip::text("Create remote repository".to_string()) - } - Entry::NewBranch { name } => { - Tooltip::text(format!("Create branch \"{name}\"")) - } - } - }) .child( - v_flex() + h_flex() .w_full() - .overflow_hidden() + .gap_3() + .flex_grow() + .child(entry_icon) .child( - h_flex() - .gap_6() - .justify_between() - .overflow_x_hidden() - .child(entry_name) - .when_some(commit_time, |label, commit_time| { - label.child( - Label::new(commit_time) - .size(LabelSize::Small) - .color(Color::Muted) - .into_element(), - ) - }), - ) - .when(self.style == BranchListStyle::Modal, |el| { - el.child(div().max_w_96().child({ - let message = match entry { - Entry::NewUrl { url } => format!("based off {url}"), - Entry::NewBranch { .. } => { - if let Some(current_branch) = - self.repo.as_ref().and_then(|repo| { - repo.read(cx).branch.as_ref().map(|b| b.name()) - }) - { - format!("based off {}", current_branch) - } else { - "based off the current branch".to_string() - } - } - Entry::Branch { .. } => { - let show_author_name = ProjectSettings::get_global(cx) - .git - .branch_picker - .show_author_name; - - subject.map_or("no commits found".into(), |subject| { - if show_author_name && author_name.is_some() { - format!("{} • {}", author_name.unwrap(), subject) - } else { - subject.to_string() - } + v_flex() + .id("info_container") + .w_full() + .child(entry_title) + .child( + h_flex() + .w_full() + .justify_between() + .gap_1p5() + .when(self.style == BranchListStyle::Modal, |el| { + el.child(div().max_w_96().child({ + let message = match entry { + Entry::NewUrl { url } => { + format!("Based off {url}") + } + Entry::NewBranch { .. } => { + if let Some(current_branch) = + self.repo.as_ref().and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|b| b.name()) + }) + { + format!("Based off {}", current_branch) + } else { + "Based off the current branch" + .to_string() + } + } + Entry::Branch { .. } => { + let show_author_name = + ProjectSettings::get_global(cx) + .git + .branch_picker + .show_author_name; + + subject.map_or( + "No commits found".into(), + |subject| { + if show_author_name + && author_name.is_some() + { + format!( + "{} • {}", + author_name.unwrap(), + subject + ) + } else { + subject.to_string() + } + }, + ) + } + }; + + Label::new(message) + .size(LabelSize::Small) + .color(Color::Muted) + .truncate() + })) }) - } - }; - - Label::new(message) - .size(LabelSize::Small) - .truncate() - .color(Color::Muted) - })) - }), - ) - .end_slot::(icon), + .when_some(commit_time, |label, commit_time| { + label.child( + Label::new(commit_time) + .size(LabelSize::Small) + .color(Color::Muted), + ) + }), + ) + .when_some( + entry.as_branch().map(|b| b.name().to_string()), + |this, branch_name| this.tooltip(Tooltip::text(branch_name)), + ), + ), + ), ) } fn render_header( &self, _window: &mut Window, - cx: &mut Context>, + _cx: &mut Context>, ) -> Option { matches!(self.state, PickerState::List).then(|| { let label = if self.display_remotes { @@ -964,83 +945,54 @@ impl PickerDelegate for BranchListDelegate { "Local" }; - h_flex() - .w_full() - .p_1p5() - .gap_1() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .child(Label::new(label).size(LabelSize::Small).color(Color::Muted)) - .into_any() + ListHeader::new(label).inset(true).into_any_element() }) } fn render_footer(&self, _: &mut Window, cx: &mut Context>) -> Option { let focus_handle = self.focus_handle.clone(); + let loading_icon = Icon::new(IconName::LoadCircle) + .size(IconSize::Small) + .with_rotate_animation(3); + + let footer_container = || { + h_flex() + .w_full() + .p_1p5() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }; - if self.loading { - return Some( - h_flex() - .w_full() - .p_1p5() - .gap_1() - .justify_end() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .child(self.loader()) - .into_any(), - ); - } match self.state { - PickerState::List => Some( - h_flex() - .w_full() - .p_1p5() - .gap_0p5() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .justify_between() - .child({ - let focus_handle = focus_handle.clone(); - Button::new("filter-remotes", "Filter remotes") + PickerState::List => { + let selected_entry = self.matches.get(self.selected_index); + + let branch_from_default_button = self + .default_branch + .as_ref() + .filter(|_| matches!(selected_entry, Some(Entry::NewBranch { .. }))) + .map(|default_branch| { + let button_label = format!("Create New From: {default_branch}"); + + Button::new("branch-from-default", button_label) .key_binding( KeyBinding::for_action_in( - &branch_picker::FilterRemotes, + &menu::SecondaryConfirm, &focus_handle, cx, ) .map(|kb| kb.size(rems_from_px(12.))), ) - .on_click(|_click, window, cx| { - window.dispatch_action( - branch_picker::FilterRemotes.boxed_clone(), - cx, - ); - }) - .disabled(self.loading) - .style(ButtonStyle::Subtle) - .toggle_state(self.display_remotes) - .tooltip({ - let state = self.display_remotes; - - move |_window, cx| { - let tooltip_text = if state { - "Show local branches" - } else { - "Show remote branches" - }; - - Tooltip::for_action_in( - tooltip_text, - &branch_picker::FilterRemotes, - &focus_handle, - cx, - ) - } - }) - }) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(true, window, cx); + })) + }); + + let delete_and_select_btns = h_flex() + .gap_0p5() .child( Button::new("delete-branch", "Delete") + .disabled(self.loading) .key_binding( KeyBinding::for_action_in( &branch_picker::DeleteBranch, @@ -1049,43 +1001,134 @@ impl PickerDelegate for BranchListDelegate { ) .map(|kb| kb.size(rems_from_px(12.))), ) - .disabled(self.loading) .on_click(|_, window, cx| { window .dispatch_action(branch_picker::DeleteBranch.boxed_clone(), cx); }), ) - .when(self.loading, |this| this.child(self.loader())) - .into_any(), - ), + .child( + Button::new("select_branch", "Select") + .key_binding( + KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ); + + Some( + footer_container() + .map(|this| { + if branch_from_default_button.is_some() { + this.justify_end().when_some( + branch_from_default_button, + |this, button| { + this.child(button).child( + Button::new("create", "Create") + .key_binding( + KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ) + }, + ) + } else if self.loading { + this.justify_between() + .child(loading_icon) + .child(delete_and_select_btns) + } else { + this.justify_between() + .child({ + let focus_handle = focus_handle.clone(); + Button::new("filter-remotes", "Filter Remotes") + .disabled(self.loading) + .toggle_state(self.display_remotes) + .key_binding( + KeyBinding::for_action_in( + &branch_picker::FilterRemotes, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_click, window, cx| { + window.dispatch_action( + branch_picker::FilterRemotes.boxed_clone(), + cx, + ); + }) + }) + .child(delete_and_select_btns) + } + }) + .into_any_element(), + ) + } + PickerState::NewBranch => { + let branch_from_default_button = + self.default_branch.as_ref().map(|default_branch| { + let button_label = format!("Create New From: {default_branch}"); + + Button::new("branch-from-default", button_label) + .key_binding( + KeyBinding::for_action_in( + &menu::SecondaryConfirm, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(true, window, cx); + })) + }); + + Some( + footer_container() + .gap_0p5() + .justify_end() + .when_some(branch_from_default_button, |this, button| { + this.child(button) + }) + .child( + Button::new("branch-from-default", "Create") + .key_binding( + KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ) + .into_any_element(), + ) + } PickerState::CreateRemote(_) => Some( - h_flex() - .w_full() - .p_1p5() - .gap_1() - .border_t_1() - .border_color(cx.theme().colors().border_variant) + footer_container() + .justify_end() .child( Label::new("Choose a name for this remote repository") .size(LabelSize::Small) .color(Color::Muted), ) .child( - h_flex().w_full().justify_end().child( - Label::new("Save") - .size(LabelSize::Small) - .color(Color::Muted), - ), + Label::new("Save") + .size(LabelSize::Small) + .color(Color::Muted), ) - .into_any(), + .into_any_element(), ), - PickerState::NewRemote | PickerState::NewBranch => None, + PickerState::NewRemote => None, } } - - fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { - None - } } #[cfg(test)] @@ -1515,6 +1558,7 @@ mod tests { let last_match = picker.delegate.matches.last().unwrap(); assert!(last_match.is_new_branch()); assert_eq!(last_match.name(), "new-feature-branch"); + // State is NewBranch because no existing branches fuzzy-match the query assert!(matches!(picker.delegate.state, PickerState::NewBranch)); picker.delegate.confirm(false, window, cx); }) From a350438a21c80e1199ceae78f4e9f7e6f7403330 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Fri, 5 Dec 2025 22:26:42 +0200 Subject: [PATCH 26/81] Specify a schema to use when dealing with JSONC files (#44250) Follow-up of https://github.com/zed-industries/zed/pull/43854 Closes https://github.com/zed-industries/zed/issues/40970 Seems that json language server does not distinguish between JSONC and JSON files in runtime, but there is a static schema, which accepts globs in its `fileMatch` fields. Use all glob overrides and file suffixes for JSONC inside those match fields, and provide a grammar for such matches, which accepts trailing commas. Release Notes: - Improved JSONC trailing comma handling --- .../src/json_schema_store.rs | 54 +++++++++++++++++-- crates/language/src/language_registry.rs | 4 +- crates/language/src/language_settings.rs | 9 ++-- crates/languages/src/json.rs | 11 ++-- crates/languages/src/lib.rs | 2 +- 5 files changed, 65 insertions(+), 15 deletions(-) diff --git a/crates/json_schema_store/src/json_schema_store.rs b/crates/json_schema_store/src/json_schema_store.rs index b44efb8b1b135850ab78460a428b5088e5fa0928..18041545ccd404eef0035b9b50ff8244d212fa0b 100644 --- a/crates/json_schema_store/src/json_schema_store.rs +++ b/crates/json_schema_store/src/json_schema_store.rs @@ -3,8 +3,9 @@ use std::{str::FromStr, sync::Arc}; use anyhow::{Context as _, Result}; use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, WeakEntity}; -use language::LanguageRegistry; +use language::{LanguageRegistry, language_settings::all_language_settings}; use project::LspStore; +use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields}; // Origin: https://github.com/SchemaStore/schemastore const TSCONFIG_SCHEMA: &str = include_str!("schemas/tsconfig.json"); @@ -159,14 +160,35 @@ pub fn resolve_schema_request_inner( } } "snippets" => snippet_provider::format::VsSnippetsFile::generate_json_schema(), + "jsonc" => jsonc_schema(), _ => { - anyhow::bail!("Unrecognized builtin JSON schema: {}", schema_name); + anyhow::bail!("Unrecognized builtin JSON schema: {schema_name}"); } }; Ok(schema) } -pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { +const JSONC_LANGUAGE_NAME: &str = "JSONC"; + +pub fn all_schema_file_associations( + languages: &Arc, + cx: &mut App, +) -> serde_json::Value { + let extension_globs = languages + .available_language_for_name(JSONC_LANGUAGE_NAME) + .map(|language| language.matcher().path_suffixes.clone()) + .into_iter() + .flatten() + // Path suffixes can be entire file names or just their extensions. + .flat_map(|path_suffix| [format!("*.{path_suffix}"), path_suffix]); + let override_globs = all_language_settings(None, cx) + .file_types + .get(JSONC_LANGUAGE_NAME) + .into_iter() + .flat_map(|(_, glob_strings)| glob_strings) + .cloned(); + let jsonc_globs = extension_globs.chain(override_globs).collect::>(); + let mut file_associations = serde_json::json!([ { "fileMatch": [ @@ -211,6 +233,10 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { "fileMatch": ["package.json"], "url": "zed://schemas/package_json" }, + { + "fileMatch": &jsonc_globs, + "url": "zed://schemas/jsonc" + }, ]); #[cfg(debug_assertions)] @@ -233,7 +259,7 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { let file_name = normalized_action_name_to_file_name(normalized_name.clone()); serde_json::json!({ "fileMatch": [file_name], - "url": format!("zed://schemas/action/{}", normalized_name) + "url": format!("zed://schemas/action/{normalized_name}") }) }), ); @@ -249,6 +275,26 @@ fn package_json_schema() -> serde_json::Value { serde_json::Value::from_str(PACKAGE_JSON_SCHEMA).unwrap() } +fn jsonc_schema() -> serde_json::Value { + let generator = schemars::generate::SchemaSettings::draft2019_09() + .with_transform(DefaultDenyUnknownFields) + .with_transform(AllowTrailingCommas) + .into_generator(); + let meta_schema = generator + .settings() + .meta_schema + .as_ref() + .expect("meta_schema should be present in schemars settings") + .to_string(); + let defs = generator.definitions(); + let schema = schemars::json_schema!({ + "$schema": meta_schema, + "allowTrailingCommas": true, + "$defs": defs, + }); + serde_json::to_value(schema).unwrap() +} + fn generate_inspector_style_schema() -> serde_json::Value { let schema = schemars::generate::SchemaSettings::draft2019_09() .with_transform(util::schemars::DefaultDenyUnknownFields) diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index a0b04efd1b1366a101812d8656965637c13769a5..af2b66316d133370a3c27f59da645cfff8d8aa66 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -745,7 +745,7 @@ impl LanguageRegistry { self: &Arc, path: &Path, content: Option<&Rope>, - user_file_types: Option<&FxHashMap, GlobSet>>, + user_file_types: Option<&FxHashMap, (GlobSet, Vec)>>, ) -> Option { let filename = path.file_name().and_then(|filename| filename.to_str()); // `Path.extension()` returns None for files with a leading '.' @@ -788,7 +788,7 @@ impl LanguageRegistry { let path_matches_custom_suffix = || { user_file_types .and_then(|types| types.get(language_name.as_ref())) - .map_or(None, |custom_suffixes| { + .map_or(None, |(custom_suffixes, _)| { path_suffixes .iter() .find(|(_, candidate)| custom_suffixes.is_match_candidate(candidate)) diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 068f8e1aa39ca3422fda8eb5706c00de6f2f62ce..fccaa545b79c1f24589889df8fcd163fbc5b6c7d 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -51,7 +51,7 @@ pub struct AllLanguageSettings { pub edit_predictions: EditPredictionSettings, pub defaults: LanguageSettings, languages: HashMap, - pub(crate) file_types: FxHashMap, GlobSet>, + pub file_types: FxHashMap, (GlobSet, Vec)>, } #[derive(Debug, Clone, PartialEq)] @@ -656,7 +656,7 @@ impl settings::Settings for AllLanguageSettings { let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap(); - let mut file_types: FxHashMap, GlobSet> = FxHashMap::default(); + let mut file_types: FxHashMap, (GlobSet, Vec)> = FxHashMap::default(); for (language, patterns) in all_languages.file_types.iter().flatten() { let mut builder = GlobSetBuilder::new(); @@ -665,7 +665,10 @@ impl settings::Settings for AllLanguageSettings { builder.add(Glob::new(pattern).unwrap()); } - file_types.insert(language.clone(), builder.build().unwrap()); + file_types.insert( + language.clone(), + (builder.build().unwrap(), patterns.0.clone()), + ); } Self { diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index f695512c1a9ed55289a79bbbd632114a24b82d8d..00bb265967f83ee9a95c034cc0bbcbf63e952647 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -7,8 +7,8 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageName, LocalFile as _, LspAdapter, LspAdapterDelegate, LspInstaller, - Toolchain, + ContextProvider, LanguageName, LanguageRegistry, LocalFile as _, LspAdapter, + LspAdapterDelegate, LspInstaller, Toolchain, }; use lsp::{LanguageServerBinary, LanguageServerName, Uri}; use node_runtime::{NodeRuntime, VersionStrategy}; @@ -129,14 +129,15 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct JsonLspAdapter { + languages: Arc, node: NodeRuntime, } impl JsonLspAdapter { const PACKAGE_NAME: &str = "vscode-langservers-extracted"; - pub fn new(node: NodeRuntime) -> Self { - Self { node } + pub fn new(languages: Arc, node: NodeRuntime) -> Self { + Self { languages, node } } } @@ -255,7 +256,7 @@ impl LspAdapter for JsonLspAdapter { cx: &mut AsyncApp, ) -> Result { let mut config = cx.update(|cx| { - let schemas = json_schema_store::all_schema_file_associations(cx); + let schemas = json_schema_store::all_schema_file_associations(&self.languages, cx); // This can be viewed via `dev: open language server logs` -> `json-language-server` -> // `Server Info` diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 9df14fb162e2ed722f5ed7527e179f3aec9b0af6..8ce234a864085a324adeb93a1005a0ed60b1c2b1 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -89,7 +89,7 @@ pub fn init(languages: Arc, fs: Arc, node: NodeRuntime let go_context_provider = Arc::new(go::GoContextProvider); let go_lsp_adapter = Arc::new(go::GoLspAdapter); let json_context_provider = Arc::new(JsonTaskProvider); - let json_lsp_adapter = Arc::new(json::JsonLspAdapter::new(node.clone())); + let json_lsp_adapter = Arc::new(json::JsonLspAdapter::new(languages.clone(), node.clone())); let node_version_lsp_adapter = Arc::new(json::NodeVersionAdapter); let py_lsp_adapter = Arc::new(python::PyLspAdapter::new()); let ty_lsp_adapter = Arc::new(python::TyLspAdapter::new(fs.clone())); From 5cd30e51067c9f4aa57f7e68b9f3aef957916a18 Mon Sep 17 00:00:00 2001 From: Michael Benfield Date: Fri, 5 Dec 2025 13:28:29 -0800 Subject: [PATCH 27/81] inline assistant: Use tools and remove insertion mode (#44248) Co-authored by: Mikayla Maki Co-authored-by: Danilo Leal Release Notes: - N/A --- assets/prompts/content_prompt_v2.hbs | 44 ++++ crates/agent/src/tools.rs | 4 + crates/agent_ui/src/agent_model_selector.rs | 2 +- crates/agent_ui/src/buffer_codegen.rs | 265 ++++++++++++++++++-- crates/agent_ui/src/inline_assistant.rs | 111 ++++++-- crates/agent_ui/src/inline_prompt_editor.rs | 131 ++++++++-- crates/feature_flags/src/flags.rs | 6 + crates/language_model/src/language_model.rs | 34 +++ crates/prompt_store/src/prompts.rs | 92 +++++++ 9 files changed, 630 insertions(+), 59 deletions(-) create mode 100644 assets/prompts/content_prompt_v2.hbs diff --git a/assets/prompts/content_prompt_v2.hbs b/assets/prompts/content_prompt_v2.hbs new file mode 100644 index 0000000000000000000000000000000000000000..e1b6ddc6f023e9e97c9bb851473ac02e989c8feb --- /dev/null +++ b/assets/prompts/content_prompt_v2.hbs @@ -0,0 +1,44 @@ +{{#if language_name}} +Here's a file of {{language_name}} that the user is going to ask you to make an edit to. +{{else}} +Here's a file of text that the user is going to ask you to make an edit to. +{{/if}} + +The section you'll need to rewrite is marked with tags. + + +{{{document_content}}} + + +{{#if is_truncated}} +The context around the relevant section has been truncated (possibly in the middle of a line) for brevity. +{{/if}} + +{{#if rewrite_section}} +And here's the section to rewrite based on that prompt again for reference: + + +{{{rewrite_section}}} + + +{{#if diagnostic_errors}} +Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to. + +{{#each diagnostic_errors}} + + {{line_number}} + {{error_message}} + {{code_content}} + +{{/each}} +{{/if}} + +{{/if}} + +Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. + +Start at the indentation level in the original file in the rewritten {{content_type}}. + +You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if +you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so. +It is an error if you try to make a change that cannot be made simply by editing the rewrite_section. diff --git a/crates/agent/src/tools.rs b/crates/agent/src/tools.rs index 1d3c0d557716ec3a52f910971547df4ee764cab0..62a52998a705e11d1c9e69cbade7f427cc9cfc32 100644 --- a/crates/agent/src/tools.rs +++ b/crates/agent/src/tools.rs @@ -4,6 +4,7 @@ mod create_directory_tool; mod delete_path_tool; mod diagnostics_tool; mod edit_file_tool; + mod fetch_tool; mod find_path_tool; mod grep_tool; @@ -12,6 +13,7 @@ mod move_path_tool; mod now_tool; mod open_tool; mod read_file_tool; + mod terminal_tool; mod thinking_tool; mod web_search_tool; @@ -25,6 +27,7 @@ pub use create_directory_tool::*; pub use delete_path_tool::*; pub use diagnostics_tool::*; pub use edit_file_tool::*; + pub use fetch_tool::*; pub use find_path_tool::*; pub use grep_tool::*; @@ -33,6 +36,7 @@ pub use move_path_tool::*; pub use now_tool::*; pub use open_tool::*; pub use read_file_tool::*; + pub use terminal_tool::*; pub use thinking_tool::*; pub use web_search_tool::*; diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 43982cdda7bd887b8fd9970e836090a0e549ae11..3840e40cf4d22db9d52e74ef0489c06ca8a15f26 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -98,7 +98,7 @@ impl Render for AgentModelSelector { .child( Icon::new(IconName::ChevronDown) .color(color) - .size(IconSize::XSmall), + .size(IconSize::Small), ), move |_window, cx| { Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx) diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 972ead664464876e57d7830b18db3f2b0c49629c..0d014f50294f90aa2bda1f51025c937cc0e2ae56 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -5,22 +5,26 @@ use client::telemetry::Telemetry; use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; +use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag}; use futures::{ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::{LocalBoxFuture, Shared}, join, }; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; +use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelTextStream, Role, report_assistant_event, + LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role, + report_assistant_event, }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use prompt_store::PromptBuilder; use rope::Rope; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use smol::future::FutureExt; use std::{ cmp, @@ -34,6 +38,29 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; +use ui::SharedString; + +/// Use this tool to provide a message to the user when you're unable to complete a task. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FailureMessageInput { + /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request. + /// + /// The message may use markdown formatting if you wish. + pub message: String, +} + +/// Replaces text in tags with your replacement_text. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RewriteSectionInput { + /// A brief description of the edit you have made. + /// + /// The description may use markdown formatting if you wish. + /// This is optional - if the edit is simple or obvious, you should leave it empty. + pub description: String, + + /// The text to replace the section with. + pub replacement_text: String, +} pub struct BufferCodegen { alternatives: Vec>, @@ -238,6 +265,7 @@ pub struct CodegenAlternative { elapsed_time: Option, completion: Option, pub message_id: Option, + pub model_explanation: Option, } impl EventEmitter for CodegenAlternative {} @@ -288,14 +316,15 @@ impl CodegenAlternative { generation: Task::ready(()), diff: Diff::default(), telemetry, - _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), builder, - active, + active: active, edits: Vec::new(), line_operations: Vec::new(), range, elapsed_time: None, completion: None, + model_explanation: None, + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), } } @@ -358,18 +387,124 @@ impl CodegenAlternative { let api_key = model.api_key(cx); let telemetry_id = model.telemetry_id(); let provider_id = model.provider_id(); - let stream: LocalBoxFuture> = - if user_prompt.trim().to_lowercase() == "delete" { - async { Ok(LanguageModelTextStream::default()) }.boxed_local() + + if cx.has_flag::() { + let request = self.build_request(&model, user_prompt, context_task, cx)?; + let tool_use = + cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await); + self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx); + } else { + let stream: LocalBoxFuture> = + if user_prompt.trim().to_lowercase() == "delete" { + async { Ok(LanguageModelTextStream::default()) }.boxed_local() + } else { + let request = self.build_request(&model, user_prompt, context_task, cx)?; + cx.spawn(async move |_, cx| { + Ok(model.stream_completion_text(request.await, cx).await?) + }) + .boxed_local() + }; + self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); + } + + Ok(()) + } + + fn build_request_v2( + &self, + model: &Arc, + user_prompt: String, + context_task: Shared>>, + cx: &mut App, + ) -> Result> { + let buffer = self.buffer.read(cx).snapshot(cx); + let language = buffer.language_at(self.range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None } else { - let request = self.build_request(&model, user_prompt, context_task, cx)?; - cx.spawn(async move |_, cx| { - Ok(model.stream_completion_text(request.await, cx).await?) - }) - .boxed_local() + Some(language.name()) + } + } else { + None + }; + + let language_name = language_name.as_ref(); + let start = buffer.point_to_buffer_offset(self.range.start); + let end = buffer.point_to_buffer_offset(self.range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + anyhow::bail!("invalid transformation range"); + } + } else { + anyhow::bail!("invalid transformation range"); + }; + + let system_prompt = self + .builder + .generate_inline_transformation_prompt_v2( + language_name, + buffer, + range.start.0..range.end.0, + ) + .context("generating content prompt")?; + + let temperature = AgentSettings::temperature_for_model(model, cx); + + let tool_input_format = model.tool_input_format(); + + Ok(cx.spawn(async move |_cx| { + let mut messages = vec![LanguageModelRequestMessage { + role: Role::System, + content: vec![system_prompt.into()], + cache: false, + reasoning_details: None, + }]; + + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + reasoning_details: None, }; - self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); - Ok(()) + + if let Some(context) = context_task.await { + context.add_to_request_message(&mut user_message); + } + + user_message.content.push(user_prompt.into()); + messages.push(user_message); + + let tools = vec![ + LanguageModelRequestTool { + name: "rewrite_section".to_string(), + description: "Replaces text in tags with your replacement_text.".to_string(), + input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + }, + LanguageModelRequestTool { + name: "failure_message".to_string(), + description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(), + input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + }, + ]; + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(CompletionIntent::InlineAssist), + mode: None, + tools, + tool_choice: None, + stop: Vec::new(), + temperature, + messages, + thinking_allowed: false, + } + })) } fn build_request( @@ -379,6 +514,10 @@ impl CodegenAlternative { context_task: Shared>>, cx: &mut App, ) -> Result> { + if cx.has_flag::() { + return self.build_request_v2(model, user_prompt, context_task, cx); + } + let buffer = self.buffer.read(cx).snapshot(cx); let language = buffer.language_at(self.range.start); let language_name = if let Some(language) = language.as_ref() { @@ -510,6 +649,7 @@ impl CodegenAlternative { self.generation = cx.spawn(async move |codegen, cx| { let stream = stream.await; + let token_usage = stream .as_ref() .ok() @@ -899,6 +1039,101 @@ impl CodegenAlternative { .ok(); }) } + + fn handle_tool_use( + &mut self, + _telemetry_id: String, + _provider_id: String, + _api_key: Option, + tool_use: impl 'static + + Future< + Output = Result, + >, + cx: &mut Context, + ) { + self.diff = Diff::default(); + self.status = CodegenStatus::Pending; + + self.generation = cx.spawn(async move |codegen, cx| { + let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| { + let _ = codegen.update(cx, |this, cx| { + this.status = status; + cx.emit(CodegenEvent::Finished); + cx.notify(); + }); + }; + + let tool_use = tool_use.await; + + match tool_use { + Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => { + // Parse the input JSON into RewriteSectionInput + match serde_json::from_value::(tool_use.input) { + Ok(input) => { + // Store the description if non-empty + let description = if !input.description.trim().is_empty() { + Some(input.description.clone()) + } else { + None + }; + + // Apply the replacement text to the buffer and compute diff + let batch_diff_task = codegen + .update(cx, |this, cx| { + this.model_explanation = description.map(Into::into); + let range = this.range.clone(); + this.apply_edits( + std::iter::once((range, input.replacement_text)), + cx, + ); + this.reapply_batch_diff(cx) + }) + .ok(); + + // Wait for the diff computation to complete + if let Some(diff_task) = batch_diff_task { + diff_task.await; + } + + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + } + Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => { + // Handle failure message tool use + match serde_json::from_value::(tool_use.input) { + Ok(input) => { + let _ = codegen.update(cx, |this, _cx| { + // Store the failure message as the tool description + this.model_explanation = Some(input.message.into()); + }); + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + } + Ok(_tool_use) => { + // Unexpected tool. + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + }); + cx.notify(); + } } #[derive(Copy, Clone, Debug)] diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index cbc5891036fdf03ee04cca6b77820748faed2d0a..48da85d38554da8227d76d3cbe290e29ef4fc531 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -387,17 +387,9 @@ impl InlineAssistant { let mut selections = Vec::>::new(); let mut newest_selection = None; for mut selection in initial_selections { - if selection.end > selection.start { - selection.start.column = 0; - // If the selection ends at the start of the line, we don't want to include it. - if selection.end.column == 0 { - selection.end.row -= 1; - } - selection.end.column = snapshot - .buffer_snapshot() - .line_len(MultiBufferRow(selection.end.row)); - } else if let Some(fold) = - snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row)) + if selection.end == selection.start + && let Some(fold) = + snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row)) { selection.start = fold.range().start; selection.end = fold.range().end; @@ -424,6 +416,15 @@ impl InlineAssistant { } } } + } else { + selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if selection.end.column == 0 && selection.start.row != selection.end.row { + selection.end.row -= 1; + } + selection.end.column = snapshot + .buffer_snapshot() + .line_len(MultiBufferRow(selection.end.row)); } if let Some(prev_selection) = selections.last_mut() @@ -544,14 +545,15 @@ impl InlineAssistant { } } - let [prompt_block_id, end_block_id] = - self.insert_assist_blocks(editor, &range, &prompt_editor, cx); + let [prompt_block_id, tool_description_block_id, end_block_id] = + self.insert_assist_blocks(&editor, &range, &prompt_editor, cx); assists.push(( assist_id, range.clone(), prompt_editor, prompt_block_id, + tool_description_block_id, end_block_id, )); } @@ -570,7 +572,15 @@ impl InlineAssistant { }; let mut assist_group = InlineAssistGroup::new(); - for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { + for ( + assist_id, + range, + prompt_editor, + prompt_block_id, + tool_description_block_id, + end_block_id, + ) in assists + { let codegen = prompt_editor.read(cx).codegen().clone(); self.assists.insert( @@ -581,6 +591,7 @@ impl InlineAssistant { editor, &prompt_editor, prompt_block_id, + tool_description_block_id, end_block_id, range, codegen, @@ -689,7 +700,7 @@ impl InlineAssistant { range: &Range, prompt_editor: &Entity>, cx: &mut App, - ) -> [CustomBlockId; 2] { + ) -> [CustomBlockId; 3] { let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| { prompt_editor .editor @@ -703,6 +714,14 @@ impl InlineAssistant { render: build_assist_editor_renderer(prompt_editor), priority: 0, }, + // Placeholder for tool description - will be updated dynamically + BlockProperties { + style: BlockStyle::Flex, + placement: BlockPlacement::Below(range.end), + height: Some(0), + render: Arc::new(|_cx| div().into_any_element()), + priority: 0, + }, BlockProperties { style: BlockStyle::Sticky, placement: BlockPlacement::Below(range.end), @@ -721,7 +740,7 @@ impl InlineAssistant { editor.update(cx, |editor, cx| { let block_ids = editor.insert_blocks(assist_blocks, None, cx); - [block_ids[0], block_ids[1]] + [block_ids[0], block_ids[1], block_ids[2]] }) } @@ -1113,6 +1132,9 @@ impl InlineAssistant { let mut to_remove = decorations.removed_line_block_ids; to_remove.insert(decorations.prompt_block_id); to_remove.insert(decorations.end_block_id); + if let Some(tool_description_block_id) = decorations.model_explanation { + to_remove.insert(tool_description_block_id); + } editor.remove_blocks(to_remove, None, cx); }); @@ -1433,8 +1455,60 @@ impl InlineAssistant { let old_snapshot = codegen.snapshot(cx); let old_buffer = codegen.old_buffer(cx); let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone(); + // let model_explanation = codegen.model_explanation(cx); editor.update(cx, |editor, cx| { + // Update tool description block + // if let Some(description) = model_explanation { + // if let Some(block_id) = decorations.model_explanation { + // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + // let new_block_id = editor.insert_blocks( + // [BlockProperties { + // style: BlockStyle::Flex, + // placement: BlockPlacement::Below(assist.range.end), + // height: Some(1), + // render: Arc::new({ + // let description = description.clone(); + // move |cx| { + // div() + // .w_full() + // .py_1() + // .px_2() + // .bg(cx.theme().colors().editor_background) + // .border_y_1() + // .border_color(cx.theme().status().info_border) + // .child( + // Label::new(description.clone()) + // .color(Color::Muted) + // .size(LabelSize::Small), + // ) + // .into_any_element() + // } + // }), + // priority: 0, + // }], + // None, + // cx, + // ); + // decorations.model_explanation = new_block_id.into_iter().next(); + // } + // } else if let Some(block_id) = decorations.model_explanation { + // // Hide the block if there's no description + // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + // let new_block_id = editor.insert_blocks( + // [BlockProperties { + // style: BlockStyle::Flex, + // placement: BlockPlacement::Below(assist.range.end), + // height: Some(0), + // render: Arc::new(|_cx| div().into_any_element()), + // priority: 0, + // }], + // None, + // cx, + // ); + // decorations.model_explanation = new_block_id.into_iter().next(); + // } + let old_blocks = mem::take(&mut decorations.removed_line_block_ids); editor.remove_blocks(old_blocks, None, cx); @@ -1686,6 +1760,7 @@ impl InlineAssist { editor: &Entity, prompt_editor: &Entity>, prompt_block_id: CustomBlockId, + tool_description_block_id: CustomBlockId, end_block_id: CustomBlockId, range: Range, codegen: Entity, @@ -1700,7 +1775,8 @@ impl InlineAssist { decorations: Some(InlineAssistDecorations { prompt_block_id, prompt_editor: prompt_editor.clone(), - removed_line_block_ids: HashSet::default(), + removed_line_block_ids: Default::default(), + model_explanation: Some(tool_description_block_id), end_block_id, }), range, @@ -1804,6 +1880,7 @@ struct InlineAssistDecorations { prompt_block_id: CustomBlockId, prompt_editor: Entity>, removed_line_block_ids: HashSet, + model_explanation: Option, end_block_id: CustomBlockId, } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index b9e8d9ada230ba497ffcd4e577d3312dd440e604..0083648651645c456acfa19332d61b9f550ed4ed 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -11,9 +11,10 @@ use editor::{ use fs::Fs; use gpui::{ AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable, - Subscription, TextStyle, WeakEntity, Window, + Subscription, TextStyle, TextStyleRefinement, WeakEntity, Window, }; use language_model::{LanguageModel, LanguageModelRegistry}; +use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use prompt_store::PromptStore; @@ -65,7 +66,7 @@ impl Render for PromptEditor { const RIGHT_PADDING: Pixels = px(9.); - let (left_gutter_width, right_padding) = match &self.mode { + let (left_gutter_width, right_padding, explanation) = match &self.mode { PromptEditorMode::Buffer { id: _, codegen, @@ -83,11 +84,17 @@ impl Render for PromptEditor { let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0); let right_padding = editor_margins.right + RIGHT_PADDING; - (left_gutter_width, right_padding) + let explanation = codegen + .active_alternative() + .read(cx) + .model_explanation + .clone(); + + (left_gutter_width, right_padding, explanation) } PromptEditorMode::Terminal { .. } => { // Give the equivalent of the same left-padding that we're using on the right - (Pixels::from(40.0), Pixels::from(24.)) + (Pixels::from(40.0), Pixels::from(24.), None) } }; @@ -111,18 +118,30 @@ impl Render for PromptEditor { this.trigger_completion_menu(window, cx); })); + let markdown = window.use_state(cx, |_, cx| Markdown::new("".into(), None, None, cx)); + + if let Some(explanation) = &explanation { + markdown.update(cx, |markdown, cx| { + markdown.reset(explanation.clone(), cx); + }); + } + + let explanation_label = self + .render_markdown(markdown, markdown_style(window, cx)) + .into_any_element(); + v_flex() .key_context("PromptEditor") .capture_action(cx.listener(Self::paste)) - .bg(cx.theme().colors().editor_background) .block_mouse_except_scroll() - .gap_0p5() - .border_y_1() - .border_color(cx.theme().status().info_border) .size_full() .pt_0p5() .pb(bottom_padding) .pr(right_padding) + .bg(cx.theme().colors().editor_background) + .gap_0p5() + .border_y_1() + .border_color(cx.theme().colors().border) .child( h_flex() .items_start() @@ -139,12 +158,12 @@ impl Render for PromptEditor { .capture_action(cx.listener(Self::cycle_next)) .child( WithRemSize::new(ui_font_size) + .h_full() + .w(left_gutter_width) .flex() .flex_row() .flex_shrink_0() .items_center() - .h_full() - .w(left_gutter_width) .justify_center() .gap_2() .child(self.render_close_button(cx)) @@ -177,26 +196,82 @@ impl Render for PromptEditor { .flex_row() .items_center() .gap_1() + .child(add_context_button) + .child(self.model_selector.clone()) .children(buttons), ), ), ) - .child( - WithRemSize::new(ui_font_size) - .flex() - .flex_row() - .items_center() - .child(h_flex().flex_shrink_0().w(left_gutter_width)) - .child( - h_flex() - .w_full() - .pl_1() - .items_start() - .justify_between() - .child(add_context_button) - .child(self.model_selector.clone()), - ), - ) + .when_some(explanation, |this, _| { + this.child( + h_flex() + .size_full() + .child(div().w(left_gutter_width + px(6.))) + .child( + div() + .size_full() + .min_w_0() + .pb_px() + .pl_1() + .flex_1() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .child(explanation_label), + ), + ) + }) + } +} + +fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let mut text_style = window.text_style(); + + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + color: Some(colors.text), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style.clone(), + syntax: cx.theme().syntax().clone(), + selection_background_color: colors.element_selection_background, + heading_level_styles: Some(HeadingLevelStyles { + h1: Some(TextStyleRefinement { + font_size: Some(rems(1.15).into()), + ..Default::default() + }), + h2: Some(TextStyleRefinement { + font_size: Some(rems(1.1).into()), + ..Default::default() + }), + h3: Some(TextStyleRefinement { + font_size: Some(rems(1.05).into()), + ..Default::default() + }), + h4: Some(TextStyleRefinement { + font_size: Some(rems(1.).into()), + ..Default::default() + }), + h5: Some(TextStyleRefinement { + font_size: Some(rems(0.95).into()), + ..Default::default() + }), + h6: Some(TextStyleRefinement { + font_size: Some(rems(0.875).into()), + ..Default::default() + }), + }), + inline_code: TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + background_color: Some(colors.editor_foreground.opacity(0.08)), + ..Default::default() + }, + ..Default::default() } } @@ -759,6 +834,10 @@ impl PromptEditor { }) .into_any_element() } + + fn render_markdown(&self, markdown: Entity, style: MarkdownStyle) -> MarkdownElement { + MarkdownElement::new(markdown, style) + } } pub enum PromptEditorMode { diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 26615aea0f7566ec6dbbd66a128c1a395cc1b9bc..fe11a7b5eaa162a90ae8ba3f691ca804ab64db2d 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -11,3 +11,9 @@ pub struct PanicFeatureFlag; impl FeatureFlag for PanicFeatureFlag { const NAME: &'static str = "panic"; } + +pub struct InlineAssistantV2FeatureFlag; + +impl FeatureFlag for InlineAssistantV2FeatureFlag { + const NAME: &'static str = "inline-assistant-v2"; +} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c9b6391136da1a2b2e9a2ae470229179615a865a..cb03b84cbf34d3003e53befa518ecd91626a13e9 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -707,6 +707,40 @@ pub trait LanguageModel: Send + Sync { .boxed() } + fn stream_completion_tool( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result> { + let future = self.stream_completion(request, cx); + + async move { + let events = future.await?; + let mut events = events.fuse(); + + // Iterate through events until we find a complete ToolUse + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) + if tool_use.is_input_complete => + { + return Ok(tool_use); + } + Err(err) => { + return Err(err); + } + _ => {} + } + } + + // Stream ended without a complete tool use + Err(LanguageModelCompletionError::Other(anyhow::anyhow!( + "Stream ended without receiving a complete tool use" + ))) + } + .boxed() + } + fn cache_configuration(&self) -> Option { None } diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 3d47fbce7014e8e791ca8961447c8df1adf45abf..d6a172218a8eb3d4538363e6202a7e721d2b7bc1 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -94,6 +94,16 @@ pub struct ContentPromptContext { pub diagnostic_errors: Vec, } +#[derive(Serialize)] +pub struct ContentPromptContextV2 { + pub content_type: String, + pub language_name: Option, + pub is_truncated: bool, + pub document_content: String, + pub rewrite_section: Option, + pub diagnostic_errors: Vec, +} + #[derive(Serialize)] pub struct TerminalAssistantPromptContext { pub os: String, @@ -276,6 +286,88 @@ impl PromptBuilder { Ok(()) } + pub fn generate_inline_transformation_prompt_v2( + &self, + language_name: Option<&LanguageName>, + buffer: BufferSnapshot, + range: Range, + ) -> Result { + let content_type = match language_name.as_ref().map(|l| l.as_ref()) { + None | Some("Markdown" | "Plain Text") => "text", + Some(_) => "code", + }; + + const MAX_CTX: usize = 50000; + let is_insert = range.is_empty(); + let mut is_truncated = false; + + let before_range = 0..range.start; + let truncated_before = if before_range.len() > MAX_CTX { + is_truncated = true; + let start = buffer.clip_offset(range.start - MAX_CTX, text::Bias::Right); + start..range.start + } else { + before_range + }; + + let after_range = range.end..buffer.len(); + let truncated_after = if after_range.len() > MAX_CTX { + is_truncated = true; + let end = buffer.clip_offset(range.end + MAX_CTX, text::Bias::Left); + range.end..end + } else { + after_range + }; + + let mut document_content = String::new(); + for chunk in buffer.text_for_range(truncated_before) { + document_content.push_str(chunk); + } + if is_insert { + document_content.push_str(""); + } else { + document_content.push_str("\n"); + for chunk in buffer.text_for_range(range.clone()) { + document_content.push_str(chunk); + } + document_content.push_str("\n"); + } + for chunk in buffer.text_for_range(truncated_after) { + document_content.push_str(chunk); + } + + let rewrite_section = if !is_insert { + let mut section = String::new(); + for chunk in buffer.text_for_range(range.clone()) { + section.push_str(chunk); + } + Some(section) + } else { + None + }; + let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false); + let diagnostic_errors: Vec = diagnostics + .map(|entry| { + let start = entry.range.start; + ContentPromptDiagnosticContext { + line_number: (start.row + 1) as usize, + error_message: entry.diagnostic.message.clone(), + code_content: buffer.text_for_range(entry.range).collect(), + } + }) + .collect(); + + let context = ContentPromptContextV2 { + content_type: content_type.to_string(), + language_name: language_name.map(|s| s.to_string()), + is_truncated, + document_content, + rewrite_section, + diagnostic_errors, + }; + self.handlebars.lock().render("content_prompt_v2", &context) + } + pub fn generate_inline_transformation_prompt( &self, user_prompt: String, From f4b8b0f4716d842580e9a4d9a6526c8c3f0553b0 Mon Sep 17 00:00:00 2001 From: Mayank Verma Date: Sat, 6 Dec 2025 03:54:59 +0530 Subject: [PATCH 28/81] settings: Fix inconsistent terminal font weight step size (#44243) Closes #44242 Release Notes: - Fixed inconsistent terminal font weight step size in settings --- crates/settings/src/settings_content/terminal.rs | 5 ++--- crates/terminal/src/terminal_settings.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/settings/src/settings_content/terminal.rs b/crates/settings/src/settings_content/terminal.rs index cd01eb14fa5ce19b077c39b67f8bd90ac93ad35f..1a30eecaa12e1e4a2a9799b2ec752bae2998a257 100644 --- a/crates/settings/src/settings_content/terminal.rs +++ b/crates/settings/src/settings_content/terminal.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use collections::HashMap; -use gpui::{AbsoluteLength, FontFeatures, SharedString, px}; +use gpui::{AbsoluteLength, FontFeatures, FontWeight, SharedString, px}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings_macros::{MergeFrom, with_fallible_options}; @@ -96,8 +96,7 @@ pub struct TerminalSettingsContent { pub line_height: Option, pub font_features: Option, /// Sets the terminal's font weight in CSS weight units 0-900. - #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] - pub font_weight: Option, + pub font_weight: Option, /// Default cursor shape for the terminal. /// Can be "bar", "block", "underline", or "hollow". /// diff --git a/crates/terminal/src/terminal_settings.rs b/crates/terminal/src/terminal_settings.rs index 3b3070c6f680452b43d398786fa2a705a06d3404..3d70d85f35239778bee61113ebc51eea7d87adcb 100644 --- a/crates/terminal/src/terminal_settings.rs +++ b/crates/terminal/src/terminal_settings.rs @@ -95,7 +95,7 @@ impl settings::Settings for TerminalSettings { ) }), font_features: user_content.font_features, - font_weight: user_content.font_weight.map(FontWeight), + font_weight: user_content.font_weight, line_height: user_content.line_height.unwrap(), env: project_content.env.unwrap(), cursor_shape: user_content.cursor_shape.unwrap().into(), From e5f87735d3611b45c778e27b99cc4c6880962901 Mon Sep 17 00:00:00 2001 From: "Oleksii (Alexey) Orlenko" Date: Fri, 5 Dec 2025 23:27:21 +0100 Subject: [PATCH 29/81] markdown_preview: Remove unnecessary vec allocation (#44238) Instead of allocating a one-element vec on the heap, we can just use an array here (since `Editor::edit` accepts anything that implements `IntoIterator`). I haven't checked if there are more instances that can be simplified, just accidentally stumbled upon this when working on something else in the markdown preview crate. Release Notes: - N/A --- crates/markdown_preview/src/markdown_preview_view.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index 4126a31379fa74a750a7d111ac71dc180a3bb0ff..df8201dc7a3dad18c279582d668304ce9e1cf77b 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -524,7 +524,7 @@ impl Render for MarkdownPreviewView { if e.checked() { "[x]" } else { "[ ]" }; editor.edit( - vec![( + [( MultiBufferOffset( e.source_range().start, ) From 4cef8eb47bb157916f10cedd18b0c1a85cd21977 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 5 Dec 2025 16:55:05 -0700 Subject: [PATCH 30/81] Fix persistence for single-file worktrees (#44257) We were just deleting them before Co-Authored-By: Matthew Chisolm Closes #ISSUE Release Notes: - Fixed restoring window location for single-file worktrees Co-authored-by: Matthew Chisolm --- crates/workspace/src/persistence.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 824a9be90b6dc33094f854a3a9672db692e2b592..103e51d548648c18b5b2d724362228948a70930b 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -1359,11 +1359,11 @@ impl WorkspaceDb { // If a local workspace points to WSL, this check will cause us to wait for the // WSL VM and file server to boot up. This can block for many seconds. // Supported scenarios use remote workspaces. - if !has_wsl_path - && paths.paths().iter().all(|path| path.exists()) - && paths.paths().iter().any(|path| path.is_dir()) - { - result.push((id, SerializedWorkspaceLocation::Local, paths)); + if !has_wsl_path && paths.paths().iter().all(|path| path.exists()) { + // Only show directories in recent projects + if paths.paths().iter().any(|path| path.is_dir()) { + result.push((id, SerializedWorkspaceLocation::Local, paths)); + } } else { delete_tasks.push(self.delete_workspace_by_id(id)); } From 98608842175f02f503581737f9eb69eea01b56df Mon Sep 17 00:00:00 2001 From: Serophots <47299955+Serophots@users.noreply.github.com> Date: Sat, 6 Dec 2025 01:08:43 +0000 Subject: [PATCH 31/81] gpui: Make length helpers into const functions (#44259) Make gpui's `rems()`, `phi()`, `auto()` length related helpers into const functions. I can't see why these functions aren't already const except that it must've been overlooked when they were written? In my project I had need for rems() to be const, and I thought I'd do phi() and auto() whilst I was in the neighbourhood Release Notes: - N/A --- crates/gpui/src/geometry.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 859ecb3d0e6c7b5c33f5765ce4c6295cef7fd566..4daec6d15367f3e12bab3cba658ccb3f261e9f46 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -3567,7 +3567,7 @@ pub const fn relative(fraction: f32) -> DefiniteLength { } /// Returns the Golden Ratio, i.e. `~(1.0 + sqrt(5.0)) / 2.0`. -pub fn phi() -> DefiniteLength { +pub const fn phi() -> DefiniteLength { relative(1.618_034) } @@ -3580,7 +3580,7 @@ pub fn phi() -> DefiniteLength { /// # Returns /// /// A `Rems` representing the specified number of rems. -pub fn rems(rems: f32) -> Rems { +pub const fn rems(rems: f32) -> Rems { Rems(rems) } @@ -3608,7 +3608,7 @@ pub const fn px(pixels: f32) -> Pixels { /// # Returns /// /// A `Length` variant set to `Auto`. -pub fn auto() -> Length { +pub const fn auto() -> Length { Length::Auto } From 363fbbf0d43d4f39a03be685a6283025243cc36f Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Fri, 5 Dec 2025 21:05:34 -0500 Subject: [PATCH 32/81] git: Fix excerpt ranges in the commit view (#44261) Release Notes: - N/A --- crates/git_ui/src/commit_view.rs | 39 ++++++++++++-------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs index 7d191c1ae461ac36007dcbadc0b2e10f7dc53599..c637ea674f7e58954c186e1557df251d0d22d36b 100644 --- a/crates/git_ui/src/commit_view.rs +++ b/crates/git_ui/src/commit_view.rs @@ -1,7 +1,9 @@ use anyhow::{Context as _, Result}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle}; -use editor::{Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer}; +use editor::{ + Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, multibuffer_context_lines, +}; use git::repository::{CommitDetails, CommitDiff, RepoPath}; use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url}; use gpui::{ @@ -10,8 +12,8 @@ use gpui::{ PromptLevel, Render, Styled, Task, WeakEntity, Window, actions, }; use language::{ - Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, ReplicaId, Rope, - TextBuffer, + Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, OffsetRangeExt as _, + ReplicaId, Rope, TextBuffer, }; use multi_buffer::PathKey; use project::{Project, WorktreeId, git_store::Repository}; @@ -202,33 +204,22 @@ impl CommitView { this.multibuffer.update(cx, |multibuffer, cx| { let snapshot = buffer.read(cx).snapshot(); let path = snapshot.file().unwrap().path().clone(); - - let hunks: Vec<_> = buffer_diff.read(cx).hunks(&snapshot, cx).collect(); - - let excerpt_ranges = if hunks.is_empty() { - vec![language::Point::zero()..snapshot.max_point()] - } else { - hunks - .into_iter() - .map(|hunk| { - let start = hunk.range.start.max(language::Point::new( - hunk.range.start.row.saturating_sub(3), - 0, - )); - let end_row = - (hunk.range.end.row + 3).min(snapshot.max_point().row); - let end = - language::Point::new(end_row, snapshot.line_len(end_row)); - start..end - }) - .collect() + let excerpt_ranges = { + let mut hunks = buffer_diff.read(cx).hunks(&snapshot, cx).peekable(); + if hunks.peek().is_none() { + vec![language::Point::zero()..snapshot.max_point()] + } else { + hunks + .map(|hunk| hunk.buffer_range.to_point(&snapshot)) + .collect::>() + } }; let _is_newly_added = multibuffer.set_excerpts_for_path( PathKey::with_sort_prefix(FILE_NAMESPACE_SORT_PREFIX, path), buffer, excerpt_ranges, - 0, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff, cx); From 66c7bdf037c51659e5848e72bf27a77980e14df4 Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Fri, 5 Dec 2025 21:20:14 -0500 Subject: [PATCH 33/81] git: For conflicted files, set project diff excerpts using conflicts only (#44263) It's just distracting having excerpts for all the successfully merged hunks. Release Notes: - git: The project diff now focuses on merge conflicts for files that have them. --- crates/git_ui/src/project_diff.rs | 89 +++++++++---------------------- 1 file changed, 24 insertions(+), 65 deletions(-) diff --git a/crates/git_ui/src/project_diff.rs b/crates/git_ui/src/project_diff.rs index f211483c5efeb14fd230def9235d82a1a79f49b4..e560bba0d36ad9901fffa9b5aad4dbd88e3108b6 100644 --- a/crates/git_ui/src/project_diff.rs +++ b/crates/git_ui/src/project_diff.rs @@ -34,7 +34,6 @@ use project::{ use settings::{Settings, SettingsStore}; use smol::future::yield_now; use std::any::{Any, TypeId}; -use std::ops::Range; use std::sync::Arc; use theme::ActiveTheme; use ui::{KeyBinding, Tooltip, prelude::*, vertical_divider}; @@ -500,23 +499,30 @@ impl ProjectDiff { let snapshot = buffer.read(cx).snapshot(); let diff_read = diff.read(cx); - let diff_hunk_ranges = diff_read - .hunks_intersecting_range( - Anchor::min_max_range_for_buffer(diff_read.buffer_id), - &snapshot, - cx, - ) - .map(|diff_hunk| diff_hunk.buffer_range); - let conflicts = conflict_addon - .conflict_set(snapshot.remote_id()) - .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts) - .unwrap_or_default(); - let conflicts = conflicts.iter().map(|conflict| conflict.range.clone()); - - let excerpt_ranges = - merge_anchor_ranges(diff_hunk_ranges.into_iter(), conflicts, &snapshot) - .map(|range| range.to_point(&snapshot)) - .collect::>(); + + let excerpt_ranges = { + let diff_hunk_ranges = diff_read + .hunks_intersecting_range( + Anchor::min_max_range_for_buffer(diff_read.buffer_id), + &snapshot, + cx, + ) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot)); + let conflicts = conflict_addon + .conflict_set(snapshot.remote_id()) + .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts) + .unwrap_or_default(); + let mut conflicts = conflicts + .iter() + .map(|conflict| conflict.range.to_point(&snapshot)) + .peekable(); + + if conflicts.peek().is_some() { + conflicts.collect::>() + } else { + diff_hunk_ranges.collect() + } + }; let (was_empty, is_excerpt_newly_added) = self.multibuffer.update(cx, |multibuffer, cx| { let was_empty = multibuffer.is_empty(); @@ -1544,53 +1550,6 @@ mod preview { } } -fn merge_anchor_ranges<'a>( - left: impl 'a + Iterator>, - right: impl 'a + Iterator>, - snapshot: &'a language::BufferSnapshot, -) -> impl 'a + Iterator> { - let mut left = left.fuse().peekable(); - let mut right = right.fuse().peekable(); - - std::iter::from_fn(move || { - let Some(left_range) = left.peek() else { - return right.next(); - }; - let Some(right_range) = right.peek() else { - return left.next(); - }; - - let mut next_range = if left_range.start.cmp(&right_range.start, snapshot).is_lt() { - left.next().unwrap() - } else { - right.next().unwrap() - }; - - // Extend the basic range while there's overlap with a range from either stream. - loop { - if let Some(left_range) = left - .peek() - .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) - .cloned() - { - left.next(); - next_range.end = left_range.end; - } else if let Some(right_range) = right - .peek() - .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) - .cloned() - { - right.next(); - next_range.end = right_range.end; - } else { - break; - } - } - - Some(next_range) - }) -} - struct BranchDiffAddon { branch_diff: Entity, } From 51b7d06a27780d007f8391ac7d05313611a27163 Mon Sep 17 00:00:00 2001 From: Haojian Wu Date: Sat, 6 Dec 2025 08:35:18 +0100 Subject: [PATCH 34/81] Fix a typo: to -> two (#44272) Release Notes: - N/A --- docs/src/development/debuggers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/development/debuggers.md b/docs/src/development/debuggers.md index a5713f6c8aae1123e48ab6ab9f85f2147dfc7819..11f49390d41b89cfb1f527e1adabfd8b1b6d401a 100644 --- a/docs/src/development/debuggers.md +++ b/docs/src/development/debuggers.md @@ -5,7 +5,7 @@ ## Using Zed's built-in debugger -While the Zed project is open you can open the `New Process Modal` and select the `Debug` tab. There you can see to debug configurations to debug Zed with, one for GDB and one for LLDB. Select the configuration you want and Zed will build and launch the binary. +While the Zed project is open you can open the `New Process Modal` and select the `Debug` tab. There you can see two debug configurations to debug Zed with, one for GDB and one for LLDB. Select the configuration you want and Zed will build and launch the binary. Please note, GDB isn't supported on arm Macbooks From f08fd732a7ecbfe191563e2498a61a7ae75d5b05 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Sat, 6 Dec 2025 07:08:44 -0300 Subject: [PATCH 35/81] Add experimental mercury edit prediction provider (#44256) Release Notes: - N/A --------- Co-authored-by: Ben Kunkle Co-authored-by: Max Brunsfeld --- assets/icons/inception.svg | 11 + crates/edit_prediction/src/cursor_excerpt.rs | 78 ++++ crates/edit_prediction/src/edit_prediction.rs | 37 +- .../src/edit_prediction_tests.rs | 2 +- crates/edit_prediction/src/mercury.rs | 340 ++++++++++++++++++ .../edit_prediction/src/open_ai_response.rs | 31 ++ crates/edit_prediction/src/zeta1.rs | 178 ++++++++- .../src/zeta1/input_excerpt.rs | 231 ------------ crates/edit_prediction/src/zeta2.rs | 35 +- crates/edit_prediction_cli/src/predict.rs | 5 +- .../src/edit_prediction_button.rs | 112 +++++- .../src/edit_prediction_ui.rs | 4 +- ...s => external_provider_api_token_modal.rs} | 33 +- crates/icons/src/icons.rs | 15 +- .../language_models/src/provider/open_ai.rs | 2 +- crates/open_ai/src/open_ai.rs | 3 +- .../settings/src/settings_content/language.rs | 8 + .../zed/src/zed/edit_prediction_registry.rs | 5 + 18 files changed, 808 insertions(+), 322 deletions(-) create mode 100644 assets/icons/inception.svg create mode 100644 crates/edit_prediction/src/cursor_excerpt.rs create mode 100644 crates/edit_prediction/src/mercury.rs create mode 100644 crates/edit_prediction/src/open_ai_response.rs delete mode 100644 crates/edit_prediction/src/zeta1/input_excerpt.rs rename crates/edit_prediction_ui/src/{sweep_api_token_modal.rs => external_provider_api_token_modal.rs} (72%) diff --git a/assets/icons/inception.svg b/assets/icons/inception.svg new file mode 100644 index 0000000000000000000000000000000000000000..77a96c0b390ab9f2fe89143c2a89ba916000fabc --- /dev/null +++ b/assets/icons/inception.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f2f1d32ebcb2eaa151433bd49d275e0e2a3b817 --- /dev/null +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -0,0 +1,78 @@ +use language::{BufferSnapshot, Point}; +use std::ops::Range; + +pub fn editable_and_context_ranges_for_cursor_position( + position: Point, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> (Range, Range) { + let mut scope_range = position..position; + let mut remaining_edit_tokens = editable_region_token_limit; + + while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { + let parent_tokens = guess_token_count(parent.byte_range().len()); + let parent_point_range = Point::new( + parent.start_position().row as u32, + parent.start_position().column as u32, + ) + ..Point::new( + parent.end_position().row as u32, + parent.end_position().column as u32, + ); + if parent_point_range == scope_range { + break; + } else if parent_tokens <= editable_region_token_limit { + scope_range = parent_point_range; + remaining_edit_tokens = editable_region_token_limit - parent_tokens; + } else { + break; + } + } + + let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); + let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); + (editable_range, context_range) +} + +fn expand_range( + snapshot: &BufferSnapshot, + range: Range, + mut remaining_tokens: usize, +) -> Range { + let mut expanded_range = range; + expanded_range.start.column = 0; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + loop { + let mut expanded = false; + + if remaining_tokens > 0 && expanded_range.start.row > 0 { + expanded_range.start.row -= 1; + let line_tokens = + guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { + expanded_range.end.row += 1; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + let line_tokens = guess_token_count(expanded_range.end.column as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if !expanded { + break; + } + } + expanded_range +} + +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; + +pub fn guess_token_count(bytes: usize) -> usize { + bytes / BYTES_PER_TOKEN_GUESS +} diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index ea8f0af7e16dedd30a86284af5386829053d7fab..141fff3063b83d7e0003fddd6b4eba2d213d5fd5 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -51,8 +51,11 @@ use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +mod cursor_excerpt; mod license_detection; +pub mod mercury; mod onboarding_modal; +pub mod open_ai_response; mod prediction; pub mod sweep_ai; pub mod udiff; @@ -65,6 +68,7 @@ pub mod zeta2; mod edit_prediction_tests; use crate::license_detection::LicenseDetectionWatcher; +use crate::mercury::Mercury; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; @@ -96,6 +100,12 @@ impl FeatureFlag for SweepFeatureFlag { const NAME: &str = "sweep-ai"; } +pub struct MercuryFeatureFlag; + +impl FeatureFlag for MercuryFeatureFlag { + const NAME: &str = "mercury"; +} + pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { context: EditPredictionExcerptOptions { max_bytes: 512, @@ -157,6 +167,7 @@ pub struct EditPredictionStore { eval_cache: Option>, edit_prediction_model: EditPredictionModel, pub sweep_ai: SweepAi, + pub mercury: Mercury, data_collection_choice: DataCollectionChoice, reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, @@ -169,6 +180,7 @@ pub enum EditPredictionModel { Zeta1, Zeta2, Sweep, + Mercury, } #[derive(Debug, Clone, PartialEq)] @@ -474,6 +486,7 @@ impl EditPredictionStore { eval_cache: None, edit_prediction_model: EditPredictionModel::Zeta2, sweep_ai: SweepAi::new(cx), + mercury: Mercury::new(cx), data_collection_choice, reject_predictions_tx: reject_tx, rated_predictions: Default::default(), @@ -509,6 +522,15 @@ impl EditPredictionStore { .is_some() } + pub fn has_mercury_api_token(&self) -> bool { + self.mercury + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() + } + #[cfg(feature = "eval-support")] pub fn with_eval_cache(&mut self, cache: Arc) { self.eval_cache = Some(cache); @@ -868,7 +890,7 @@ impl EditPredictionStore { fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { match self.edit_prediction_model { EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} - EditPredictionModel::Sweep => return, + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { @@ -1013,7 +1035,7 @@ impl EditPredictionStore { ) { match self.edit_prediction_model { EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} - EditPredictionModel::Sweep => return, + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } self.reject_predictions_tx @@ -1373,6 +1395,17 @@ impl EditPredictionStore { diagnostic_search_range.clone(), cx, ), + EditPredictionModel::Mercury => self.mercury.request_prediction( + &project, + &active_buffer, + snapshot.clone(), + position, + events, + &project_state.recent_paths, + related_files, + diagnostic_search_range.clone(), + cx, + ), }; cx.spawn(async move |this, cx| { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 8d5bad9ed8990769fd512ecfe523cf8d79aebca6..0b7e289bb32b5a10c32a4bd34f118d7cb6c7d43c 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1620,7 +1620,7 @@ async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut Te buffer.edit( [( 0..0, - " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS), + " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS), )], None, cx, diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs new file mode 100644 index 0000000000000000000000000000000000000000..40c0fdfac021f937df5172fd423d3b6bfc5f8146 --- /dev/null +++ b/crates/edit_prediction/src/mercury.rs @@ -0,0 +1,340 @@ +use anyhow::{Context as _, Result}; +use cloud_llm_client::predict_edits_v3::Event; +use credentials_provider::CredentialsProvider; +use edit_prediction_context::RelatedFile; +use futures::{AsyncReadExt as _, FutureExt, future::Shared}; +use gpui::{ + App, AppContext as _, Entity, Task, + http_client::{self, AsyncBody, Method}, +}; +use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _}; +use project::{Project, ProjectPath}; +use std::{ + collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant, +}; + +use crate::{ + EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response, + prediction::EditPredictionResult, +}; + +const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +const MAX_CONTEXT_TOKENS: usize = 150; +const MAX_REWRITE_TOKENS: usize = 350; + +pub struct Mercury { + pub api_token: Shared>>, +} + +impl Mercury { + pub fn new(cx: &App) -> Self { + Mercury { + api_token: load_api_token(cx).shared(), + } + } + + pub fn set_api_token(&mut self, api_token: Option, cx: &mut App) -> Task> { + self.api_token = Task::ready(api_token.clone()).shared(); + store_api_token_in_keychain(api_token, cx) + } + + pub fn request_prediction( + &self, + _project: &Entity, + active_buffer: &Entity, + snapshot: BufferSnapshot, + position: language::Anchor, + events: Vec>, + _recent_paths: &VecDeque, + related_files: Vec, + _diagnostic_search_range: Range, + cx: &mut App, + ) -> Task>> { + let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { + return Task::ready(Ok(None)); + }; + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let http_client = cx.http_client(); + let cursor_point = position.to_point(&snapshot); + let buffer_snapshotted_at = Instant::now(); + + let result = cx.background_spawn(async move { + let (editable_range, context_range) = + crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + MAX_CONTEXT_TOKENS, + MAX_REWRITE_TOKENS, + ); + + let offset_range = editable_range.to_offset(&snapshot); + let prompt = build_prompt( + &events, + &related_files, + &snapshot, + full_path.as_ref(), + cursor_point, + editable_range, + context_range.clone(), + ); + + let inputs = EditPredictionInputs { + events: events, + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { + path: full_path.clone(), + max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), + excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { + start_line: cloud_llm_client::predict_edits_v3::Line( + context_range.start.row, + ), + text: snapshot + .text_for_range(context_range.clone()) + .collect::() + .into(), + }], + }], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + column: cursor_point.column, + line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), + }, + cursor_path: full_path.clone(), + }; + + let request_body = open_ai::Request { + model: "mercury-coder".into(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: vec![], + temperature: None, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + let buf = serde_json::to_vec(&request_body)?; + let body: AsyncBody = buf.into(); + + let request = http_client::Request::builder() + .uri(MERCURY_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .method(Method::POST) + .body(body) + .context("Failed to create request")?; + + let mut response = http_client + .send(request) + .await + .context("Failed to send request")?; + + let mut body: Vec = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("Failed to read response body")?; + + let response_received_at = Instant::now(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let mut response: open_ai::Response = + serde_json::from_slice(&body).context("Failed to parse response")?; + + let id = mem::take(&mut response.id); + let response_str = text_from_response(response).unwrap_or_default(); + + let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str); + let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str); + + let mut edits = Vec::new(); + const NO_PREDICTION_OUTPUT: &str = "None"; + + if response_str != NO_PREDICTION_OUTPUT { + let old_text = snapshot + .text_for_range(offset_range.clone()) + .collect::(); + edits.extend( + language::text_diff(&old_text, &response_str) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(offset_range.start + range.start) + ..snapshot.anchor_before(offset_range.start + range.end), + text, + ) + }), + ); + } + + anyhow::Ok((id, edits, snapshot, response_received_at, inputs)) + }); + + let buffer = active_buffer.clone(); + + cx.spawn(async move |cx| { + let (id, edits, old_snapshot, response_received_at, inputs) = + result.await.context("Mercury edit prediction failed")?; + anyhow::Ok(Some( + EditPredictionResult::new( + EditPredictionId(id.into()), + &buffer, + &old_snapshot, + edits.into(), + buffer_snapshotted_at, + response_received_at, + inputs, + cx, + ) + .await, + )) + }) + } +} + +fn build_prompt( + events: &[Arc], + related_files: &[RelatedFile], + cursor_buffer: &BufferSnapshot, + cursor_buffer_path: &Path, + cursor_point: Point, + editable_range: Range, + context_range: Range, +) -> String { + const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n"; + const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n"; + const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n"; + const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n"; + const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n"; + const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n"; + const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n"; + const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n"; + const CURSOR_TAG: &str = "<|cursor|>"; + const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: "; + const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: "; + + let mut prompt = String::new(); + + push_delimited( + &mut prompt, + RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END, + |prompt| { + for related_file in related_files { + for related_excerpt in &related_file.excerpts { + push_delimited( + prompt, + RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END, + |prompt| { + prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX); + prompt.push_str(related_file.path.path.as_unix_str()); + prompt.push('\n'); + prompt.push_str(&related_excerpt.text.to_string()); + }, + ); + } + } + }, + ); + + push_delimited( + &mut prompt, + CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END, + |prompt| { + prompt.push_str(CURRENT_FILE_PATH_PREFIX); + prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref()); + prompt.push('\n'); + + let prefix_range = context_range.start..editable_range.start; + let suffix_range = editable_range.end..context_range.end; + + prompt.extend(cursor_buffer.text_for_range(prefix_range)); + push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| { + let range_before_cursor = editable_range.start..cursor_point; + let range_after_cursor = cursor_point..editable_range.end; + prompt.extend(cursor_buffer.text_for_range(range_before_cursor)); + prompt.push_str(CURSOR_TAG); + prompt.extend(cursor_buffer.text_for_range(range_after_cursor)); + }); + prompt.extend(cursor_buffer.text_for_range(suffix_range)); + }, + ); + + push_delimited( + &mut prompt, + EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END, + |prompt| { + for event in events { + writeln!(prompt, "{event}").unwrap(); + } + }, + ); + + prompt +} + +fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) { + prompt.push_str(delimiters.start); + cb(prompt); + prompt.push_str(delimiters.end); +} + +pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token"; + +pub fn load_api_token(cx: &App) -> Task> { + if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN") + .ok() + .filter(|value| !value.is_empty()) + { + return Task::ready(Some(api_token)); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |cx| { + let (_, credentials) = credentials_provider + .read_credentials(MERCURY_CREDENTIALS_URL, &cx) + .await + .ok()??; + String::from_utf8(credentials).ok() + }) +} + +fn store_api_token_in_keychain(api_token: Option, cx: &App) -> Task> { + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + if let Some(api_token) = api_token { + credentials_provider + .write_credentials( + MERCURY_CREDENTIALS_URL, + MERCURY_CREDENTIALS_USERNAME, + api_token.as_bytes(), + cx, + ) + .await + .context("Failed to save Mercury API token to system keychain") + } else { + credentials_provider + .delete_credentials(MERCURY_CREDENTIALS_URL, cx) + .await + .context("Failed to delete Mercury API token from system keychain") + } + }) +} diff --git a/crates/edit_prediction/src/open_ai_response.rs b/crates/edit_prediction/src/open_ai_response.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7e3350936dd89c89849130ba279ad2914dd2bd8 --- /dev/null +++ b/crates/edit_prediction/src/open_ai_response.rs @@ -0,0 +1,31 @@ +pub fn text_from_response(mut res: open_ai::Response) -> Option { + let choice = res.choices.pop()?; + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return None; + } + + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return None; + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return None; + } + }; + Some(output_text) +} diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 06248603464563db12fa55a90c9c0bccf153c5f4..20f70421810c6d1678f844d1ec4c968b1ca96678 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,9 +1,8 @@ -mod input_excerpt; - use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ EditPredictionId, EditPredictionStore, ZedUpdateRequiredError, + cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, prediction::{EditPredictionInputs, EditPredictionResult}, }; use anyhow::{Context as _, Result}; @@ -12,7 +11,6 @@ use cloud_llm_client::{ predict_edits_v3::Event, }; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; -use input_excerpt::excerpt_for_cursor_position; use language::{ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff, }; @@ -495,10 +493,174 @@ pub fn format_event(event: &Event) -> String { } } -/// Typical number of string bytes per token for the purposes of limiting model input. This is -/// intentionally low to err on the side of underestimating limits. -pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; +#[derive(Debug)] +pub struct InputExcerpt { + pub context_range: Range, + pub editable_range: Range, + pub prompt: String, +} + +pub fn excerpt_for_cursor_position( + position: Point, + path: &str, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> InputExcerpt { + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + position, + snapshot, + editable_region_token_limit, + context_token_limit, + ); + + let mut prompt = String::new(); + + writeln!(&mut prompt, "```{path}").unwrap(); + if context_range.start == Point::zero() { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); + } + + for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { + prompt.push_str(chunk.text); + } + + push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); + + for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n```").unwrap(); + + InputExcerpt { + context_range, + editable_range, + prompt, + } +} + +fn push_editable_range( + cursor_position: Point, + snapshot: &BufferSnapshot, + editable_range: Range, + prompt: &mut String, +) { + writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { + prompt.push_str(chunk.text); + } + prompt.push_str(CURSOR_MARKER); + for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{App, AppContext}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use std::sync::Arc; + + #[gpui::test] + fn test_excerpt_for_cursor_position(cx: &mut App) { + let text = indoc! {r#" + fn foo() { + let x = 42; + println!("Hello, world!"); + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + return sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + } + numbers + } + "#}; + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + + // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion + // when a larger scope doesn't fit the editable region. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + let x = 42; + println!("Hello, world!"); + <|editable_region_start|> + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } -fn guess_token_count(bytes: usize) -> usize { - bytes / BYTES_PER_TOKEN_GUESS + fn generate_random_numbers() -> Vec { + <|editable_region_end|> + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + ```"#} + ); + + // The `bar` function won't fit within the editable region, so we resort to line-based expansion. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + fn bar() { + let x = 42; + let mut sum = 0; + <|editable_region_start|> + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + <|editable_region_end|> + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + ```"#} + ); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + } } diff --git a/crates/edit_prediction/src/zeta1/input_excerpt.rs b/crates/edit_prediction/src/zeta1/input_excerpt.rs deleted file mode 100644 index 853d74da463c19de4f1d3915cb703a53b6c43c61..0000000000000000000000000000000000000000 --- a/crates/edit_prediction/src/zeta1/input_excerpt.rs +++ /dev/null @@ -1,231 +0,0 @@ -use super::{ - CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, - guess_token_count, -}; -use language::{BufferSnapshot, Point}; -use std::{fmt::Write, ops::Range}; - -#[derive(Debug)] -pub struct InputExcerpt { - pub context_range: Range, - pub editable_range: Range, - pub prompt: String, -} - -pub fn excerpt_for_cursor_position( - position: Point, - path: &str, - snapshot: &BufferSnapshot, - editable_region_token_limit: usize, - context_token_limit: usize, -) -> InputExcerpt { - let mut scope_range = position..position; - let mut remaining_edit_tokens = editable_region_token_limit; - - while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = guess_token_count(parent.byte_range().len()); - let parent_point_range = Point::new( - parent.start_position().row as u32, - parent.start_position().column as u32, - ) - ..Point::new( - parent.end_position().row as u32, - parent.end_position().column as u32, - ); - if parent_point_range == scope_range { - break; - } else if parent_tokens <= editable_region_token_limit { - scope_range = parent_point_range; - remaining_edit_tokens = editable_region_token_limit - parent_tokens; - } else { - break; - } - } - - let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); - let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); - - let mut prompt = String::new(); - - writeln!(&mut prompt, "```{path}").unwrap(); - if context_range.start == Point::zero() { - writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); - } - - for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { - prompt.push_str(chunk.text); - } - - push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); - - for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n```").unwrap(); - - InputExcerpt { - context_range, - editable_range, - prompt, - } -} - -fn push_editable_range( - cursor_position: Point, - snapshot: &BufferSnapshot, - editable_range: Range, - prompt: &mut String, -) { - writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); - for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { - prompt.push_str(chunk.text); - } - prompt.push_str(CURSOR_MARKER); - for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); -} - -fn expand_range( - snapshot: &BufferSnapshot, - range: Range, - mut remaining_tokens: usize, -) -> Range { - let mut expanded_range = range; - expanded_range.start.column = 0; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - loop { - let mut expanded = false; - - if remaining_tokens > 0 && expanded_range.start.row > 0 { - expanded_range.start.row -= 1; - let line_tokens = - guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { - expanded_range.end.row += 1; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = guess_token_count(expanded_range.end.column as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if !expanded { - break; - } - } - expanded_range -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::{App, AppContext}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use std::sync::Arc; - - #[gpui::test] - fn test_excerpt_for_cursor_position(cx: &mut App) { - let text = indoc! {r#" - fn foo() { - let x = 42; - println!("Hello, world!"); - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - return sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.random_range(1..101)); - } - numbers - } - "#}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let snapshot = buffer.read(cx).snapshot(); - - // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion - // when a larger scope doesn't fit the editable region. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - let x = 42; - println!("Hello, world!"); - <|editable_region_start|> - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - <|editable_region_end|> - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - ```"#} - ); - - // The `bar` function won't fit within the editable region, so we resort to line-based expansion. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - fn bar() { - let x = 42; - let mut sum = 0; - <|editable_region_start|> - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - <|editable_region_end|> - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.random_range(1..101)); - ```"#} - ); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - } -} diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 4808f38fc529b1c34212dd0198d15fb03a0baddf..e542bc7e86e6e381766bbedac6a15f431e0693f1 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,5 +1,6 @@ #[cfg(feature = "eval-support")] use crate::EvalCacheEntryKind; +use crate::open_ai_response::text_from_response; use crate::prediction::EditPredictionResult; use crate::{ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs, @@ -199,7 +200,7 @@ pub fn request_prediction_with_zeta2( stream: false, max_completion_tokens: None, stop: generation_params.stop.unwrap_or_default(), - temperature: generation_params.temperature.unwrap_or(0.7), + temperature: generation_params.temperature.or(Some(0.7)), tool_choice: None, parallel_tool_calls: None, tools: vec![], @@ -324,35 +325,3 @@ pub fn request_prediction_with_zeta2( )) }) } - -pub fn text_from_response(mut res: open_ai::Response) -> Option { - let choice = res.choices.pop()?; - let output_text = match choice.message { - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(content)), - .. - } => content, - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Multipart(mut content)), - .. - } => { - if content.is_empty() { - log::error!("No output from Baseten completion response"); - return None; - } - - match content.remove(0) { - open_ai::MessagePart::Text { text } => text, - open_ai::MessagePart::Image { .. } => { - log::error!("Expected text, got an image"); - return None; - } - } - } - _ => { - log::error!("Invalid response message: {:?}", choice.message); - return None; - } - }; - Some(output_text) -} diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index db1fed70d82a1a19713dfe54dfd6cea2bfa03d5d..74e939b887ce15790993ec15f5973c7f5fd01866 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -198,8 +198,9 @@ pub async fn perform_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = edit_prediction::zeta2::text_from_response(response) - .unwrap_or_default(); + let response = + edit_prediction::open_ai_response::text_from_response(response) + .unwrap_or_default(); let prediction_finished_at = Instant::now(); fs::write(example_run_dir.join("prediction_response.md"), &response)?; diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index dd3ebab42029f5adb7570b71ae0cd662aff3328e..04c7614689c5fdc076ab0aa9c4b4fe7d68e2f582 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use codestral::CodestralEditPredictionDelegate; use copilot::{Copilot, Status}; -use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag}; +use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag}; use edit_prediction_types::EditPredictionDelegateHandle; use editor::{ Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll, @@ -23,6 +23,7 @@ use language::{ use project::DisableAiSettings; use regex::Regex; use settings::{ + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, @@ -44,7 +45,7 @@ use workspace::{ use zed_actions::OpenBrowser; use crate::{ - RatePredictions, SweepApiKeyModal, + ExternalProviderApiKeyModal, RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag, }; @@ -311,21 +312,31 @@ impl Render for EditPredictionButton { provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { let enabled = self.editor_enabled.unwrap_or(true); - let is_sweep = matches!( - provider, - EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - ) - ); - - let sweep_missing_token = is_sweep - && !edit_prediction::EditPredictionStore::try_global(cx) - .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token()); + let ep_icon; + let mut missing_token = false; - let ep_icon = match (is_sweep, enabled) { - (true, _) => IconName::SweepAi, - (false, true) => IconName::ZedPredict, - (false, false) => IconName::ZedPredictDisabled, + match provider { + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::SweepAi; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token()); + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::Inception; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token()); + } + _ => { + ep_icon = if enabled { + IconName::ZedPredict + } else { + IconName::ZedPredictDisabled + }; + } }; if edit_prediction::should_show_upsell_modal() { @@ -369,7 +380,7 @@ impl Render for EditPredictionButton { let show_editor_predictions = self.editor_show_predictions; let user = self.user_store.read(cx).current_user(); - let indicator_color = if sweep_missing_token { + let indicator_color = if missing_token { Some(Color::Error) } else if enabled && (!show_editor_predictions || over_limit) { Some(if over_limit { @@ -532,6 +543,12 @@ impl EditPredictionButton { )); } + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + if cx.has_flag::() { providers.push(EditPredictionProvider::Experimental( EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, @@ -628,7 +645,66 @@ impl EditPredictionButton { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { workspace.toggle_modal(window, cx, |window, cx| { - SweepApiKeyModal::new(window, cx) + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .sweep_ai + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) + }); + }); + }; + } else { + set_completion_provider(fs.clone(), cx, provider); + } + }); + + menu.item(entry) + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + let has_api_token = edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token()); + + let should_open_modal = !has_api_token || is_current; + + let entry = if has_api_token { + ContextMenuEntry::new("Mercury") + .toggleable(IconPosition::Start, is_current) + } else { + ContextMenuEntry::new("Mercury") + .icon(IconName::XCircle) + .icon_color(Color::Error) + .documentation_aside( + DocumentationSide::Left, + DocumentationEdge::Bottom, + |_| { + Label::new("Click to configure your Mercury API token") + .into_any_element() + }, + ) + }; + + let entry = entry.handler(move |window, cx| { + if should_open_modal { + if let Some(workspace) = window.root::().flatten() { + workspace.update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .mercury + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) }); }); }; diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 51b491c6b3512968bca4ce2e7ed73a505bd73a00..c177b5233c33feb4f5ff82f60bf3fb6981cf3ee8 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -1,7 +1,7 @@ mod edit_prediction_button; mod edit_prediction_context_view; +mod external_provider_api_token_modal; mod rate_prediction_modal; -mod sweep_api_token_modal; use std::any::{Any as _, TypeId}; @@ -17,7 +17,7 @@ use ui::{App, prelude::*}; use workspace::{SplitDirection, Workspace}; pub use edit_prediction_button::{EditPredictionButton, ToggleMenu}; -pub use sweep_api_token_modal::SweepApiKeyModal; +pub use external_provider_api_token_modal::ExternalProviderApiKeyModal; use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag; diff --git a/crates/edit_prediction_ui/src/sweep_api_token_modal.rs b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs similarity index 72% rename from crates/edit_prediction_ui/src/sweep_api_token_modal.rs rename to crates/edit_prediction_ui/src/external_provider_api_token_modal.rs index 80366fc2ac691f165d44e1e6a29a633522146984..bc312836e9fdd30237156ac532a055d1e23a2589 100644 --- a/crates/edit_prediction_ui/src/sweep_api_token_modal.rs +++ b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs @@ -6,18 +6,24 @@ use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*}; use ui_input::InputField; use workspace::ModalView; -pub struct SweepApiKeyModal { +pub struct ExternalProviderApiKeyModal { api_key_input: Entity, focus_handle: FocusHandle, + on_confirm: Box, &mut EditPredictionStore, &mut App)>, } -impl SweepApiKeyModal { - pub fn new(window: &mut Window, cx: &mut Context) -> Self { - let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token")); +impl ExternalProviderApiKeyModal { + pub fn new( + window: &mut Window, + cx: &mut Context, + on_confirm: impl Fn(Option, &mut EditPredictionStore, &mut App) + 'static, + ) -> Self { + let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key")); Self { api_key_input, focus_handle: cx.focus_handle(), + on_confirm: Box::new(on_confirm), } } @@ -30,39 +36,34 @@ impl SweepApiKeyModal { let api_key = (!api_key.trim().is_empty()).then_some(api_key); if let Some(ep_store) = EditPredictionStore::try_global(cx) { - ep_store.update(cx, |ep_store, cx| { - ep_store - .sweep_ai - .set_api_token(api_key, cx) - .detach_and_log_err(cx); - }); + ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx)) } cx.emit(DismissEvent); } } -impl EventEmitter for SweepApiKeyModal {} +impl EventEmitter for ExternalProviderApiKeyModal {} -impl ModalView for SweepApiKeyModal {} +impl ModalView for ExternalProviderApiKeyModal {} -impl Focusable for SweepApiKeyModal { +impl Focusable for ExternalProviderApiKeyModal { fn focus_handle(&self, _cx: &App) -> FocusHandle { self.focus_handle.clone() } } -impl Render for SweepApiKeyModal { +impl Render for ExternalProviderApiKeyModal { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() - .key_context("SweepApiKeyModal") + .key_context("ExternalApiKeyModal") .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(Self::confirm)) .elevation_2(cx) .w(px(400.)) .p_4() .gap_3() - .child(Headline::new("Sweep API Token").size(HeadlineSize::Small)) + .child(Headline::new("API Token").size(HeadlineSize::Small)) .child(self.api_key_input.clone()) .child( h_flex() diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index d28e2c1030c3c2378aa7997f4799c503cee97105..d79660356f04fd42425d9e549764a4c202d29d43 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -34,8 +34,8 @@ pub enum IconName { ArrowRightLeft, ArrowUp, ArrowUpRight, - Attach, AtSign, + Attach, AudioOff, AudioOn, Backspace, @@ -45,8 +45,8 @@ pub enum IconName { BellRing, Binary, Blocks, - BoltOutlined, BoltFilled, + BoltOutlined, Book, BookCopy, CaseSensitive, @@ -80,9 +80,9 @@ pub enum IconName { Debug, DebugBreakpoint, DebugContinue, + DebugDetach, DebugDisabledBreakpoint, DebugDisabledLogBreakpoint, - DebugDetach, DebugIgnoreBreakpoints, DebugLogBreakpoint, DebugPause, @@ -140,6 +140,7 @@ pub enum IconName { Hash, HistoryRerun, Image, + Inception, Indicator, Info, Json, @@ -147,6 +148,7 @@ pub enum IconName { Library, LineHeight, Link, + Linux, ListCollapse, ListFilter, ListTodo, @@ -172,8 +174,8 @@ pub enum IconName { PencilUnavailable, Person, Pin, - PlayOutlined, PlayFilled, + PlayOutlined, Plus, Power, Public, @@ -259,15 +261,14 @@ pub enum IconName { ZedAssistant, ZedBurnMode, ZedBurnModeOn, - ZedSrcCustom, - ZedSrcExtension, ZedPredict, ZedPredictDisabled, ZedPredictDown, ZedPredictError, ZedPredictUp, + ZedSrcCustom, + ZedSrcExtension, ZedXCopilot, - Linux, } impl IconName { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 46cea34e3e01cb0f8ad0f859827881f3ec74cad7..32ee95ce9bd423bf7f66efc1bc7440455380ab5c 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -438,7 +438,7 @@ pub fn into_open_ai( messages, stream, stop: request.stop, - temperature: request.temperature.unwrap_or(1.0), + temperature: request.temperature.or(Some(1.0)), max_completion_tokens: max_output_tokens, parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 6fdb393c9a13c7ff6a6981f949b4d0c865b9bff8..8ed70c9dd514cb59f5c7a160169031cbc28428e6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -266,7 +266,8 @@ pub struct Request { pub max_completion_tokens: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec, - pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, /// Whether to enable parallel function calling during tool use. diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index b466b4e0dd88bf41e0f77f67a38842305c11906f..25ff60e9f46cf797b815227222a3d27a6353c396 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -81,6 +81,7 @@ pub enum EditPredictionProvider { pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep"; pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2"; +pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury"; impl<'de> Deserialize<'de> for EditPredictionProvider { fn deserialize(deserializer: D) -> Result @@ -111,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider { EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) } + Content::Experimental(name) + if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME => + { + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) + } Content::Experimental(name) if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME => { diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 2d5746b87ab20de5d0aca47a4d5da60b9ec33d2a..77a1f71596f9cf1d2f4e32137580d0e3648359f5 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -9,6 +9,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::MistralLanguageModelProvider; use settings::{ + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore, }; @@ -219,6 +220,10 @@ fn assign_edit_prediction_provider( && cx.has_flag::() { edit_prediction::EditPredictionModel::Zeta2 + } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + edit_prediction::EditPredictionModel::Mercury } else { return false; } From e1d8c1a6a1af1821a6ab4cbdb87199c38ce1434f Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Sat, 6 Dec 2025 09:06:43 -0300 Subject: [PATCH 36/81] Improve visual alignment on the inline assistant (#44265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Just making all of the elements in the inline assistant more vertically centered. Screenshot 2025-12-06 at 12  02@2x Release Notes: - N/A --- crates/agent_ui/src/inline_prompt_editor.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index 0083648651645c456acfa19332d61b9f550ed4ed..b9852ea727c7974e3564fadc652f132076c01f09 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -10,8 +10,8 @@ use editor::{ }; use fs::Fs; use gpui::{ - AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable, - Subscription, TextStyle, TextStyleRefinement, WeakEntity, Window, + AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, Subscription, + TextStyle, TextStyleRefinement, WeakEntity, Window, }; use language_model::{LanguageModel, LanguageModelRegistry}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; @@ -100,7 +100,7 @@ impl Render for PromptEditor { let bottom_padding = match &self.mode { PromptEditorMode::Buffer { .. } => rems_from_px(2.0), - PromptEditorMode::Terminal { .. } => rems_from_px(8.0), + PromptEditorMode::Terminal { .. } => rems_from_px(4.0), }; buttons.extend(self.render_buttons(window, cx)); @@ -138,14 +138,13 @@ impl Render for PromptEditor { .pt_0p5() .pb(bottom_padding) .pr(right_padding) - .bg(cx.theme().colors().editor_background) .gap_0p5() + .justify_center() .border_y_1() .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) .child( h_flex() - .items_start() - .cursor(CursorStyle::Arrow) .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { this.model_selector .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); @@ -165,7 +164,7 @@ impl Render for PromptEditor { .flex_shrink_0() .items_center() .justify_center() - .gap_2() + .gap_1() .child(self.render_close_button(cx)) .map(|el| { let CodegenStatus::Error(error) = self.codegen_status(cx) else { @@ -206,13 +205,14 @@ impl Render for PromptEditor { this.child( h_flex() .size_full() + .justify_center() .child(div().w(left_gutter_width + px(6.))) .child( div() .size_full() .min_w_0() - .pb_px() - .pl_1() + .pt(rems_from_px(3.)) + .pl_0p5() .flex_1() .border_t_1() .border_color(cx.theme().colors().border_variant) From 0565992d7a0bb53ad9b620196ad23ae0ed02ebab Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Sat, 6 Dec 2025 09:06:51 -0300 Subject: [PATCH 37/81] project picker: Improve tooltip on secondary actions (#44264) This PR adds the keybinding for the "open in project window" button on the project picker as well as makes the tooltip for the content bit on the active list item only show up for the content container. https://github.com/user-attachments/assets/42944cf7-e4e7-4bf8-8695-48df8b3a35eb Release Notes: - N/A --- crates/recent_projects/src/recent_projects.rs | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 280bf17a385db09c10c2844ac7126b3aac7adafb..8c081205444fbc13fb1d94c297946261fcab7fb3 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -132,7 +132,8 @@ pub fn init(cx: &mut App) { let create_new_window = open_recent.create_new_window; with_active_or_new_workspace(cx, move |workspace, window, cx| { let Some(recent_projects) = workspace.active_modal::(cx) else { - RecentProjects::open(workspace, create_new_window, window, cx); + let focus_handle = workspace.focus_handle(cx); + RecentProjects::open(workspace, create_new_window, window, focus_handle, cx); return; }; @@ -246,11 +247,12 @@ impl RecentProjects { workspace: &mut Workspace, create_new_window: bool, window: &mut Window, + focus_handle: FocusHandle, cx: &mut Context, ) { let weak = cx.entity().downgrade(); workspace.toggle_modal(window, cx, |window, cx| { - let delegate = RecentProjectsDelegate::new(weak, create_new_window, true); + let delegate = RecentProjectsDelegate::new(weak, create_new_window, true, focus_handle); Self::new(delegate, 34., window, cx) }) @@ -289,10 +291,16 @@ pub struct RecentProjectsDelegate { // Flag to reset index when there is a new query vs not reset index when user delete an item reset_selected_match_index: bool, has_any_non_local_projects: bool, + focus_handle: FocusHandle, } impl RecentProjectsDelegate { - fn new(workspace: WeakEntity, create_new_window: bool, render_paths: bool) -> Self { + fn new( + workspace: WeakEntity, + create_new_window: bool, + render_paths: bool, + focus_handle: FocusHandle, + ) -> Self { Self { workspace, workspaces: Vec::new(), @@ -302,6 +310,7 @@ impl RecentProjectsDelegate { render_paths, reset_selected_match_index: true, has_any_non_local_projects: false, + focus_handle, } } @@ -544,12 +553,23 @@ impl PickerDelegate for RecentProjectsDelegate { paths, }; + let focus_handle = self.focus_handle.clone(); + let secondary_actions = h_flex() .gap_px() .child( IconButton::new("open_new_window", IconName::ArrowUpRight) .icon_size(IconSize::XSmall) - .tooltip(Tooltip::text("Open Project in New Window")) + .tooltip({ + move |_, cx| { + Tooltip::for_action_in( + "Open Project in New Window", + &menu::SecondaryConfirm, + &focus_handle, + cx, + ) + } + }) .on_click(cx.listener(move |this, _event, window, cx| { cx.stop_propagation(); window.prevent_default(); @@ -577,8 +597,9 @@ impl PickerDelegate for RecentProjectsDelegate { .spacing(ListItemSpacing::Sparse) .child( h_flex() - .flex_grow() + .id("projecy_info_container") .gap_3() + .flex_grow() .when(self.has_any_non_local_projects, |this| { this.child(match location { SerializedWorkspaceLocation::Local => Icon::new(IconName::Screen) @@ -600,6 +621,13 @@ impl PickerDelegate for RecentProjectsDelegate { highlighted.paths.clear(); } highlighted.render(window, cx) + }) + .tooltip(move |_, cx| { + let tooltip_highlighted_location = highlighted_match.clone(); + cx.new(|_| MatchTooltip { + highlighted_location: tooltip_highlighted_location, + }) + .into() }), ) .map(|el| { @@ -608,13 +636,6 @@ impl PickerDelegate for RecentProjectsDelegate { } else { el.end_hover_slot(secondary_actions) } - }) - .tooltip(move |_, cx| { - let tooltip_highlighted_location = highlighted_match.clone(); - cx.new(|_| MatchTooltip { - highlighted_location: tooltip_highlighted_location, - }) - .into() }), ) } From d72746773faf458452ee393cf3ec01a164f98b37 Mon Sep 17 00:00:00 2001 From: David Kleingeld Date: Sat, 6 Dec 2025 13:08:01 +0100 Subject: [PATCH 38/81] Put tracy dependency behind feature tracy (#44277) It broke CI, now it no longer does :tada: Proper fix followes after the weekend. Release Notes: - N/A --- Cargo.toml | 1 - crates/zed/Cargo.toml | 3 +++ crates/ztracing/Cargo.toml | 5 ++++- docs/src/performance.md | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 858da1dc460cda2fecbaf2ed94d437bfd25d9644..be78357b2515b12acad808f436cf7359877b5418 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -699,7 +699,6 @@ tree-sitter-rust = "0.24" tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347 tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" } tracing = "0.1.40" -tracing-tracy = "0.11.4" unicase = "2.6" unicode-script = "0.5.7" unicode-segmentation = "1.10" diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index e304ad7f5cd94c05daab2755cb9e7bed21fe0f8d..a9a8ba87c645e99a68409865a95737e3222c87b3 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -10,6 +10,9 @@ authors = ["Zed Team "] [lints] workspace = true +[features] +tracy = ["ztracing/tracy"] + [[bin]] name = "zed" path = "src/zed-main.rs" diff --git a/crates/ztracing/Cargo.toml b/crates/ztracing/Cargo.toml index fbc9dc032d2d485f74a15e5fe3b073a7017911fd..c68ac15423cf3a26a8dc855769ba44b9ac29696a 100644 --- a/crates/ztracing/Cargo.toml +++ b/crates/ztracing/Cargo.toml @@ -8,10 +8,13 @@ license = "GPL-3.0-or-later" [lints] workspace = true +[features] +tracy = ["tracing-tracy"] + [dependencies] tracing.workspace = true tracing-subscriber = "0.3.22" -tracing-tracy = { workspace = true, features = ["enable", "ondemand"] } +tracing-tracy = { version = "0.11.4", optional = true, features = ["enable", "ondemand"] } ztracing_macro.workspace = true diff --git a/docs/src/performance.md b/docs/src/performance.md index 4adc38f5eea27de26f1d5818b6787fb78ae1d1ad..544e39e94babbf9c335a847af8819ad5b00494d1 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -28,7 +28,7 @@ fn should_appear_in_profile(kitty: Cat) { } ``` -Then either compile Zed with `ZTRACING=1 cargo r --release`. The release build is optional but highly recommended as like every program Zeds performance characteristics change dramatically with optimizations. You do not want to chase slowdowns that do not exist in release. +Then either compile Zed with `ZTRACING=1 cargo r --features tracy --release`. The release build is optional but highly recommended as like every program Zeds performance characteristics change dramatically with optimizations. You do not want to chase slowdowns that do not exist in release. ## One time Setup/Building the profiler: From a0848daab44c05f69e6adfcfa0682b84a0bd06d7 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Sat, 6 Dec 2025 09:43:37 -0300 Subject: [PATCH 39/81] agent ui: Fix clicks on the notification sometimes not being triggered (#44280) Closes https://github.com/zed-industries/zed/issues/43292 We were seeing clicks on the "View Panel" and "Dismiss" buttons sometimes not being triggered. I believe this was happening because the overall parent also had an on_click, which due to this being a popup window, was causing conflicts with the buttons' on click handlers. This should hopefully fix that issue. Release Notes: - agent: Fixed an issue where clicking on the agent notification buttons would sometimes not trigger their actions. --- crates/agent_ui/src/ui/agent_notification.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/agent_ui/src/ui/agent_notification.rs b/crates/agent_ui/src/ui/agent_notification.rs index af2a022f147b79a0a299c17dd26c7e9a8b62aeb9..34ca0bb32a82aa23d1b954554ce2dfec436bfe1c 100644 --- a/crates/agent_ui/src/ui/agent_notification.rs +++ b/crates/agent_ui/src/ui/agent_notification.rs @@ -106,9 +106,6 @@ impl Render for AgentNotification { .font(ui_font) .border_color(cx.theme().colors().border) .rounded_xl() - .on_click(cx.listener(|_, _, _, cx| { - cx.emit(AgentNotificationEvent::Accepted); - })) .child( h_flex() .items_start() From 9e33243015d39ac54060c074d275aca3de77f2d9 Mon Sep 17 00:00:00 2001 From: John Tur Date: Sat, 6 Dec 2025 11:31:05 -0500 Subject: [PATCH 40/81] Fix unregistration logic for pull diagnostics (#44294) Even if `workspace_diagnostics_refresh_tasks` is empty, registrations which didn't advertise support for workspace diagnostics may still exist. Release Notes: - N/A --- crates/project/src/lsp_store.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 59b7a6932d4733a78959e9e4f481a63589811a52..1ae6d1295f37df31aac03e2019cb5510c836fb1c 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -12647,30 +12647,29 @@ impl LspStore { .language_servers .get_mut(&server_id) .context("Could not obtain Language Servers state")?; - local + let registrations = local .language_server_dynamic_registrations .get_mut(&server_id) .with_context(|| { format!("Expected dynamic registration to exist for server {server_id}") - })?.diagnostics + })?; + registrations.diagnostics .remove(&Some(unreg.id.clone())) .with_context(|| format!( "Attempted to unregister non-existent diagnostic registration with ID {}", unreg.id) )?; + let removed_last_diagnostic_provider = registrations.diagnostics.is_empty(); - let mut has_any_diagnostic_providers_still = true; if let LanguageServerState::Running { workspace_diagnostics_refresh_tasks, .. } = state { workspace_diagnostics_refresh_tasks.remove(&Some(unreg.id.clone())); - has_any_diagnostic_providers_still = - !workspace_diagnostics_refresh_tasks.is_empty(); } - if !has_any_diagnostic_providers_still { + if removed_last_diagnostic_provider { server.update_capabilities(|capabilities| { debug_assert!(capabilities.diagnostic_provider.is_some()); capabilities.diagnostic_provider = None; From b2e35b5f999b1640251d155d13a0b3914e7c96a1 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Sat, 6 Dec 2025 09:56:49 -0800 Subject: [PATCH 41/81] zlog: Fix dynamic mod path filtering (#44296) Closes #ISSUE Release Notes: - Linux: cleaned up noisy logs from `zbus` --- crates/zlog/src/filter.rs | 30 +++++++++++++++++++----------- crates/zlog/src/sink.rs | 6 +++--- crates/zlog/src/zlog.rs | 20 ++++++++++++++------ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/crates/zlog/src/filter.rs b/crates/zlog/src/filter.rs index e2ca04be60f4fe7eba7cdb2fc9eb983092d2331a..0be6f4ead5bf64aa47f7a60391bf377c9998cfb4 100644 --- a/crates/zlog/src/filter.rs +++ b/crates/zlog/src/filter.rs @@ -5,12 +5,12 @@ use std::sync::{ atomic::{AtomicU8, Ordering}, }; -use crate::{SCOPE_DEPTH_MAX, SCOPE_STRING_SEP_STR, Scope, ScopeAlloc, env_config, private}; +use crate::{SCOPE_DEPTH_MAX, SCOPE_STRING_SEP_STR, ScopeAlloc, ScopeRef, env_config, private}; use log; static ENV_FILTER: OnceLock = OnceLock::new(); -static SCOPE_MAP: RwLock> = RwLock::new(None); +static SCOPE_MAP: RwLock = RwLock::new(ScopeMap::empty()); pub const LEVEL_ENABLED_MAX_DEFAULT: log::LevelFilter = log::LevelFilter::Info; /// The maximum log level of verbosity that is enabled by default. @@ -59,7 +59,11 @@ pub fn is_possibly_enabled_level(level: log::Level) -> bool { level as u8 <= LEVEL_ENABLED_MAX_CONFIG.load(Ordering::Acquire) } -pub fn is_scope_enabled(scope: &Scope, module_path: Option<&str>, level: log::Level) -> bool { +pub fn is_scope_enabled( + scope: &ScopeRef<'_>, + module_path: Option<&str>, + level: log::Level, +) -> bool { // TODO: is_always_allowed_level that checks against LEVEL_ENABLED_MIN_CONFIG if !is_possibly_enabled_level(level) { // [FAST PATH] @@ -74,16 +78,11 @@ pub fn is_scope_enabled(scope: &Scope, module_path: Option<&str>, level: log::Le err.into_inner() }); - let Some(map) = global_scope_map.as_ref() else { - // on failure, return false because it's not <= LEVEL_ENABLED_MAX_STATIC - return is_enabled_by_default; - }; - - if map.is_empty() { + if global_scope_map.is_empty() { // if no scopes are enabled, return false because it's not <= LEVEL_ENABLED_MAX_STATIC return is_enabled_by_default; } - let enabled_status = map.is_enabled(scope, module_path, level); + let enabled_status = global_scope_map.is_enabled(scope, module_path, level); match enabled_status { EnabledStatus::NotConfigured => is_enabled_by_default, EnabledStatus::Enabled => true, @@ -107,7 +106,7 @@ pub fn refresh_from_settings(settings: &HashMap) { SCOPE_MAP.clear_poison(); err.into_inner() }); - global_map.replace(map_new); + *global_map = map_new; } log::trace!("Log configuration updated"); } @@ -395,12 +394,21 @@ impl ScopeMap { } EnabledStatus::NotConfigured } + + const fn empty() -> ScopeMap { + ScopeMap { + entries: vec![], + modules: vec![], + root_count: 0, + } + } } #[cfg(test)] mod tests { use log::LevelFilter; + use crate::Scope; use crate::private::scope_new; use super::*; diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs index 303e3139bc7cdb95ae01c7e87fff8f9bc6d100c2..07e87be1b071f2538e716bb8fd2b692527363fc4 100644 --- a/crates/zlog/src/sink.rs +++ b/crates/zlog/src/sink.rs @@ -8,7 +8,7 @@ use std::{ }, }; -use crate::{SCOPE_STRING_SEP_CHAR, Scope}; +use crate::{SCOPE_STRING_SEP_CHAR, ScopeRef}; // ANSI color escape codes for log levels const ANSI_RESET: &str = "\x1b[0m"; @@ -35,7 +35,7 @@ static SINK_FILE_SIZE_BYTES: AtomicU64 = AtomicU64::new(0); const SINK_FILE_SIZE_BYTES_MAX: u64 = 1024 * 1024; // 1 MB pub struct Record<'a> { - pub scope: Scope, + pub scope: ScopeRef<'a>, pub level: log::Level, pub message: &'a std::fmt::Arguments<'a>, pub module_path: Option<&'a str>, @@ -208,7 +208,7 @@ pub fn flush() { } struct SourceFmt<'a> { - scope: Scope, + scope: ScopeRef<'a>, module_path: Option<&'a str>, line: Option, ansi: bool, diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index bcd13216252e0b45f6dc553160e17c7216a87f27..3c154f790845da74dcf3a4f9bfdd830d2d32c9ec 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -70,15 +70,18 @@ impl log::Log for Zlog { if !self.enabled(record.metadata()) { return; } - let (crate_name_scope, module_scope) = match record.module_path_static() { + let module_path = record.module_path().or(record.file()); + let (crate_name_scope, module_scope) = match module_path { Some(module_path) => { let crate_name = private::extract_crate_name_from_module_path(module_path); - let crate_name_scope = private::scope_new(&[crate_name]); - let module_scope = private::scope_new(&[module_path]); + let crate_name_scope = private::scope_ref_new(&[crate_name]); + let module_scope = private::scope_ref_new(&[module_path]); (crate_name_scope, module_scope) } - // TODO: when do we hit this - None => (private::scope_new(&[]), private::scope_new(&["*unknown*"])), + None => { + // TODO: when do we hit this + (private::scope_new(&[]), private::scope_new(&["*unknown*"])) + } }; let level = record.metadata().level(); if !filter::is_scope_enabled(&crate_name_scope, Some(record.target()), level) { @@ -89,7 +92,7 @@ impl log::Log for Zlog { level, message: record.args(), // PERF(batching): store non-static paths in a cache + leak them and pass static str here - module_path: record.module_path().or(record.file()), + module_path, line: record.line(), }); } @@ -252,6 +255,10 @@ pub mod private { } pub const fn scope_new(scopes: &[&'static str]) -> Scope { + scope_ref_new(scopes) + } + + pub const fn scope_ref_new<'a>(scopes: &[&'a str]) -> ScopeRef<'a> { assert!(scopes.len() <= SCOPE_DEPTH_MAX); let mut scope = [""; SCOPE_DEPTH_MAX]; let mut i = 0; @@ -275,6 +282,7 @@ pub mod private { } pub type Scope = [&'static str; SCOPE_DEPTH_MAX]; +pub type ScopeRef<'a> = [&'a str; SCOPE_DEPTH_MAX]; pub type ScopeAlloc = [String; SCOPE_DEPTH_MAX]; const SCOPE_STRING_SEP_STR: &str = "."; const SCOPE_STRING_SEP_CHAR: char = '.'; From 16666f5357a7cb7ad69d55095f27affafdf06724 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Sat, 6 Dec 2025 20:49:21 +0200 Subject: [PATCH 42/81] Use single `languages::{rust_lang, markdown_lang}` in tests across the codebase (#44282) This allows referencing proper queries and keeping the tests up-to-date. Release Notes: - N/A --- crates/agent/src/tools/grep_tool.rs | 19 +- crates/agent/src/tools/read_file_tool.rs | 46 +-- crates/agent_ui/src/buffer_codegen.rs | 34 +-- crates/collab/src/tests.rs | 17 -- crates/collab/src/tests/editor_tests.rs | 7 +- crates/collab/src/tests/integration_tests.rs | 4 +- crates/debugger_ui/Cargo.toml | 1 + crates/debugger_ui/src/tests/inline_values.rs | 121 ++++---- crates/edit_prediction/src/zeta1.rs | 19 +- .../src/edit_prediction_context_tests.rs | 24 +- crates/edit_prediction_context/src/excerpt.rs | 20 +- crates/editor/src/items.rs | 20 +- crates/language/src/buffer_tests.rs | 183 +++-------- crates/language/src/language.rs | 32 +- .../src/syntax_map/syntax_map_tests.rs | 66 ++-- crates/markdown_preview/Cargo.toml | 1 + .../markdown_preview/src/markdown_parser.rs | 21 +- crates/outline/src/outline.rs | 88 +----- crates/outline_panel/src/outline_panel.rs | 288 ++++-------------- crates/project/src/project_tests.rs | 16 +- crates/vim/src/object.rs | 7 +- crates/zed/src/zed.rs | 63 +--- 22 files changed, 267 insertions(+), 830 deletions(-) diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index ec61b013e87ccb3afc133ee0a264e55a6d8baee9..0caba91564fd1fc9e670909490d4e776b8ad6f11 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -322,7 +322,6 @@ mod tests { use super::*; use gpui::{TestAppContext, UpdateGlobal}; - use language::{Language, LanguageConfig, LanguageMatcher}; use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; @@ -564,7 +563,7 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; project.update(cx, |project, _cx| { - project.languages().add(rust_lang().into()) + project.languages().add(language::rust_lang()) }); project @@ -793,22 +792,6 @@ mod tests { }); } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../../languages/src/rust/outline.scm")) - .unwrap() - } - #[gpui::test] async fn test_grep_security_boundaries(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 4457a6e5ca21a2fc88c76c718160d1d59171e66a..5b19bf36ee3a0949910d217880e2e95c49f021fc 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -302,7 +302,6 @@ mod test { use super::*; use crate::{ContextServerRegistry, Templates, Thread}; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; - use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use language_model::fake_provider::FakeLanguageModel; use project::{FakeFs, Project}; use prompt_store::ProjectContext; @@ -406,7 +405,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - language_registry.add(Arc::new(rust_lang())); + language_registry.add(language::rust_lang()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -596,49 +595,6 @@ mod test { }); } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query( - r#" - (line_comment) @annotation - - (struct_item - "struct" @context - name: (_) @name) @item - (enum_item - "enum" @context - name: (_) @name) @item - (enum_variant - name: (_) @name) @item - (field_declaration - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name - body: (_ "{" (_)* "}")) @item - (function_item - "fn" @context - name: (_) @name) @item - (mod_item - "mod" @context - name: (_) @name) @item - "#, - ) - .unwrap() - } - #[gpui::test] async fn test_read_file_security(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 0d014f50294f90aa2bda1f51025c937cc0e2ae56..f7e7884310458e97421768882df57934a19b4430 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -1295,8 +1295,9 @@ mod tests { }; use gpui::TestAppContext; use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, Point, tree_sitter_rust}; + use language::{Buffer, Point}; use language_model::{LanguageModelRegistry, TokenUsage}; + use languages::rust_lang; use rand::prelude::*; use settings::SettingsStore; use std::{future, sync::Arc}; @@ -1313,7 +1314,7 @@ mod tests { } } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1375,7 +1376,7 @@ mod tests { le } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1439,7 +1440,7 @@ mod tests { " \n", "}\n" // ); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1555,7 +1556,7 @@ mod tests { let x = 0; } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1672,27 +1673,4 @@ mod tests { }); chunks_tx } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } } diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 7d07360b8042ed54a9f19a82a2876e448e8a14a4..3785ee0b7abaeddeac5c9acb1718407ab5bd54f2 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use call::Room; use client::ChannelId; use gpui::{Entity, TestAppContext}; @@ -18,7 +16,6 @@ mod randomized_test_helpers; mod remote_editing_collaboration_tests; mod test_server; -use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; pub use randomized_test_helpers::{ RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan, }; @@ -51,17 +48,3 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic fn channel_id(room: &Entity, cx: &mut TestAppContext) -> Option { cx.read(|cx| room.read(cx).channel_id()) } - -fn rust_lang() -> Arc { - Arc::new(Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) -} diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 149a48db7439cc28e76fac5aae8b6e11f0837991..ba92e868126c7f27fb5051021fce44fe43c8d5e7 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -1,7 +1,4 @@ -use crate::{ - rpc::RECONNECT_TIMEOUT, - tests::{TestServer, rust_lang}, -}; +use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; use editor::{ DocumentColorsRenderMode, Editor, FETCH_COLORS_DEBOUNCE_TIMEOUT, MultiBufferOffset, RowInfo, @@ -23,7 +20,7 @@ use gpui::{ App, Rgba, SharedString, TestAppContext, UpdateGlobal, VisualContext, VisualTestContext, }; use indoc::indoc; -use language::FakeLspAdapter; +use language::{FakeLspAdapter, rust_lang}; use lsp::LSP_REQUEST_TIMEOUT; use pretty_assertions::assert_eq; use project::{ diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index fcda8688d427f3e6b937f00edc7c3586dfdbef36..391e7355ea196dfe25d363472918837ea817f450 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -2,7 +2,7 @@ use crate::{ rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, tests::{ RoomParticipants, TestClient, TestServer, channel_id, following_tests::join_channel, - room_participants, rust_lang, + room_participants, }, }; use anyhow::{Result, anyhow}; @@ -26,7 +26,7 @@ use language::{ Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope, language_settings::{Formatter, FormatterList}, - tree_sitter_rust, tree_sitter_typescript, + rust_lang, tree_sitter_rust, tree_sitter_typescript, }; use lsp::{LanguageServerId, OneOf}; use parking_lot::Mutex; diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index 325bcc300ae637ab46c36b7a3e7875e197f7d3d2..25d23b96b897001faec39498c5b08ef08b09a3a1 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -82,6 +82,7 @@ dap_adapters = { workspace = true, features = ["test-support"] } debugger_tools = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } tree-sitter-go.workspace = true unindent.workspace = true diff --git a/crates/debugger_ui/src/tests/inline_values.rs b/crates/debugger_ui/src/tests/inline_values.rs index 801e6d43623b50d69ea3ce297c274c2d7e5a8b14..379bc4c98f5341b089b5936ed8571da5a6280723 100644 --- a/crates/debugger_ui/src/tests/inline_values.rs +++ b/crates/debugger_ui/src/tests/inline_values.rs @@ -4,7 +4,7 @@ use dap::{Scope, StackFrame, Variable, requests::Variables}; use editor::{Editor, EditorMode, MultiBuffer}; use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext}; use language::{ - Language, LanguageConfig, LanguageMatcher, tree_sitter_python, tree_sitter_rust, + Language, LanguageConfig, LanguageMatcher, rust_lang, tree_sitter_python, tree_sitter_typescript, }; use project::{FakeFs, Project}; @@ -224,7 +224,7 @@ fn main() { .unwrap(); buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(Arc::new(rust_lang())), cx); + buffer.set_language(Some(rust_lang()), cx); }); let (editor, cx) = cx.add_window_view(|window, cx| { @@ -1521,23 +1521,6 @@ fn main() { }); } -fn rust_lang() -> Language { - let debug_variables_query = include_str!("../../../languages/src/rust/debugger.scm"); - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_debug_variables_query(debug_variables_query) - .unwrap() -} - #[gpui::test] async fn test_python_inline_values(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); @@ -1859,21 +1842,23 @@ fn python_lang() -> Language { .unwrap() } -fn go_lang() -> Language { +fn go_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/go/debugger.scm"); - Language::new( - LanguageConfig { - name: "Go".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["go".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "Go".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["go".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_go::LANGUAGE.into()), + Some(tree_sitter_go::LANGUAGE.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } /// Test utility function for inline values testing @@ -1891,7 +1876,7 @@ async fn test_inline_values_util( before: &str, after: &str, active_debug_line: Option, - language: Language, + language: Arc, executor: BackgroundExecutor, cx: &mut TestAppContext, ) { @@ -2091,7 +2076,7 @@ async fn test_inline_values_util( .unwrap(); buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(Arc::new(language)), cx); + buffer.set_language(Some(language), cx); }); let (editor, cx) = cx.add_window_view(|window, cx| { @@ -2276,55 +2261,61 @@ fn main() { .await; } -fn javascript_lang() -> Language { +fn javascript_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/javascript/debugger.scm"); - Language::new( - LanguageConfig { - name: "JavaScript".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["js".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "JavaScript".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["js".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } -fn typescript_lang() -> Language { +fn typescript_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/typescript/debugger.scm"); - Language::new( - LanguageConfig { - name: "TypeScript".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["ts".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "TypeScript".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["ts".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } -fn tsx_lang() -> Language { +fn tsx_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/tsx/debugger.scm"); - Language::new( - LanguageConfig { - name: "TSX".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["tsx".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "TSX".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["tsx".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TSX.into()), + Some(tree_sitter_typescript::LANGUAGE_TSX.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } #[gpui::test] diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 20f70421810c6d1678f844d1ec4c968b1ca96678..ad630484d392d75849bd33a52a55e63ea77ca23f 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -561,8 +561,7 @@ mod tests { use super::*; use gpui::{App, AppContext}; use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use std::sync::Arc; + use language::Buffer; #[gpui::test] fn test_excerpt_for_cursor_position(cx: &mut App) { @@ -591,7 +590,7 @@ mod tests { numbers } "#}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); let snapshot = buffer.read(cx).snapshot(); // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion @@ -649,18 +648,4 @@ mod tests { ```"#} ); } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs index f62df37e551db19145e9ea631b6ab6a16fefda78..dba8d89e593ccb60e7eae5d091708e82debef0f5 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -2,12 +2,12 @@ use super::*; use futures::channel::mpsc::UnboundedReceiver; use gpui::TestAppContext; use indoc::indoc; -use language::{Language, LanguageConfig, LanguageMatcher, Point, ToPoint as _, tree_sitter_rust}; +use language::{Point, ToPoint as _, rust_lang}; use lsp::FakeLanguageServer; use project::{FakeFs, LocationLink, Project}; use serde_json::json; use settings::SettingsStore; -use std::{fmt::Write as _, sync::Arc}; +use std::fmt::Write as _; use util::{path, test::marked_text_ranges}; #[gpui::test] @@ -508,23 +508,3 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { } output } - -pub(crate) fn rust_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - first_line_pattern: None, - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap(), - ) -} diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index 55a3d8f03b277d0ce40f1d2ac947c55abf93f1c9..3fc7eed4ace5a83992bf496aef3e364aea96e215 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -419,30 +419,14 @@ fn node_line_end(node: Node) -> Point { mod tests { use super::*; use gpui::{AppContext, TestAppContext}; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use language::Buffer; use util::test::{generate_marked_text, marked_text_offsets_by}; fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); buffer.read_with(cx, |buffer, _| buffer.snapshot()) } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } - fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range) { let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']); (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0]) diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index ca8937bebe3d3578c7fe2fdec2c6252bdd395e6d..3b9c17f80f10116f2302bab203966922cbf0bcb2 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1951,7 +1951,7 @@ mod tests { use super::*; use fs::MTime; use gpui::{App, VisualTestContext}; - use language::{LanguageMatcher, TestFile}; + use language::TestFile; use project::FakeFs; use std::path::{Path, PathBuf}; use util::{path, rel_path::RelPath}; @@ -1991,20 +1991,6 @@ mod tests { .unwrap() } - fn rust_language() -> Arc { - Arc::new(language::Language::new( - language::LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) - } - #[gpui::test] async fn test_deserialize(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); @@ -2086,7 +2072,9 @@ mod tests { { let project = Project::test(fs.clone(), [path!("/file.rs").as_ref()], cx).await; // Add Rust to the language, so that we can restore the language of the buffer - project.read_with(cx, |project, _| project.languages().add(rust_language())); + project.read_with(cx, |project, _| { + project.languages().add(languages::rust_lang()) + }); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index e95bc544a56ecf9d561936ca48b10ccffcb23e72..6b5d2450fe72f46b728be0f5b151801fe2e7fa70 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -6,6 +6,7 @@ use futures::FutureExt as _; use gpui::{App, AppContext as _, BorrowAppContext, Entity}; use gpui::{HighlightStyle, TestAppContext}; use indoc::indoc; +use pretty_assertions::assert_eq; use proto::deserialize_operation; use rand::prelude::*; use regex::RegexBuilder; @@ -46,8 +47,7 @@ fn test_line_endings(cx: &mut gpui::App) { init_settings(cx, |_| {}); cx.new(|cx| { - let mut buffer = - Buffer::local("one\r\ntwo\rthree", cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local("one\r\ntwo\rthree", cx).with_language(rust_lang(), cx); assert_eq!(buffer.text(), "one\ntwo\nthree"); assert_eq!(buffer.line_ending(), LineEnding::Windows); @@ -608,7 +608,7 @@ async fn test_normalize_whitespace(cx: &mut gpui::TestAppContext) { #[gpui::test] async fn test_reparse(cx: &mut gpui::TestAppContext) { let text = "fn a() {}"; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); // Wait for the initial text to parse cx.executor().run_until_parked(); @@ -735,7 +735,7 @@ async fn test_reparse(cx: &mut gpui::TestAppContext) { #[gpui::test] async fn test_resetting_language(cx: &mut gpui::TestAppContext) { let buffer = cx.new(|cx| { - let mut buffer = Buffer::local("{}", cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local("{}", cx).with_language(rust_lang(), cx); buffer.set_sync_parse_timeout(Duration::ZERO); buffer }); @@ -783,11 +783,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let outline = snapshot.outline(None); - pretty_assertions::assert_eq!( + assert_eq!( outline .items .iter() @@ -819,7 +819,7 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { ("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())), ("person", 3, None), ("time", 3, None), - ("impl Eq for Person", 0, None), + ("impl Eq for Person", 0, Some("".to_string())), ( "impl Drop for Person", 0, @@ -890,7 +890,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( @@ -970,7 +970,7 @@ fn test_outline_annotations(cx: &mut App) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( @@ -1018,7 +1018,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); // point is at the start of an item @@ -1093,7 +1093,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) { " .unindent(), ); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); // note, it would be nice to actually return the method test in this @@ -1112,8 +1112,7 @@ fn test_text_objects(cx: &mut App) { false, ); - let buffer = - cx.new(|cx| Buffer::local(text.clone(), cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text.clone(), cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let matches = snapshot @@ -1130,6 +1129,14 @@ fn test_text_objects(cx: &mut App) { "fn say() -> u8 { return /* hi */ 1 }", TextObject::AroundFunction ), + ( + "fn say() -> u8 { return /* hi */ 1 }", + TextObject::InsideClass + ), + ( + "impl Hello {\n fn say() -> u8 { return /* hi */ 1 }\n}", + TextObject::AroundClass + ), ], ) } @@ -1260,7 +1267,12 @@ fn test_enclosing_bracket_ranges(cx: &mut App) { #[gpui::test] fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &mut App) { let mut assert = |selection_text, bracket_pair_texts| { - assert_bracket_pairs(selection_text, bracket_pair_texts, javascript_lang(), cx) + assert_bracket_pairs( + selection_text, + bracket_pair_texts, + Arc::new(javascript_lang()), + cx, + ) }; assert( @@ -1293,7 +1305,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: & fn test_range_for_syntax_ancestor(cx: &mut App) { cx.new(|cx| { let text = "fn a() { b(|c| {}) }"; - let buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); let snapshot = buffer.snapshot(); assert_eq!( @@ -1345,7 +1357,7 @@ fn test_autoindent_with_soft_tabs(cx: &mut App) { cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx); assert_eq!(buffer.text(), "fn a() {\n \n}"); @@ -1387,7 +1399,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut App) { cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx); assert_eq!(buffer.text(), "fn a() {\n\t\n}"); @@ -1436,7 +1448,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App) .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Lines 2 and 3 don't match the indentation suggestion. When editing these lines, // their indentation is not adjusted. @@ -1577,7 +1589,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App) .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Insert a closing brace. It is outdented. buffer.edit_via_marked_text( @@ -1640,7 +1652,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Regression test: line does not get outdented due to syntax error buffer.edit_via_marked_text( @@ -1699,7 +1711,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut App) { .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); buffer.edit_via_marked_text( &" @@ -1749,7 +1761,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut App) { cx.new(|cx| { let text = "a\nb"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [(0..1, "\n"), (2..3, "\n")], Some(AutoindentMode::EachLine), @@ -1775,7 +1787,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut App) { " .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [(Point::new(3, 0)..Point::new(3, 0), "e(\n f()\n);\n")], Some(AutoindentMode::EachLine), @@ -1812,7 +1824,7 @@ fn test_autoindent_block_mode(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // When this text was copied, both of the quotation marks were at the same // indent level, but the indentation of the first line was not included in @@ -1895,7 +1907,7 @@ fn test_autoindent_block_mode_with_newline(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // First line contains just '\n', it's indentation is stored in "original_indent_columns" let original_indent_columns = vec![Some(4)]; @@ -1947,7 +1959,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // The original indent columns are not known, so this text is // auto-indented in a block as if the first line was copied in @@ -2038,7 +2050,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) { false, ); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [ @@ -2052,7 +2064,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) { cx, ); - pretty_assertions::assert_eq!( + assert_eq!( buffer.text(), " mod numbers { @@ -2246,7 +2258,7 @@ async fn test_async_autoindents_preserve_preview(cx: &mut TestAppContext) { // Then we request that a preview tab be preserved for the new version, even though it's edited. let buffer = cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // This causes autoindent to be async. buffer.set_sync_parse_timeout(Duration::ZERO); @@ -2704,7 +2716,7 @@ fn test_language_at_with_hidden_languages(cx: &mut App) { .unindent(); let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - language_registry.add(Arc::new(markdown_lang())); + language_registry.add(markdown_lang()); language_registry.add(Arc::new(markdown_inline_lang())); let mut buffer = Buffer::local(text, cx); @@ -2746,9 +2758,9 @@ fn test_language_at_for_markdown_code_block(cx: &mut App) { .unindent(); let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - language_registry.add(Arc::new(markdown_lang())); + language_registry.add(markdown_lang()); language_registry.add(Arc::new(markdown_inline_lang())); - language_registry.add(Arc::new(rust_lang())); + language_registry.add(rust_lang()); let mut buffer = Buffer::local(text, cx); buffer.set_language_registry(language_registry.clone()); @@ -3145,7 +3157,7 @@ async fn test_preview_edits(cx: &mut TestAppContext) { cx: &mut TestAppContext, assert_fn: impl Fn(HighlightedText), ) { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let edits = buffer.read_with(cx, |buffer, _| { edits .into_iter() @@ -3556,7 +3568,7 @@ let word=öäpple.bar你 Öäpple word2-öÄpPlE-Pizza-word ÖÄPPLE word "#; let buffer = cx.new(|cx| { - let buffer = Buffer::local(contents, cx).with_language(Arc::new(rust_lang()), cx); + let buffer = Buffer::local(contents, cx).with_language(rust_lang(), cx); assert_eq!(buffer.text(), contents); buffer.check_invariants(); buffer @@ -3781,78 +3793,6 @@ fn erb_lang() -> Language { .unwrap() } -fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - .with_brackets_query( - r#" - ("{" @open "}" @close) - "#, - ) - .unwrap() - .with_text_object_query( - r#" - (function_item - body: (_ - "{" - (_)* @function.inside - "}" )) @function.around - - (line_comment)+ @comment.around - - (block_comment) @comment.around - "#, - ) - .unwrap() - .with_outline_query( - r#" - (line_comment) @annotation - - (struct_item - "struct" @context - name: (_) @name) @item - (enum_item - "enum" @context - name: (_) @name) @item - (enum_variant - name: (_) @name) @item - (field_declaration - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name - body: (_ "{" (_)* "}")) @item - (function_item - "fn" @context - name: (_) @name) @item - (mod_item - "mod" @context - name: (_) @name) @item - "#, - ) - .unwrap() -} - fn json_lang() -> Language { Language::new( LanguageConfig { @@ -3890,32 +3830,6 @@ fn javascript_lang() -> Language { .unwrap() } -pub fn markdown_lang() -> Language { - Language::new( - LanguageConfig { - name: "Markdown".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["md".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_md::LANGUAGE.into()), - ) - .with_injection_query( - r#" - (fenced_code_block - (info_string - (language) @injection.language) - (code_fence_content) @injection.content) - - ((inline) @injection.content - (#set! injection.language "markdown-inline")) - "#, - ) - .unwrap() -} - pub fn markdown_inline_lang() -> Language { Language::new( LanguageConfig { @@ -3942,12 +3856,11 @@ fn get_tree_sexp(buffer: &Entity, cx: &mut gpui::TestAppContext) -> Stri fn assert_bracket_pairs( selection_text: &'static str, bracket_pair_texts: Vec<&'static str>, - language: Language, + language: Arc, cx: &mut App, ) { let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false); - let buffer = - cx.new(|cx| Buffer::local(expected_text.clone(), cx).with_language(Arc::new(language), cx)); + let buffer = cx.new(|cx| Buffer::local(expected_text.clone(), cx).with_language(language, cx)); let buffer = buffer.update(cx, |buffer, _cx| buffer.snapshot()); let selection_range = selection_ranges[0].clone(); diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 0451be3ee164aa70b549f3502a45f5e52fbafce3..891e4842a49b81659c9e4a9bf42a0655ef30abcb 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -2656,7 +2656,28 @@ pub fn rust_lang() -> Arc { text_objects: Some(Cow::from(include_str!( "../../languages/src/rust/textobjects.scm" ))), - ..LanguageQueries::default() + highlights: Some(Cow::from(include_str!( + "../../languages/src/rust/highlights.scm" + ))), + embedding: Some(Cow::from(include_str!( + "../../languages/src/rust/embedding.scm" + ))), + injections: Some(Cow::from(include_str!( + "../../languages/src/rust/injections.scm" + ))), + overrides: Some(Cow::from(include_str!( + "../../languages/src/rust/overrides.scm" + ))), + redactions: None, + runnables: Some(Cow::from(include_str!( + "../../languages/src/rust/runnables.scm" + ))), + debugger: Some(Cow::from(include_str!( + "../../languages/src/rust/debugger.scm" + ))), + imports: Some(Cow::from(include_str!( + "../../languages/src/rust/imports.scm" + ))), }) .expect("Could not parse queries"); Arc::new(language) @@ -2685,6 +2706,15 @@ pub fn markdown_lang() -> Arc { injections: Some(Cow::from(include_str!( "../../languages/src/markdown/injections.scm" ))), + highlights: Some(Cow::from(include_str!( + "../../languages/src/markdown/highlights.scm" + ))), + indents: Some(Cow::from(include_str!( + "../../languages/src/markdown/indents.scm" + ))), + outline: Some(Cow::from(include_str!( + "../../languages/src/markdown/outline.scm" + ))), ..LanguageQueries::default() }) .expect("Could not parse markdown queries"); diff --git a/crates/language/src/syntax_map/syntax_map_tests.rs b/crates/language/src/syntax_map/syntax_map_tests.rs index 9c4eecad363de386cddc6e943e20e5762634d713..1eb63772760719a381d16795ecde0c4a3293c789 100644 --- a/crates/language/src/syntax_map/syntax_map_tests.rs +++ b/crates/language/src/syntax_map/syntax_map_tests.rs @@ -1,9 +1,9 @@ use super::*; use crate::{ - LanguageConfig, LanguageMatcher, - buffer_tests::{markdown_inline_lang, markdown_lang}, + LanguageConfig, LanguageMatcher, buffer_tests::markdown_inline_lang, markdown_lang, rust_lang, }; use gpui::App; +use pretty_assertions::assert_eq; use rand::rngs::StdRng; use std::{env, ops::Range, sync::Arc}; use text::{Buffer, BufferId, ReplicaId}; @@ -84,7 +84,7 @@ fn test_splice_included_ranges() { #[gpui::test] fn test_syntax_map_layers_for_range(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let language = Arc::new(rust_lang()); + let language = rust_lang(); registry.add(language.clone()); let mut buffer = Buffer::new( @@ -181,11 +181,11 @@ fn test_syntax_map_layers_for_range(cx: &mut App) { #[gpui::test] fn test_dynamic_language_injection(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let markdown = Arc::new(markdown_lang()); + let markdown = markdown_lang(); let markdown_inline = Arc::new(markdown_inline_lang()); registry.add(markdown.clone()); registry.add(markdown_inline.clone()); - registry.add(Arc::new(rust_lang())); + registry.add(rust_lang()); registry.add(Arc::new(ruby_lang())); let mut buffer = Buffer::new( @@ -291,7 +291,7 @@ fn test_typing_multiple_new_injections(cx: &mut App) { assert_capture_ranges( &syntax_map, &buffer, - &["field"], + &["property"], "fn a() { test_macro!(b.«c»(vec![d.«e»])) }", ); } @@ -329,16 +329,16 @@ fn test_pasting_new_injection_line_between_others(cx: &mut App) { assert_capture_ranges( &syntax_map, &buffer, - &["struct"], + &["type"], " fn a() { - b!(«B {}»); - c!(«C {}»); - d!(«D {}»); - h!(«H {}»); - e!(«E {}»); - f!(«F {}»); - g!(«G {}»); + b!(«B» {}); + c!(«C» {}); + d!(«D» {}); + h!(«H» {}); + e!(«E» {}); + f!(«F» {}); + g!(«G» {}); } ", ); @@ -376,7 +376,7 @@ fn test_joining_injections_with_child_injections(cx: &mut App) { assert_capture_ranges( &syntax_map, &buffer, - &["field"], + &["property"], " fn a() { b!( @@ -900,7 +900,7 @@ fn test_random_syntax_map_edits_rust_macros(rng: StdRng, cx: &mut App) { .repeat(2); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let language = Arc::new(rust_lang()); + let language = rust_lang(); registry.add(language.clone()); test_random_edits(text, registry, language, rng); @@ -1147,11 +1147,11 @@ fn test_edit_sequence(language_name: &str, steps: &[&str], cx: &mut App) -> (Buf let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); registry.add(Arc::new(elixir_lang())); registry.add(Arc::new(heex_lang())); - registry.add(Arc::new(rust_lang())); + registry.add(rust_lang()); registry.add(Arc::new(ruby_lang())); registry.add(Arc::new(html_lang())); registry.add(Arc::new(erb_lang())); - registry.add(Arc::new(markdown_lang())); + registry.add(markdown_lang()); registry.add(Arc::new(markdown_inline_lang())); let language = registry @@ -1287,35 +1287,6 @@ fn erb_lang() -> Language { .unwrap() } -fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query( - r#" - (field_identifier) @field - (struct_expression) @struct - "#, - ) - .unwrap() - .with_injection_query( - r#" - (macro_invocation - (token_tree) @injection.content - (#set! injection.language "rust")) - "#, - ) - .unwrap() -} - fn elixir_lang() -> Language { Language::new( LanguageConfig { @@ -1425,6 +1396,7 @@ fn assert_capture_ranges( actual_ranges.push(capture.node.byte_range()); } } + actual_ranges.dedup(); let (text, expected_ranges) = marked_text_ranges(&marked_string.unindent(), false); assert_eq!(text, buffer.text()); diff --git a/crates/markdown_preview/Cargo.toml b/crates/markdown_preview/Cargo.toml index 89e5ec5921a3ad330a75343e980dfeff0f535b00..d61ec00cc8cfd5e04768381b64d5230682924623 100644 --- a/crates/markdown_preview/Cargo.toml +++ b/crates/markdown_preview/Cargo.toml @@ -37,3 +37,4 @@ workspace.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } diff --git a/crates/markdown_preview/src/markdown_parser.rs b/crates/markdown_preview/src/markdown_parser.rs index 7b3886d10f5c8977f8766bddc39fb81f6d8f316f..b17ee5cac455605ce49d0dd436d163e49f2954bd 100644 --- a/crates/markdown_preview/src/markdown_parser.rs +++ b/crates/markdown_preview/src/markdown_parser.rs @@ -1467,9 +1467,7 @@ mod tests { use ParsedMarkdownListItemType::*; use core::panic; use gpui::{AbsoluteLength, BackgroundExecutor, DefiniteLength}; - use language::{ - HighlightId, Language, LanguageConfig, LanguageMatcher, LanguageRegistry, tree_sitter_rust, - }; + use language::{HighlightId, LanguageRegistry}; use pretty_assertions::assert_eq; async fn parse(input: &str) -> ParsedMarkdown { @@ -3053,7 +3051,7 @@ fn main() { #[gpui::test] async fn test_code_block_with_language(executor: BackgroundExecutor) { let language_registry = Arc::new(LanguageRegistry::test(executor.clone())); - language_registry.add(rust_lang()); + language_registry.add(language::rust_lang()); let parsed = parse_markdown( "\ @@ -3079,21 +3077,6 @@ fn main() { ); } - fn rust_lang() -> Arc { - Arc::new(Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".into()], - ..Default::default() - }, - collapsed_placeholder: " /* ... */ ".to_string(), - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) - } - fn h1(contents: MarkdownParagraph, source_range: Range) -> ParsedMarkdownElement { ParsedMarkdownElement::Heading(ParsedMarkdownHeading { source_range, diff --git a/crates/outline/src/outline.rs b/crates/outline/src/outline.rs index 7127627226d3aa55877f067038b69e6e848e1c3a..1f5cf1edab15a190a9f15d6106190eae637b9f3d 100644 --- a/crates/outline/src/outline.rs +++ b/crates/outline/src/outline.rs @@ -391,7 +391,6 @@ mod tests { use super::*; use gpui::{TestAppContext, VisualTestContext}; use indoc::indoc; - use language::{Language, LanguageConfig, LanguageMatcher}; use project::{FakeFs, Project}; use serde_json::json; use util::{path, rel_path::rel_path}; @@ -418,7 +417,9 @@ mod tests { .await; let project = Project::test(fs, [path!("/dir").as_ref()], cx).await; - project.read_with(cx, |project, _| project.languages().add(rust_lang())); + project.read_with(cx, |project, _| { + project.languages().add(language::rust_lang()) + }); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); @@ -581,89 +582,6 @@ mod tests { }) } - fn rust_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query( - r#"(struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - - (enum_variant - (visibility_modifier)? @context - name: (_) @name) @item - - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name) @item - - (trait_item - (visibility_modifier)? @context - "trait" @context - name: (_) @name) @item - - (function_item - (visibility_modifier)? @context - (function_modifiers)? @context - "fn" @context - name: (_) @name) @item - - (function_signature_item - (visibility_modifier)? @context - (function_modifiers)? @context - "fn" @context - name: (_) @name) @item - - (macro_definition - . "macro_rules!" @context - name: (_) @name) @item - - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - - (type_item - (visibility_modifier)? @context - "type" @context - name: (_) @name) @item - - (associated_type - "type" @context - name: (_) @name) @item - - (const_item - (visibility_modifier)? @context - "const" @context - name: (_) @name) @item - - (field_declaration - (visibility_modifier)? @context - name: (_) @name) @item -"#, - ) - .unwrap(), - ) - } - #[track_caller] fn assert_single_caret_at_row( editor: &Entity, diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 6e78b8a1e1f573d9870d42c6a5e99c8574e6979a..85cca3c2b1273d6abcd85af6db8df7fdcb411220 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -5220,7 +5220,7 @@ impl GenerationState { mod tests { use db::indoc; use gpui::{TestAppContext, VisualTestContext, WindowHandle}; - use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use language::rust_lang; use pretty_assertions::assert_eq; use project::FakeFs; use search::{ @@ -5243,9 +5243,7 @@ mod tests { let root = path!("/rust-analyzer"); populate_with_test_ra_project(&fs, root).await; let project = Project::test(fs.clone(), [Path::new(root)], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new(rust_lang())) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -5478,9 +5476,7 @@ mod tests { let root = path!("/rust-analyzer"); populate_with_test_ra_project(&fs, root).await; let project = Project::test(fs.clone(), [Path::new(root)], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new(rust_lang())) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -5617,9 +5613,7 @@ mod tests { let root = path!("/rust-analyzer"); populate_with_test_ra_project(&fs, root).await; let project = Project::test(fs.clone(), [Path::new(root)], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new(rust_lang())) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -5816,7 +5810,8 @@ mod tests { outline_panel.selected_entry(), cx, ), - "fn_lifetime_fn.rs <==== selected" + "outline: pub(super) fn hints +outline: fn hints_lifetimes_named <==== selected" ); assert_eq!( selected_row_text(&new_active_editor, cx), @@ -6029,24 +6024,7 @@ struct OutlineEntryExcerpt { ) .await; let project = Project::test(fs.clone(), [Path::new(root)], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - - (field_declaration - (visibility_modifier)? @context - name: (_) @name) @item -"#, - ) - .unwrap(), - )) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -6992,35 +6970,6 @@ outline: struct OutlineEntryExcerpt .await; } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query( - r#" - (field_identifier) @field - (struct_expression) @struct - "#, - ) - .unwrap() - .with_injection_query( - r#" - (macro_invocation - (token_tree) @injection.content - (#set! injection.language "rust")) - "#, - ) - .unwrap() - } - fn snapshot(outline_panel: &OutlinePanel, cx: &App) -> MultiBufferSnapshot { outline_panel .active_editor() @@ -7086,44 +7035,7 @@ outline: struct OutlineEntryExcerpt .await; let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @context - "for"? @context - type: (_) @context - body: (_)) @item - (function_item - (visibility_modifier)? @context - "fn" @context - name: (_) @name - parameters: (_) @context) @item - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - (field_declaration - (visibility_modifier)? @context - name: (_) @name - ":" @context - type: (_) @context) @item - "#, - ) - .unwrap(), - )) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -7174,15 +7086,15 @@ outline: struct OutlineEntryExcerpt " outline: mod outer <==== selected outline: pub struct OuterStruct - outline: field: String + outline: field outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) + outline: pub fn new + outline: pub fn method outline: mod inner - outline: pub fn inner_function() + outline: pub fn inner_function outline: pub struct InnerStruct - outline: value: i32 -outline: fn main()" + outline: value +outline: fn main" ) ); }); @@ -7232,7 +7144,7 @@ outline: fn main()" indoc!( " outline: mod outer <==== selected -outline: fn main()" +outline: fn main" ) ); }); @@ -7257,15 +7169,15 @@ outline: fn main()" " outline: mod outer <==== selected outline: pub struct OuterStruct - outline: field: String + outline: field outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) + outline: pub fn new + outline: pub fn method outline: mod inner - outline: pub fn inner_function() + outline: pub fn inner_function outline: pub struct InnerStruct - outline: value: i32 -outline: fn main()" + outline: value +outline: fn main" ) ); }); @@ -7321,7 +7233,7 @@ outline: fn main()" indoc!( " outline: mod outer -outline: fn main()" +outline: fn main" ) ); }); @@ -7378,44 +7290,7 @@ outline: fn main()" .await; let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @context - "for"? @context - type: (_) @context - body: (_)) @item - (function_item - (visibility_modifier)? @context - "fn" @context - name: (_) @name - parameters: (_) @context) @item - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - (field_declaration - (visibility_modifier)? @context - name: (_) @name - ":" @context - type: (_) @context) @item - "#, - ) - .unwrap(), - )) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); @@ -7462,14 +7337,16 @@ outline: fn main()" indoc!( " outline: struct Config - outline: name: String - outline: value: i32 + outline: name + outline: value outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) + outline: fn new + outline: fn get_value outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" + outline: Active + outline: Inactive +outline: fn process_config +outline: fn main" ) ); }); @@ -7500,14 +7377,16 @@ outline: fn main()" indoc!( " outline: struct Config <==== selected - outline: name: String - outline: value: i32 + outline: name + outline: value outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) + outline: fn new + outline: fn get_value outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" + outline: Active + outline: Inactive +outline: fn process_config +outline: fn main" ) ); }); @@ -7535,11 +7414,13 @@ outline: fn main()" " outline: struct Config <==== selected outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) + outline: fn new + outline: fn get_value outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" + outline: Active + outline: Inactive +outline: fn process_config +outline: fn main" ) ); }); @@ -7566,14 +7447,16 @@ outline: fn main()" indoc!( " outline: struct Config <==== selected - outline: name: String - outline: value: i32 + outline: name + outline: value outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) + outline: fn new + outline: fn get_value outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" + outline: Active + outline: Inactive +outline: fn process_config +outline: fn main" ) ); }); @@ -7622,44 +7505,7 @@ outline: fn main()" .await; let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @context - "for"? @context - type: (_) @context - body: (_)) @item - (function_item - (visibility_modifier)? @context - "fn" @context - name: (_) @name - parameters: (_) @context) @item - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - (field_declaration - (visibility_modifier)? @context - name: (_) @name - ":" @context - type: (_) @context) @item - "#, - ) - .unwrap(), - )) - }); + project.read_with(cx, |project, _| project.languages().add(rust_lang())); let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -7710,15 +7556,15 @@ outline: fn main()" " outline: mod outer <==== selected outline: pub struct OuterStruct - outline: field: String + outline: field outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) + outline: pub fn new + outline: pub fn method outline: mod inner - outline: pub fn inner_function() + outline: pub fn inner_function outline: pub struct InnerStruct - outline: value: i32 -outline: fn main()" + outline: value +outline: fn main" ) ); }); @@ -7759,7 +7605,7 @@ outline: fn main()" let expected_collapsed_output = indoc!( " outline: mod outer <==== selected - outline: fn main()" + outline: fn main" ); outline_panel.update(cx, |panel, cx| { @@ -7787,15 +7633,15 @@ outline: fn main()" " outline: mod outer <==== selected outline: pub struct OuterStruct - outline: field: String + outline: field outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) + outline: pub fn new + outline: pub fn method outline: mod inner - outline: pub fn inner_function() + outline: pub fn inner_function outline: pub struct InnerStruct - outline: value: i32 - outline: fn main()" + outline: value + outline: fn main" ); outline_panel.update(cx, |panel, cx| { diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 8adba2dea16391c35096c487c4eff0098d52df56..24b2280edee55a0131c73f6b91b3cea7adc6bbad 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -28,7 +28,7 @@ use language::{ ManifestName, ManifestProvider, ManifestQuery, OffsetRangeExt, Point, ToPoint, ToolchainList, ToolchainLister, language_settings::{LanguageSettingsContent, language_settings}, - tree_sitter_rust, tree_sitter_typescript, + rust_lang, tree_sitter_typescript, }; use lsp::{ DiagnosticSeverity, DocumentChanges, FileOperationFilter, NumberOrString, TextDocumentEdit, @@ -10468,20 +10468,6 @@ fn js_lang() -> Arc { )) } -fn rust_lang() -> Arc { - Arc::new(Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) -} - fn python_lang(fs: Arc) -> Arc { struct PythonMootToolchainLister(Arc); #[async_trait] diff --git a/crates/vim/src/object.rs b/crates/vim/src/object.rs index 2f5ccac07bfe5f6f11b048e317523292dd74294d..f11386d02d6846343645b6c7514603f16396163c 100644 --- a/crates/vim/src/object.rs +++ b/crates/vim/src/object.rs @@ -2382,9 +2382,10 @@ mod test { Mode::Insert, ); - cx.set_state("let a = (test::call(), 'p', my_macro!{ˇ});", Mode::Normal); - cx.simulate_keystrokes("c a a"); - cx.assert_state("let a = (test::call(), 'p'ˇ);", Mode::Insert); + // TODO regressed with the up-to-date Rust grammar. + // cx.set_state("let a = (test::call(), 'p', my_macro!{ˇ});", Mode::Normal); + // cx.simulate_keystrokes("c a a"); + // cx.assert_state("let a = (test::call(), 'p'ˇ);", Mode::Insert); cx.set_state("let a = [test::call(ˇ), 300];", Mode::Normal); cx.simulate_keystrokes("c i a"); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 164d6b8383fe940e3a92d5461edbff878300474a..1361fcdba788752099c8e5b37b51e751fccf4dfd 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -2255,7 +2255,8 @@ mod tests { Action, AnyWindowHandle, App, AssetSource, BorrowAppContext, TestAppContext, UpdateGlobal, VisualTestContext, WindowHandle, actions, }; - use language::{LanguageMatcher, LanguageRegistry}; + use language::LanguageRegistry; + use languages::{markdown_lang, rust_lang}; use pretty_assertions::{assert_eq, assert_ne}; use project::{Project, ProjectPath}; use semver::Version; @@ -2895,9 +2896,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); let workspace = window.root(cx).unwrap(); @@ -3327,9 +3326,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); let workspace = window.root(cx).unwrap(); @@ -3421,9 +3418,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); let workspace = window.root(cx).unwrap(); @@ -3494,7 +3489,7 @@ mod tests { let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; project.update(cx, |project, _| { - project.languages().add(markdown_language()); + project.languages().add(markdown_lang()); project.languages().add(rust_lang()); }); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); @@ -3647,8 +3642,8 @@ mod tests { let project = Project::test(app_state.fs.clone(), [], cx).await; project.update(cx, |project, _| { - project.languages().add(rust_lang()); - project.languages().add(markdown_language()); + project.languages().add(language::rust_lang()); + project.languages().add(language::markdown_lang()); }); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); @@ -3727,9 +3722,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); let workspace = window.root(cx).unwrap(); @@ -3831,9 +3824,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); let pane = workspace @@ -4225,9 +4216,7 @@ mod tests { .await; let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(markdown_language()) - }); + project.update(cx, |project, _cx| project.languages().add(markdown_lang())); let workspace = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); let pane = workspace .read_with(cx, |workspace, _| workspace.active_pane().clone()) @@ -4914,7 +4903,7 @@ mod tests { let state = Arc::get_mut(&mut app_state).unwrap(); state.build_window_options = build_window_options; - app_state.languages.add(markdown_language()); + app_state.languages.add(markdown_lang()); gpui_tokio::init(cx); theme::init(theme::LoadThemes::JustBase, cx); @@ -4965,34 +4954,6 @@ mod tests { }) } - fn rust_lang() -> Arc { - Arc::new(language::Language::new( - language::LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) - } - - fn markdown_language() -> Arc { - Arc::new(language::Language::new( - language::LanguageConfig { - name: "Markdown".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["md".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_md::LANGUAGE.into()), - )) - } - #[track_caller] fn assert_key_bindings_for( window: AnyWindowHandle, From a574ae877922c6451645df7abeef1a2f45d5c572 Mon Sep 17 00:00:00 2001 From: Remco Smits Date: Sat, 6 Dec 2025 20:31:08 +0100 Subject: [PATCH 43/81] debugger: Start work on adding session snapshot feature (#44298) This PR adds the basic logic for a feature that allows you to visit any stopped information back in time. We will follow up with PRs to improve this and actually add UI for it so the UX is better. https://github.com/user-attachments/assets/42d8a5b3-8ab8-471a-bdd0-f579662eadd6 Edit Anthony: We feature flagged this so external users won't be able to access this until the feature is polished Release Notes: - N/A --------- Co-authored-by: Anthony Eid --- Cargo.lock | 1 + crates/debugger_ui/Cargo.toml | 1 + crates/debugger_ui/src/debugger_panel.rs | 37 ++- crates/debugger_ui/src/session/running.rs | 2 +- crates/project/src/debugger/dap_store.rs | 2 +- crates/project/src/debugger/session.rs | 266 ++++++++++++++-------- 6 files changed, 215 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a8f0096a7a1219ee30b287c61efd9f77f4b9d223..0bbde0bdfddb0b11b715bce230cb82cb4c74cb0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4583,6 +4583,7 @@ dependencies = [ "db", "debugger_tools", "editor", + "feature_flags", "file_icons", "futures 0.3.31", "fuzzy", diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index 25d23b96b897001faec39498c5b08ef08b09a3a1..fb79b1b0790b28d7204774720bf9c413cfed64e6 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -37,6 +37,7 @@ dap_adapters = { workspace = true, optional = true } db.workspace = true debugger_tools.workspace = true editor.workspace = true +feature_flags.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index ffdd4a22e3d092eb5d3d6626dcfe8b167ae03936..fe81ac641196dbbc5ceecaede0785ca72336c261 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -15,6 +15,7 @@ use dap::adapters::DebugAdapterName; use dap::{DapRegistry, StartDebuggingRequestArguments}; use dap::{client::SessionId, debugger_settings::DebuggerSettings}; use editor::{Editor, MultiBufferOffset, ToPoint}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{ Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId, EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, @@ -42,6 +43,12 @@ use workspace::{ }; use zed_actions::ToggleFocus; +pub struct DebuggerHistoryFeatureFlag; + +impl FeatureFlag for DebuggerHistoryFeatureFlag { + const NAME: &'static str = "debugger-history"; +} + const DEBUG_PANEL_KEY: &str = "DebugPanel"; pub struct DebugPanel { @@ -284,7 +291,7 @@ impl DebugPanel { } }); - session.update(cx, |session, _| match &mut session.mode { + session.update(cx, |session, _| match &mut session.state { SessionState::Booting(state_task) => { *state_task = Some(boot_task); } @@ -805,6 +812,34 @@ impl DebugPanel { } }), ) + .when(cx.has_flag::(), |this| { + this.child( + IconButton::new( + "debug-back-in-history", + IconName::HistoryRerun, + ) + .icon_size(IconSize::Small) + .on_click( + window.listener_for( + running_state, + |this, _, _window, cx| { + this.session().update(cx, |session, cx| { + let ix = session + .active_history() + .unwrap_or_else(|| { + session.history().len() + }); + + session.go_back_to_history( + Some(ix.saturating_sub(1)), + cx, + ); + }) + }, + ), + ), + ) + }) .child(Divider::vertical()) .child( IconButton::new("debug-restart", IconName::RotateCcw) diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index b82f839edee82f884c1419d44a2344c39c8abd0d..bc99d6ac8e42b0a706df4a09177ae2103d5939e2 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1743,7 +1743,7 @@ impl RunningState { let is_building = self.session.update(cx, |session, cx| { session.shutdown(cx).detach(); - matches!(session.mode, session::SessionState::Booting(_)) + matches!(session.state, session::SessionState::Booting(_)) }); if is_building { diff --git a/crates/project/src/debugger/dap_store.rs b/crates/project/src/debugger/dap_store.rs index a82286441d625561009f4f9259f5c06fe424ff10..4a588e7c436f5f29fffd953b8fce988daa4655d8 100644 --- a/crates/project/src/debugger/dap_store.rs +++ b/crates/project/src/debugger/dap_store.rs @@ -692,7 +692,7 @@ impl DapStore { } VariableLookupKind::Expression => { let Ok(eval_task) = session.read_with(cx, |session, _| { - session.mode.request_dap(EvaluateCommand { + session.state.request_dap(EvaluateCommand { expression: inline_value_location.variable_name.clone(), frame_id: Some(stack_frame_id), source: None, diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index 47fe98827cbc163682ef6f002eff4008967d4ced..a63e9066c9a30233ee1edb15aac13da145cb76b2 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -1,7 +1,3 @@ -use crate::debugger::breakpoint_store::BreakpointSessionState; -use crate::debugger::dap_command::{DataBreakpointContext, ReadMemory}; -use crate::debugger::memory::{self, Memory, MemoryIterator, MemoryPageBuilder, PageAddress}; - use super::breakpoint_store::{ BreakpointStore, BreakpointStoreEvent, BreakpointUpdatedReason, SourceBreakpoint, }; @@ -14,6 +10,9 @@ use super::dap_command::{ TerminateCommand, TerminateThreadsCommand, ThreadsCommand, VariablesCommand, }; use super::dap_store::DapStore; +use crate::debugger::breakpoint_store::BreakpointSessionState; +use crate::debugger::dap_command::{DataBreakpointContext, ReadMemory}; +use crate::debugger::memory::{self, Memory, MemoryIterator, MemoryPageBuilder, PageAddress}; use anyhow::{Context as _, Result, anyhow, bail}; use base64::Engine; use collections::{HashMap, HashSet, IndexMap}; @@ -42,15 +41,13 @@ use gpui::{ Task, WeakEntity, }; use http_client::HttpClient; - use node_runtime::NodeRuntime; use remote::RemoteClient; -use rpc::ErrorExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use smol::net::{TcpListener, TcpStream}; use std::any::TypeId; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, VecDeque}; use std::net::Ipv4Addr; use std::ops::RangeInclusive; use std::path::PathBuf; @@ -71,6 +68,9 @@ use util::command::new_smol_command; use util::{ResultExt, debug_panic, maybe}; use worktree::Worktree; +const MAX_TRACKED_OUTPUT_EVENTS: usize = 5000; +const DEBUG_HISTORY_LIMIT: usize = 10; + #[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] #[repr(transparent)] pub struct ThreadId(pub i64); @@ -118,11 +118,11 @@ impl ThreadStatus { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Thread { dap: dap::Thread, stack_frames: Vec, - stack_frames_error: Option, + stack_frames_error: Option, _has_stopped: bool, } @@ -672,7 +672,18 @@ impl ThreadStates { .any(|status| *status == ThreadStatus::Stopped) } } -const MAX_TRACKED_OUTPUT_EVENTS: usize = 5000; + +// TODO(debugger): Wrap dap types with reference counting so the UI doesn't have to clone them on refresh +#[derive(Default)] +pub struct SessionSnapshot { + threads: IndexMap, + thread_states: ThreadStates, + variables: HashMap>, + stack_frames: IndexMap, + locations: HashMap, + modules: Vec, + loaded_sources: Vec, +} type IsEnabled = bool; @@ -680,23 +691,19 @@ type IsEnabled = bool; pub struct OutputToken(pub usize); /// Represents a current state of a single debug adapter and provides ways to mutate it. pub struct Session { - pub mode: SessionState, + pub state: SessionState, + active_snapshot: SessionSnapshot, + snapshots: VecDeque, + selected_snapshot_index: Option, id: SessionId, label: Option, adapter: DebugAdapterName, pub(super) capabilities: Capabilities, child_session_ids: HashSet, parent_session: Option>, - modules: Vec, - loaded_sources: Vec, output_token: OutputToken, output: Box>, - threads: IndexMap, - thread_states: ThreadStates, watchers: HashMap, - variables: HashMap>, - stack_frames: IndexMap, - locations: HashMap, is_session_terminated: bool, requests: HashMap>>>>, pub(crate) breakpoint_store: Entity, @@ -858,24 +865,20 @@ impl Session { .detach(); Self { - mode: SessionState::Booting(None), + state: SessionState::Booting(None), + snapshots: VecDeque::with_capacity(DEBUG_HISTORY_LIMIT), + selected_snapshot_index: None, + active_snapshot: Default::default(), id: session_id, child_session_ids: HashSet::default(), parent_session, capabilities: Capabilities::default(), watchers: HashMap::default(), - variables: Default::default(), - stack_frames: Default::default(), - thread_states: ThreadStates::default(), output_token: OutputToken(0), output: circular_buffer::CircularBuffer::boxed(), requests: HashMap::default(), - modules: Vec::default(), - loaded_sources: Vec::default(), - threads: IndexMap::default(), background_tasks: Vec::default(), restart_task: None, - locations: Default::default(), is_session_terminated: false, ignore_breakpoints: false, breakpoint_store, @@ -899,7 +902,7 @@ impl Session { } pub fn worktree(&self) -> Option> { - match &self.mode { + match &self.state { SessionState::Booting(_) => None, SessionState::Running(local_mode) => local_mode.worktree.upgrade(), } @@ -960,7 +963,7 @@ impl Session { ) .await?; this.update(cx, |this, cx| { - match &mut this.mode { + match &mut this.state { SessionState::Booting(task) if task.is_some() => { task.take().unwrap().detach_and_log_err(cx); } @@ -969,7 +972,7 @@ impl Session { debug_panic!("Attempting to boot a session that is already running"); } }; - this.mode = SessionState::Running(mode); + this.state = SessionState::Running(mode); cx.emit(SessionStateEvent::Running); })?; @@ -1061,7 +1064,7 @@ impl Session { } pub fn binary(&self) -> Option<&DebugAdapterBinary> { - match &self.mode { + match &self.state { SessionState::Booting(_) => None, SessionState::Running(running_mode) => Some(&running_mode.binary), } @@ -1107,25 +1110,25 @@ impl Session { } pub fn is_started(&self) -> bool { - match &self.mode { + match &self.state { SessionState::Booting(_) => false, SessionState::Running(running) => running.is_started, } } pub fn is_building(&self) -> bool { - matches!(self.mode, SessionState::Booting(_)) + matches!(self.state, SessionState::Booting(_)) } pub fn as_running_mut(&mut self) -> Option<&mut RunningMode> { - match &mut self.mode { + match &mut self.state { SessionState::Running(local_mode) => Some(local_mode), SessionState::Booting(_) => None, } } pub fn as_running(&self) -> Option<&RunningMode> { - match &self.mode { + match &self.state { SessionState::Running(local_mode) => Some(local_mode), SessionState::Booting(_) => None, } @@ -1269,7 +1272,7 @@ impl Session { let adapter_id = self.adapter().to_string(); let request = Initialize { adapter_id }; - let SessionState::Running(running) = &self.mode else { + let SessionState::Running(running) = &self.state else { return Task::ready(Err(anyhow!( "Cannot send initialize request, task still building" ))); @@ -1317,7 +1320,7 @@ impl Session { dap_store: WeakEntity, cx: &mut Context, ) -> Task> { - match &self.mode { + match &self.state { SessionState::Running(local_mode) => { local_mode.initialize_sequence(&self.capabilities, initialize_rx, dap_store, cx) } @@ -1333,10 +1336,12 @@ impl Session { active_thread_id: ThreadId, cx: &mut Context, ) { - match &mut self.mode { + match &mut self.state { SessionState::Running(local_mode) => { if !matches!( - self.thread_states.thread_state(active_thread_id), + self.active_snapshot + .thread_states + .thread_state(active_thread_id), Some(ThreadStatus::Stopped) ) { return; @@ -1411,8 +1416,51 @@ impl Session { }) } + fn session_state(&self) -> &SessionSnapshot { + self.selected_snapshot_index + .and_then(|ix| self.snapshots.get(ix)) + .unwrap_or_else(|| &self.active_snapshot) + } + + fn push_to_history(&mut self) { + if !self.has_ever_stopped() { + return; + } + + while self.snapshots.len() >= DEBUG_HISTORY_LIMIT { + self.snapshots.pop_front(); + } + + self.snapshots + .push_back(std::mem::take(&mut self.active_snapshot)); + } + + pub fn history(&self) -> &VecDeque { + &self.snapshots + } + + pub fn go_back_to_history(&mut self, ix: Option, cx: &mut Context<'_, Session>) { + if self.selected_snapshot_index == ix { + return; + } + + self.selected_snapshot_index = ix; + + if ix.is_some() { + cx.emit(SessionEvent::Stopped(None)); + } + + cx.notify(); + } + + pub fn active_history(&self) -> Option { + self.selected_snapshot_index + } + fn handle_stopped_event(&mut self, event: StoppedEvent, cx: &mut Context) { - self.mode.stopped(); + self.push_to_history(); + + self.state.stopped(); // todo(debugger): Find a clean way to get around the clone let breakpoint_store = self.breakpoint_store.clone(); if let Some((local, path)) = self.as_running_mut().and_then(|local| { @@ -1431,14 +1479,16 @@ impl Session { }; if event.all_threads_stopped.unwrap_or_default() || event.thread_id.is_none() { - self.thread_states.stop_all_threads(); + self.active_snapshot.thread_states.stop_all_threads(); self.invalidate_command_type::(); } // Event if we stopped all threads we still need to insert the thread_id // to our own data if let Some(thread_id) = event.thread_id { - self.thread_states.stop_thread(ThreadId(thread_id)); + self.active_snapshot + .thread_states + .stop_thread(ThreadId(thread_id)); self.invalidate_state( &StackTraceCommand { @@ -1451,8 +1501,8 @@ impl Session { } self.invalidate_generic(); - self.threads.clear(); - self.variables.clear(); + self.active_snapshot.threads.clear(); + self.active_snapshot.variables.clear(); cx.emit(SessionEvent::Stopped( event .thread_id @@ -1474,12 +1524,13 @@ impl Session { Events::Stopped(event) => self.handle_stopped_event(event, cx), Events::Continued(event) => { if event.all_threads_continued.unwrap_or_default() { - self.thread_states.continue_all_threads(); + self.active_snapshot.thread_states.continue_all_threads(); self.breakpoint_store.update(cx, |store, cx| { store.remove_active_position(Some(self.session_id()), cx) }); } else { - self.thread_states + self.active_snapshot + .thread_states .continue_thread(ThreadId(event.thread_id)); } // todo(debugger): We should be able to get away with only invalidating generic if all threads were continued @@ -1496,10 +1547,12 @@ impl Session { match event.reason { dap::ThreadEventReason::Started => { - self.thread_states.continue_thread(thread_id); + self.active_snapshot + .thread_states + .continue_thread(thread_id); } dap::ThreadEventReason::Exited => { - self.thread_states.exit_thread(thread_id); + self.active_snapshot.thread_states.exit_thread(thread_id); } reason => { log::error!("Unhandled thread event reason {:?}", reason); @@ -1526,10 +1579,11 @@ impl Session { Events::Module(event) => { match event.reason { dap::ModuleEventReason::New => { - self.modules.push(event.module); + self.active_snapshot.modules.push(event.module); } dap::ModuleEventReason::Changed => { if let Some(module) = self + .active_snapshot .modules .iter_mut() .find(|other| event.module.id == other.id) @@ -1538,7 +1592,9 @@ impl Session { } } dap::ModuleEventReason::Removed => { - self.modules.retain(|other| event.module.id != other.id); + self.active_snapshot + .modules + .retain(|other| event.module.id != other.id); } } @@ -1612,9 +1668,16 @@ impl Session { ); } - if !self.thread_states.any_stopped_thread() + if self.selected_snapshot_index.is_some() { + return; + } + + if self.is_session_terminated { + return; + } + + if !self.active_snapshot.thread_states.any_stopped_thread() && request.type_id() != TypeId::of::() - || self.is_session_terminated { return; } @@ -1629,7 +1692,7 @@ impl Session { let task = Self::request_inner::>( &self.capabilities, - &self.mode, + &self.state, command, |this, result, cx| { process_result(this, result, cx); @@ -1697,7 +1760,7 @@ impl Session { + 'static, cx: &mut Context, ) -> Task> { - Self::request_inner(&self.capabilities, &self.mode, request, process_result, cx) + Self::request_inner(&self.capabilities, &self.state, request, process_result, cx) } fn invalidate_command_type(&mut self) { @@ -1730,11 +1793,11 @@ impl Session { } pub fn any_stopped_thread(&self) -> bool { - self.thread_states.any_stopped_thread() + self.active_snapshot.thread_states.any_stopped_thread() } pub fn thread_status(&self, thread_id: ThreadId) -> ThreadStatus { - self.thread_states.thread_status(thread_id) + self.active_snapshot.thread_states.thread_status(thread_id) } pub fn threads(&mut self, cx: &mut Context) -> Vec<(dap::Thread, ThreadStatus)> { @@ -1745,7 +1808,7 @@ impl Session { return; }; - this.threads = result + this.active_snapshot.threads = result .into_iter() .map(|thread| (ThreadId(thread.id), Thread::from(thread))) .collect(); @@ -1757,12 +1820,14 @@ impl Session { cx, ); - self.threads + let state = self.session_state(); + state + .threads .values() .map(|thread| { ( thread.dap.clone(), - self.thread_states.thread_status(ThreadId(thread.dap.id)), + state.thread_states.thread_status(ThreadId(thread.dap.id)), ) }) .collect() @@ -1776,14 +1841,14 @@ impl Session { return; }; - this.modules = result; + this.active_snapshot.modules = result; cx.emit(SessionEvent::Modules); cx.notify(); }, cx, ); - &self.modules + &self.session_state().modules } // CodeLLDB returns the size of a pointed-to-memory, which we can use to make the experience of go-to-memory better. @@ -2034,14 +2099,13 @@ impl Session { let Some(result) = result.log_err() else { return; }; - this.loaded_sources = result; + this.active_snapshot.loaded_sources = result; cx.emit(SessionEvent::LoadedSources); cx.notify(); }, cx, ); - - &self.loaded_sources + &self.session_state().loaded_sources } fn fallback_to_manual_restart( @@ -2073,7 +2137,7 @@ impl Session { Some(response) } None => { - this.thread_states.stop_thread(thread_id); + this.active_snapshot.thread_states.stop_thread(thread_id); cx.notify(); None } @@ -2149,10 +2213,10 @@ impl Session { } self.is_session_terminated = true; - self.thread_states.exit_all_threads(); + self.active_snapshot.thread_states.exit_all_threads(); cx.notify(); - let task = match &mut self.mode { + let task = match &mut self.state { SessionState::Running(_) => { if self .capabilities @@ -2213,9 +2277,13 @@ impl Session { } pub fn continue_thread(&mut self, thread_id: ThreadId, cx: &mut Context) { + self.go_back_to_history(None, cx); + let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; - self.thread_states.continue_thread(thread_id); + self.active_snapshot + .thread_states + .continue_thread(thread_id); self.request( ContinueCommand { args: ContinueArguments { @@ -2230,21 +2298,24 @@ impl Session { } pub fn adapter_client(&self) -> Option> { - match self.mode { + match self.state { SessionState::Running(ref local) => Some(local.client.clone()), SessionState::Booting(_) => None, } } pub fn has_ever_stopped(&self) -> bool { - self.mode.has_ever_stopped() + self.state.has_ever_stopped() } + pub fn step_over( &mut self, thread_id: ThreadId, granularity: SteppingGranularity, cx: &mut Context, ) { + self.go_back_to_history(None, cx); + let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; let supports_stepping_granularity = self @@ -2260,7 +2331,7 @@ impl Session { }, }; - self.thread_states.process_step(thread_id); + self.active_snapshot.thread_states.process_step(thread_id); self.request( command, Self::on_step_response::(thread_id), @@ -2275,6 +2346,8 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { + self.go_back_to_history(None, cx); + let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; let supports_stepping_granularity = self @@ -2290,7 +2363,7 @@ impl Session { }, }; - self.thread_states.process_step(thread_id); + self.active_snapshot.thread_states.process_step(thread_id); self.request( command, Self::on_step_response::(thread_id), @@ -2305,6 +2378,8 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { + self.go_back_to_history(None, cx); + let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; let supports_stepping_granularity = self @@ -2320,7 +2395,7 @@ impl Session { }, }; - self.thread_states.process_step(thread_id); + self.active_snapshot.thread_states.process_step(thread_id); self.request( command, Self::on_step_response::(thread_id), @@ -2335,6 +2410,8 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { + self.go_back_to_history(None, cx); + let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; let supports_stepping_granularity = self @@ -2350,7 +2427,7 @@ impl Session { }, }; - self.thread_states.process_step(thread_id); + self.active_snapshot.thread_states.process_step(thread_id); self.request( command, @@ -2365,9 +2442,9 @@ impl Session { thread_id: ThreadId, cx: &mut Context, ) -> Result> { - if self.thread_states.thread_status(thread_id) == ThreadStatus::Stopped + if self.active_snapshot.thread_states.thread_status(thread_id) == ThreadStatus::Stopped && self.requests.contains_key(&ThreadsCommand.type_id()) - && self.threads.contains_key(&thread_id) + && self.active_snapshot.threads.contains_key(&thread_id) // ^ todo(debugger): We need a better way to check that we're not querying stale data // We could still be using an old thread id and have sent a new thread's request // This isn't the biggest concern right now because it hasn't caused any issues outside of tests @@ -2381,7 +2458,8 @@ impl Session { }, move |this, stack_frames, cx| { let entry = - this.threads + this.active_snapshot + .threads .entry(thread_id) .and_modify(|thread| match &stack_frames { Ok(stack_frames) => { @@ -2394,7 +2472,7 @@ impl Session { } Err(error) => { thread.stack_frames.clear(); - thread.stack_frames_error = Some(error.cloned()); + thread.stack_frames_error = Some(error.to_string().into()); } }); debug_assert!( @@ -2402,7 +2480,7 @@ impl Session { "Sent request for thread_id that doesn't exist" ); if let Ok(stack_frames) = stack_frames { - this.stack_frames.extend( + this.active_snapshot.stack_frames.extend( stack_frames .into_iter() .filter(|frame| { @@ -2427,10 +2505,10 @@ impl Session { ); } - match self.threads.get(&thread_id) { + match self.active_snapshot.threads.get(&thread_id) { Some(thread) => { if let Some(error) = &thread.stack_frames_error { - Err(error.cloned()) + Err(anyhow!(error.to_string())) } else { Ok(thread.stack_frames.clone()) } @@ -2457,6 +2535,7 @@ impl Session { } let entry = this + .active_snapshot .stack_frames .entry(stack_frame_id) .and_modify(|stack_frame| { @@ -2474,7 +2553,8 @@ impl Session { ); } - self.stack_frames + self.session_state() + .stack_frames .get(&stack_frame_id) .map(|frame| frame.scopes.as_slice()) .unwrap_or_default() @@ -2486,7 +2566,8 @@ impl Session { globals: bool, locals: bool, ) -> Vec { - let Some(stack_frame) = self.stack_frames.get(&stack_frame_id) else { + let state = self.session_state(); + let Some(stack_frame) = state.stack_frames.get(&stack_frame_id) else { return Vec::new(); }; @@ -2497,7 +2578,7 @@ impl Session { (scope.name.to_lowercase().contains("local") && locals) || (scope.name.to_lowercase().contains("global") && globals) }) - .filter_map(|scope| self.variables.get(&scope.variables_reference)) + .filter_map(|scope| state.variables.get(&scope.variables_reference)) .flatten() .cloned() .collect() @@ -2513,7 +2594,7 @@ impl Session { frame_id: u64, cx: &mut Context, ) -> Task> { - let request = self.mode.request_dap(EvaluateCommand { + let request = self.state.request_dap(EvaluateCommand { expression: expression.to_string(), context: Some(EvaluateArgumentsContext::Watch), frame_id: Some(frame_id), @@ -2570,7 +2651,9 @@ impl Session { return; }; - this.variables.insert(variables_reference, variables); + this.active_snapshot + .variables + .insert(variables_reference, variables); cx.emit(SessionEvent::Variables); cx.emit(SessionEvent::InvalidateInlineValue); @@ -2578,7 +2661,8 @@ impl Session { cx, ); - self.variables + self.session_state() + .variables .get(&variables_reference) .cloned() .unwrap_or_default() @@ -2645,7 +2729,7 @@ impl Session { location_reference: None, }; self.push_output(event); - let request = self.mode.request_dap(EvaluateCommand { + let request = self.state.request_dap(EvaluateCommand { expression, context, frame_id, @@ -2705,15 +2789,15 @@ impl Session { let Some(response) = response.log_err() else { return; }; - this.locations.insert(reference, response); + this.active_snapshot.locations.insert(reference, response); }, cx, ); - self.locations.get(&reference).cloned() + self.session_state().locations.get(&reference).cloned() } pub fn is_attached(&self) -> bool { - let SessionState::Running(local_mode) = &self.mode else { + let SessionState::Running(local_mode) = &self.state else { return false; }; local_mode.binary.request_args.request == StartDebuggingRequestArgumentsRequest::Attach @@ -2749,7 +2833,7 @@ impl Session { } pub fn thread_state(&self, thread_id: ThreadId) -> Option { - self.thread_states.thread_state(thread_id) + self.session_state().thread_states.thread_state(thread_id) } pub fn quirks(&self) -> SessionQuirks { From 4577e1bf8fb42ec96d6054f2da4de89df2d822cd Mon Sep 17 00:00:00 2001 From: Remco Smits Date: Sat, 6 Dec 2025 21:34:19 +0100 Subject: [PATCH 44/81] debugger: Get stack frame list working with historic snapshot feature (#44303) This PR fixes an issue where the stack frame list would not update when viewing a historic snapshot. We now also show the right active debug line based on the currently selected history. https://github.com/user-attachments/assets/baccd078-23ed-4db3-9959-f83dc2be8309 Release Notes: - N/A --------- Co-authored-by: Anthony Eid --- .../src/session/running/loaded_source_list.rs | 4 +++- .../src/session/running/module_list.rs | 4 +++- .../src/session/running/stack_frame_list.rs | 4 +++- .../src/session/running/variable_list.rs | 7 ++++++- crates/project/src/debugger/session.rs | 19 +++++++------------ 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index 921ebd8b5f5bdfe8a3c8a8f7bb1625bd1ffad7fb..e55fad336b5ee6dfbee1cb0c90ea3d19f561a2ba 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -17,7 +17,9 @@ impl LoadedSourceList { let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { - SessionEvent::Stopped(_) | SessionEvent::LoadedSources => { + SessionEvent::Stopped(_) + | SessionEvent::HistoricSnapshotSelected + | SessionEvent::LoadedSources => { this.invalidate = true; cx.notify(); } diff --git a/crates/debugger_ui/src/session/running/module_list.rs b/crates/debugger_ui/src/session/running/module_list.rs index 19f407eb23f8acf0aa665f5119ecfd2156eb685f..7d0228fc6851185d10a3a237257d6244d5a90c76 100644 --- a/crates/debugger_ui/src/session/running/module_list.rs +++ b/crates/debugger_ui/src/session/running/module_list.rs @@ -32,7 +32,9 @@ impl ModuleList { let focus_handle = cx.focus_handle(); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { - SessionEvent::Stopped(_) | SessionEvent::Modules => { + SessionEvent::Stopped(_) + | SessionEvent::HistoricSnapshotSelected + | SessionEvent::Modules => { if this._rebuild_task.is_some() { this.schedule_rebuild(cx); } diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index 96a910af4dd0ac901c6802c139ddd5b8b3d728bc..5ecdc0f74be97c01ace933fd3513535040599bac 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -97,7 +97,9 @@ impl StackFrameList { SessionEvent::Threads => { this.schedule_refresh(false, window, cx); } - SessionEvent::Stopped(..) | SessionEvent::StackTrace => { + SessionEvent::Stopped(..) + | SessionEvent::StackTrace + | SessionEvent::HistoricSnapshotSelected => { this.schedule_refresh(true, window, cx); } _ => {} diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index 1b455b59d7d12712a3d4adc713a6ed15e8166c6e..7b23cd685d93e6353d68dc57cd3998099ea56ad7 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -217,6 +217,12 @@ impl VariableList { let _subscriptions = vec![ cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events), cx.subscribe(&session, |this, _, event, cx| match event { + SessionEvent::HistoricSnapshotSelected => { + this.selection.take(); + this.edited_path.take(); + this.selected_stack_frame_id.take(); + this.build_entries(cx); + } SessionEvent::Stopped(_) => { this.selection.take(); this.edited_path.take(); @@ -225,7 +231,6 @@ impl VariableList { SessionEvent::Variables | SessionEvent::Watchers => { this.build_entries(cx); } - _ => {} }), cx.on_focus_out(&focus_handle, window, |this, _, _, cx| { diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index a63e9066c9a30233ee1edb15aac13da145cb76b2..9d4d307f990bfc5f00190f74ce3f1f957e71bacc 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -808,6 +808,7 @@ pub enum SessionEvent { }, DataBreakpointInfo, ConsoleOutput, + HistoricSnapshotSelected, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -1447,7 +1448,7 @@ impl Session { self.selected_snapshot_index = ix; if ix.is_some() { - cx.emit(SessionEvent::Stopped(None)); + cx.emit(SessionEvent::HistoricSnapshotSelected); } cx.notify(); @@ -1668,16 +1669,10 @@ impl Session { ); } - if self.selected_snapshot_index.is_some() { - return; - } - - if self.is_session_terminated { - return; - } - - if !self.active_snapshot.thread_states.any_stopped_thread() - && request.type_id() != TypeId::of::() + if (!self.active_snapshot.thread_states.any_stopped_thread() + && request.type_id() != TypeId::of::()) + || self.selected_snapshot_index.is_some() + || self.is_session_terminated { return; } @@ -2505,7 +2500,7 @@ impl Session { ); } - match self.active_snapshot.threads.get(&thread_id) { + match self.session_state().threads.get(&thread_id) { Some(thread) => { if let Some(error) = &thread.stack_frames_error { Err(anyhow!(error.to_string())) From ef76f07b1ec8e4bdf996666b5522c08add4b2288 Mon Sep 17 00:00:00 2001 From: Remco Smits Date: Sat, 6 Dec 2025 22:08:33 +0100 Subject: [PATCH 45/81] debugger: Make historic snapshot button a dropdown menu (#44307) This allows users to select any snapshot in the debugger history feature and go back to the active session snapshot. We also change variable names to use hsitoric snapshot instead of history and move the snapshot icon to the back of the debugger top control strip. https://github.com/user-attachments/assets/805de8d0-30c1-4719-8af7-2d47e1df1da4 Release Notes: - N/A Co-authored-by: Anthony Eid --- crates/debugger_ui/src/debugger_panel.rs | 212 ++++++++++++------ .../src/session/running/stack_frame_list.rs | 1 - crates/project/src/debugger/session.rs | 24 +- .../ui/src/components/button/split_button.rs | 32 ++- 4 files changed, 190 insertions(+), 79 deletions(-) diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index fe81ac641196dbbc5ceecaede0785ca72336c261..bdb308aafd0d2899f17bef732ac38239c4df6dda 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -17,9 +17,9 @@ use dap::{client::SessionId, debugger_settings::DebuggerSettings}; use editor::{Editor, MultiBufferOffset, ToPoint}; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{ - Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId, - EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, - WeakEntity, anchored, deferred, + Action, App, AsyncWindowContext, ClipboardItem, Context, Corner, DismissEvent, Entity, + EntityId, EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, + Subscription, Task, WeakEntity, anchored, deferred, }; use itertools::Itertools as _; @@ -32,7 +32,9 @@ use settings::Settings; use std::sync::{Arc, LazyLock}; use task::{DebugScenario, TaskContext}; use tree_sitter::{Query, StreamingIterator as _}; -use ui::{ContextMenu, Divider, PopoverMenuHandle, Tab, Tooltip, prelude::*}; +use ui::{ + ContextMenu, Divider, PopoverMenu, PopoverMenuHandle, SplitButton, Tab, Tooltip, prelude::*, +}; use util::rel_path::RelPath; use util::{ResultExt, debug_panic, maybe}; use workspace::SplitDirection; @@ -669,6 +671,12 @@ impl DebugPanel { ) }; + let thread_status = active_session + .as_ref() + .map(|session| session.read(cx).running_state()) + .and_then(|state| state.read(cx).thread_status(cx)) + .unwrap_or(project::debugger::session::ThreadStatus::Exited); + Some( div.w_full() .py_1() @@ -686,10 +694,6 @@ impl DebugPanel { .as_ref() .map(|session| session.read(cx).running_state()), |this, running_state| { - let thread_status = - running_state.read(cx).thread_status(cx).unwrap_or( - project::debugger::session::ThreadStatus::Exited, - ); let capabilities = running_state.read(cx).capabilities(cx); let supports_detach = running_state.read(cx).session().read(cx).is_attached(); @@ -812,34 +816,6 @@ impl DebugPanel { } }), ) - .when(cx.has_flag::(), |this| { - this.child( - IconButton::new( - "debug-back-in-history", - IconName::HistoryRerun, - ) - .icon_size(IconSize::Small) - .on_click( - window.listener_for( - running_state, - |this, _, _window, cx| { - this.session().update(cx, |session, cx| { - let ix = session - .active_history() - .unwrap_or_else(|| { - session.history().len() - }); - - session.go_back_to_history( - Some(ix.saturating_sub(1)), - cx, - ); - }) - }, - ), - ), - ) - }) .child(Divider::vertical()) .child( IconButton::new("debug-restart", IconName::RotateCcw) @@ -906,36 +882,53 @@ impl DebugPanel { } }), ) + .when(supports_detach, |div| { + div.child( + IconButton::new( + "debug-disconnect", + IconName::DebugDetach, + ) + .disabled( + thread_status != ThreadStatus::Stopped + && thread_status != ThreadStatus::Running, + ) + .icon_size(IconSize::Small) + .on_click(window.listener_for( + running_state, + |this, _, _, cx| { + this.detach_client(cx); + }, + )) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |_window, cx| { + Tooltip::for_action_in( + "Detach", + &Detach, + &focus_handle, + cx, + ) + } + }), + ) + }) .when( - supports_detach, - |div| { - div.child( - IconButton::new( - "debug-disconnect", - IconName::DebugDetach, - ) - .disabled( - thread_status != ThreadStatus::Stopped - && thread_status != ThreadStatus::Running, + cx.has_flag::(), + |this| { + this.child(Divider::vertical()).child( + SplitButton::new( + self.render_history_button( + &running_state, + thread_status, + window, + ), + self.render_history_toggle_button( + thread_status, + &running_state, + ) + .into_any_element(), ) - .icon_size(IconSize::Small) - .on_click(window.listener_for( - running_state, - |this, _, _, cx| { - this.detach_client(cx); - }, - )) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |_window, cx| { - Tooltip::for_action_in( - "Detach", - &Detach, - &focus_handle, - cx, - ) - } - }), + .style(ui::SplitButtonStyle::Outlined), ) }, ) @@ -1352,6 +1345,97 @@ impl DebugPanel { }); } } + + fn render_history_button( + &self, + running_state: &Entity, + thread_status: ThreadStatus, + window: &mut Window, + ) -> IconButton { + IconButton::new("debug-back-in-history", IconName::HistoryRerun) + .icon_size(IconSize::Small) + .on_click(window.listener_for(running_state, |this, _, _window, cx| { + this.session().update(cx, |session, cx| { + let ix = session + .active_snapshot_index() + .unwrap_or_else(|| session.historic_snapshots().len()); + + session.select_historic_snapshot(Some(ix.saturating_sub(1)), cx); + }) + })) + .disabled( + thread_status == ThreadStatus::Running || thread_status == ThreadStatus::Stepping, + ) + } + + fn render_history_toggle_button( + &self, + thread_status: ThreadStatus, + running_state: &Entity, + ) -> impl IntoElement { + PopoverMenu::new("debug-back-in-history-menu") + .trigger( + ui::ButtonLike::new_rounded_right("debug-back-in-history-menu-trigger") + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::None) + .child( + div() + .px_1() + .child(Icon::new(IconName::ChevronDown).size(IconSize::XSmall)), + ) + .disabled( + thread_status == ThreadStatus::Running + || thread_status == ThreadStatus::Stepping, + ), + ) + .menu({ + let running_state = running_state.clone(); + move |window, cx| { + let handler = + |ix: Option, running_state: Entity, cx: &mut App| { + running_state.update(cx, |state, cx| { + state.session().update(cx, |session, cx| { + session.select_historic_snapshot(ix, cx); + }) + }) + }; + + let running_state = running_state.clone(); + Some(ContextMenu::build( + window, + cx, + move |mut context_menu, _window, cx| { + let history = running_state + .read(cx) + .session() + .read(cx) + .historic_snapshots(); + + context_menu = context_menu.entry("Current State", None, { + let running_state = running_state.clone(); + move |_window, cx| { + handler(None, running_state.clone(), cx); + } + }); + context_menu = context_menu.separator(); + + for (ix, _) in history.iter().enumerate().rev() { + context_menu = + context_menu.entry(format!("history-{}", ix + 1), None, { + let running_state = running_state.clone(); + move |_window, cx| { + handler(Some(ix), running_state.clone(), cx); + } + }); + } + + context_menu + }, + )) + } + }) + .anchor(Corner::TopRight) + } } async fn register_session_inner( diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index 5ecdc0f74be97c01ace933fd3513535040599bac..a715e2248d14e253a9762c1bcf9f50c1db09d64c 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -227,7 +227,6 @@ impl StackFrameList { } this.update_in(cx, |this, window, cx| { this.build_entries(select_first, window, cx); - cx.notify(); }) .ok(); }) diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index 9d4d307f990bfc5f00190f74ce3f1f957e71bacc..65e903e178f6bb010c34315c1c5d5a7bf9cbe44e 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -1436,15 +1436,23 @@ impl Session { .push_back(std::mem::take(&mut self.active_snapshot)); } - pub fn history(&self) -> &VecDeque { + pub fn historic_snapshots(&self) -> &VecDeque { &self.snapshots } - pub fn go_back_to_history(&mut self, ix: Option, cx: &mut Context<'_, Session>) { + pub fn select_historic_snapshot(&mut self, ix: Option, cx: &mut Context) { if self.selected_snapshot_index == ix { return; } + if self + .selected_snapshot_index + .is_some_and(|ix| self.snapshots.len() <= ix) + { + debug_panic!("Attempted to select a debug session with an out of bounds index"); + return; + } + self.selected_snapshot_index = ix; if ix.is_some() { @@ -1454,7 +1462,7 @@ impl Session { cx.notify(); } - pub fn active_history(&self) -> Option { + pub fn active_snapshot_index(&self) -> Option { self.selected_snapshot_index } @@ -2272,7 +2280,7 @@ impl Session { } pub fn continue_thread(&mut self, thread_id: ThreadId, cx: &mut Context) { - self.go_back_to_history(None, cx); + self.select_historic_snapshot(None, cx); let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; @@ -2309,7 +2317,7 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { - self.go_back_to_history(None, cx); + self.select_historic_snapshot(None, cx); let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; @@ -2341,7 +2349,7 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { - self.go_back_to_history(None, cx); + self.select_historic_snapshot(None, cx); let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; @@ -2373,7 +2381,7 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { - self.go_back_to_history(None, cx); + self.select_historic_snapshot(None, cx); let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; @@ -2405,7 +2413,7 @@ impl Session { granularity: SteppingGranularity, cx: &mut Context, ) { - self.go_back_to_history(None, cx); + self.select_historic_snapshot(None, cx); let supports_single_thread_execution_requests = self.capabilities.supports_single_thread_execution_requests; diff --git a/crates/ui/src/components/button/split_button.rs b/crates/ui/src/components/button/split_button.rs index 14b9fd153cd5ad662467c75ff81700587667cee3..48f06ff3789e69b6d19cde2322932f4bd6e89f97 100644 --- a/crates/ui/src/components/button/split_button.rs +++ b/crates/ui/src/components/button/split_button.rs @@ -4,7 +4,7 @@ use gpui::{ }; use theme::ActiveTheme; -use crate::{ElevationIndex, h_flex}; +use crate::{ElevationIndex, IconButton, h_flex}; use super::ButtonLike; @@ -15,6 +15,23 @@ pub enum SplitButtonStyle { Transparent, } +pub enum SplitButtonKind { + ButtonLike(ButtonLike), + IconButton(IconButton), +} + +impl From for SplitButtonKind { + fn from(icon_button: IconButton) -> Self { + Self::IconButton(icon_button) + } +} + +impl From for SplitButtonKind { + fn from(button_like: ButtonLike) -> Self { + Self::ButtonLike(button_like) + } +} + /// /// A button with two parts: a primary action on the left and a secondary action on the right. /// /// The left side is a [`ButtonLike`] with the main action, while the right side can contain @@ -23,15 +40,15 @@ pub enum SplitButtonStyle { /// The two sections are visually separated by a divider, but presented as a unified control. #[derive(IntoElement)] pub struct SplitButton { - pub left: ButtonLike, - pub right: AnyElement, + left: SplitButtonKind, + right: AnyElement, style: SplitButtonStyle, } impl SplitButton { - pub fn new(left: ButtonLike, right: AnyElement) -> Self { + pub fn new(left: impl Into, right: AnyElement) -> Self { Self { - left, + left: left.into(), right, style: SplitButtonStyle::Filled, } @@ -56,7 +73,10 @@ impl RenderOnce for SplitButton { this.border_1() .border_color(cx.theme().colors().border.opacity(0.8)) }) - .child(div().flex_grow().child(self.left)) + .child(div().flex_grow().child(match self.left { + SplitButtonKind::ButtonLike(button) => button.into_any_element(), + SplitButtonKind::IconButton(icon) => icon.into_any_element(), + })) .child( div() .h_full() From 9f344f093e1b5fee08937111569b106dbeee2410 Mon Sep 17 00:00:00 2001 From: Kunall Banerjee Date: Sat, 6 Dec 2025 19:14:13 -0500 Subject: [PATCH 46/81] docs: Point to the right URL for Astro LSP (#44314) The original URL points to a deprecated repo. Release Notes: - N/A --- docs/src/languages/astro.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/languages/astro.md b/docs/src/languages/astro.md index 5691a0de4844b2e2d924713d523f4651da6fe984..cbfe8de74e7444e2e02f6240265e00eb043a2084 100644 --- a/docs/src/languages/astro.md +++ b/docs/src/languages/astro.md @@ -3,7 +3,7 @@ Astro support is available through the [Astro extension](https://github.com/zed-extensions/astro). - Tree-sitter: [virchau13/tree-sitter-astro](https://github.com/virchau13/tree-sitter-astro) -- Language Server: [withastro/language-tools](https://github.com/withastro/language-tools) +- Language Server: [withastro/language-tools](https://github.com/withastro/astro/tree/main/packages/language-tools/language-server)