From 24a6008e5c0bc691a8ee68f4cbfafd919367efde Mon Sep 17 00:00:00 2001 From: MostlyK <135974627+MostlyKIGuess@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:12:50 +0530 Subject: [PATCH 01/22] repl: Improve iopub connection error messages (#53014) Coming from #51834, these would be more helpful than just that it failed! Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --- crates/repl/src/kernels/ssh_kernel.rs | 2 +- crates/repl/src/kernels/wsl_kernel.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/repl/src/kernels/ssh_kernel.rs b/crates/repl/src/kernels/ssh_kernel.rs index 53be6622379cfcbf3ceeb6db425eeede9b226860..797b111a14345267e01c60c6803787c8f1d0f6a2 100644 --- a/crates/repl/src/kernels/ssh_kernel.rs +++ b/crates/repl/src/kernels/ssh_kernel.rs @@ -215,7 +215,7 @@ impl SshRunningKernel { &session_id, ) .await - .context("failed to create iopub connection")?; + .context("Failed to create iopub connection. Is `ipykernel` installed in the remote environment? Try running `pip install ipykernel` on the remote host.")?; let peer_identity = runtimelib::peer_identity_for_session(&session_id)?; let shell_socket = runtimelib::create_client_shell_connection_with_identity( diff --git a/crates/repl/src/kernels/wsl_kernel.rs b/crates/repl/src/kernels/wsl_kernel.rs index d9ac05c5fc8c2cb756898ff449d6714b78cb7997..be76d7ddccb7f199a368b76a1f21bf65fe6f2902 100644 --- a/crates/repl/src/kernels/wsl_kernel.rs +++ b/crates/repl/src/kernels/wsl_kernel.rs @@ -354,7 +354,8 @@ impl WslRunningKernel { "", &session_id, ) - .await?; + .await + .context("Failed to create iopub connection. Is `ipykernel` installed in the WSL environment? Try running `pip install ipykernel` inside your WSL distribution.")?; let peer_identity = runtimelib::peer_identity_for_session(&session_id)?; let shell_socket = runtimelib::create_client_shell_connection_with_identity( From 6f7fab1d68f1fa4945c7717a595b9e9776a14521 Mon Sep 17 00:00:00 2001 From: Smit Barmase Date: Tue, 7 Apr 2026 11:19:21 +0530 Subject: [PATCH 02/22] http_client: Fix GitHub download unpack failures on some filesystems (#53286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Disable mtime preservation when unpacking tar archives, as some filesystems error when asked to set it. Follows how [cargo](https://github.com/rust-lang/cargo/blob/1ad92f77a819953bcef75a24019b66681ff28b1c/src/cargo/ops/cargo_package/verify.rs#L59 ) and [uv](https://github.com/astral-sh/uv/blob/0da0cd8b4310d3ac4be96223bd1e24ada109af9e/crates/uv-extract/src/stream.rs#L658) handle it. > Caused by:     0: extracting https://github.com/microsoft/vscode-eslint/archive/refs/tags/release%2F3.0.24.tar.gz to "/Users/user-name-here/Library/Application Support/Zed/languages/eslint/.tmp-github-download-pYkrYP"     1: failed to unpack `/Users/user-name-here/Library/Application Support/Zed/languages/eslint/.tmp-github-download-pYkrYP/vscode-eslint-release-3.0.24/package-lock.json`     2: failed to set mtime for `/Users/user-name-here/Library/Application Support/Zed/languages/eslint/.tmp-github-download-pYkrYP/vscode-eslint-release-3.0.24/package-lock.json`     3: No such file or directory (os error 2) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --- crates/http_client/src/github_download.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/crates/http_client/src/github_download.rs b/crates/http_client/src/github_download.rs index 47ae2c2b36b1ab37b56ab70735c2ce018bc5e275..5d11f3e11b7ea951c6bc9c143c266d8802f88cc3 100644 --- a/crates/http_client/src/github_download.rs +++ b/crates/http_client/src/github_download.rs @@ -207,11 +207,7 @@ async fn extract_tar_gz( from: impl AsyncRead + Unpin, ) -> Result<(), anyhow::Error> { let decompressed_bytes = GzipDecoder::new(BufReader::new(from)); - let archive = async_tar::Archive::new(decompressed_bytes); - archive - .unpack(&destination_path) - .await - .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + unpack_tar_archive(destination_path, url, decompressed_bytes).await?; Ok(()) } @@ -221,7 +217,21 @@ async fn extract_tar_bz2( from: impl AsyncRead + Unpin, ) -> Result<(), anyhow::Error> { let decompressed_bytes = BzDecoder::new(BufReader::new(from)); - let archive = async_tar::Archive::new(decompressed_bytes); + unpack_tar_archive(destination_path, url, decompressed_bytes).await?; + Ok(()) +} + +async fn unpack_tar_archive( + destination_path: &Path, + url: &str, + archive_bytes: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + // We don't need to set the modified time. It's irrelevant to downloaded + // archive verification, and some filesystems return errors when asked to + // apply it after extraction. + let archive = async_tar::ArchiveBuilder::new(archive_bytes) + .set_preserve_mtime(false) + .build(); archive .unpack(&destination_path) .await From 818991db7781db11bd8b1dea9eb27179713156f1 Mon Sep 17 00:00:00 2001 From: Saketh <126517689+SAKETH11111@users.noreply.github.com> Date: Tue, 7 Apr 2026 01:14:00 -0500 Subject: [PATCH 03/22] tasks_ui: Fix previously used task tooltip (#53104) Closes #52941 ## Summary - update the task picker delete button tooltip to describe the recently used task entry it removes - keep the change scoped to the inaccurate user-facing copy in the tasks modal ## Testing - cargo test -p tasks_ui Release Notes: - N/A --- crates/tasks_ui/src/modal.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index 285a07c9562849b26b4cbba3de3979614384d875..3b7edef415f10f8723ab041e5a81ac672d603371 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -566,9 +566,7 @@ impl PickerDelegate for TasksModalDelegate { .checked_sub(1); picker.refresh(window, cx); })) - .tooltip(|_, cx| { - Tooltip::simple("Delete Previously Scheduled Task", cx) - }), + .tooltip(|_, cx| Tooltip::simple("Delete from Recent Tasks", cx)), ); item.end_slot_on_hover(delete_button) } else { From ee6495dce4012019ccc235486afa800da443d680 Mon Sep 17 00:00:00 2001 From: Cameron Mcloughlin Date: Tue, 7 Apr 2026 09:02:40 +0100 Subject: [PATCH 04/22] collab: Fix UI font size scaling (#53290) --- crates/collab_ui/src/collab_panel.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 8d0cdf351163dadf0ac8cbf6a8dc04886f30f583..1e1aab3b9d4aa0e48ad4a84ec77bdc6dff51c7f5 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -1181,7 +1181,6 @@ impl CollabPanel { .into(); ListItem::new(project_id as usize) - .height(px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.workspace @@ -1222,7 +1221,6 @@ impl CollabPanel { let id = peer_id.map_or(usize::MAX, |id| id.as_u64() as usize); ListItem::new(("screen", id)) - .height(px(24.)) .toggle_state(is_selected) .start_slot( h_flex() @@ -1269,7 +1267,6 @@ impl CollabPanel { let has_channel_buffer_changed = channel_store.has_channel_buffer_changed(channel_id); ListItem::new("channel-notes") - .height(px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.open_channel_notes(channel_id, window, cx); @@ -3210,12 +3207,9 @@ impl CollabPanel { (IconName::Star, Color::Default, "Add to Favorites") }; - let height = px(24.); - h_flex() .id(ix) .group("") - .h(height) .w_full() .overflow_hidden() .when(!channel.is_root_channel(), |el| { @@ -3245,7 +3239,6 @@ impl CollabPanel { ) .child( ListItem::new(ix) - .height(height) // Add one level of depth for the disclosure arrow. .indent_level(depth + 1) .indent_step_size(px(20.)) From 614f67ed2aa7378e5f11359ea01ba873b6a2a103 Mon Sep 17 00:00:00 2001 From: "Angel P." Date: Tue, 7 Apr 2026 05:00:22 -0400 Subject: [PATCH 05/22] markdown_preview: Fix HTML alignment styles not being applied (#53196) ## What This PR Does This PR adds support for HTML alignment styles to be applied to Paragraph and Heading elements and their children. Here is what this looks like before vs after this PR (both images use the same markdown below): ```markdown

``` **BEFORE:** image **AFTER:** image ## Notes I used `style="text-align: center|left|right;"` instead of `align="center|right|left"` since `align` has been [deprecated in HTML5](https://www.w3.org/TR/2011/WD-html5-author-20110809/obsolete.html) for block-level elements. The issue this PR solves mentioned that github supports the `align="center|right|left"` attribute, so I'm unsure if the Zed team would want to have parity there. Feel free to let me know if that would be something that should be added, however for now I've decided to follow the HTML5 standard. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes https://github.com/zed-industries/zed/issues/51062 Release Notes: - Fixed HTML alignment styles not being applied in markdown previews --------- Co-authored-by: Smit Barmase --- crates/markdown/src/html/html_parser.rs | 117 ++++++++++++++++++--- crates/markdown/src/html/html_rendering.rs | 18 +++- crates/markdown/src/markdown.rs | 69 +++++++++--- 3 files changed, 172 insertions(+), 32 deletions(-) diff --git a/crates/markdown/src/html/html_parser.rs b/crates/markdown/src/html/html_parser.rs index 20338ec2abef2314b7cd6ca91e45ee05be909745..8aa5da0cea7ea160721875fa889a720fe4c8bed1 100644 --- a/crates/markdown/src/html/html_parser.rs +++ b/crates/markdown/src/html/html_parser.rs @@ -1,6 +1,6 @@ use std::{cell::RefCell, collections::HashMap, mem, ops::Range}; -use gpui::{DefiniteLength, FontWeight, SharedString, px, relative}; +use gpui::{DefiniteLength, FontWeight, SharedString, TextAlign, px, relative}; use html5ever::{ Attribute, LocalName, ParseOpts, local_name, parse_document, tendril::TendrilSink, }; @@ -24,10 +24,17 @@ pub(crate) enum ParsedHtmlElement { List(ParsedHtmlList), Table(ParsedHtmlTable), BlockQuote(ParsedHtmlBlockQuote), - Paragraph(HtmlParagraph), + Paragraph(ParsedHtmlParagraph), Image(HtmlImage), } +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) struct ParsedHtmlParagraph { + pub text_align: Option, + pub contents: HtmlParagraph, +} + impl ParsedHtmlElement { pub fn source_range(&self) -> Option> { Some(match self { @@ -35,7 +42,7 @@ impl ParsedHtmlElement { Self::List(list) => list.source_range.clone(), Self::Table(table) => table.source_range.clone(), Self::BlockQuote(block_quote) => block_quote.source_range.clone(), - Self::Paragraph(text) => match text.first()? { + Self::Paragraph(paragraph) => match paragraph.contents.first()? { HtmlParagraphChunk::Text(text) => text.source_range.clone(), HtmlParagraphChunk::Image(image) => image.source_range.clone(), }, @@ -83,6 +90,7 @@ pub(crate) struct ParsedHtmlHeading { pub source_range: Range, pub level: HeadingLevel, pub contents: HtmlParagraph, + pub text_align: Option, } #[derive(Debug, Clone)] @@ -236,20 +244,21 @@ fn parse_html_node( consume_children(source_range, node, elements, context); } NodeData::Text { contents } => { - elements.push(ParsedHtmlElement::Paragraph(vec![ - HtmlParagraphChunk::Text(ParsedHtmlText { + elements.push(ParsedHtmlElement::Paragraph(ParsedHtmlParagraph { + text_align: None, + contents: vec![HtmlParagraphChunk::Text(ParsedHtmlText { source_range, highlights: Vec::default(), links: Vec::default(), contents: contents.borrow().to_string().into(), - }), - ])); + })], + })); } NodeData::Comment { .. } => {} NodeData::Element { name, attrs, .. } => { - let mut styles = if let Some(styles) = - html_style_from_html_styles(extract_styles_from_attributes(attrs)) - { + let styles_map = extract_styles_from_attributes(attrs); + let text_align = text_align_from_attributes(attrs, &styles_map); + let mut styles = if let Some(styles) = html_style_from_html_styles(styles_map) { vec![styles] } else { Vec::default() @@ -270,7 +279,10 @@ fn parse_html_node( ); if !paragraph.is_empty() { - elements.push(ParsedHtmlElement::Paragraph(paragraph)); + elements.push(ParsedHtmlElement::Paragraph(ParsedHtmlParagraph { + text_align, + contents: paragraph, + })); } } else if matches!( name.local, @@ -303,6 +315,7 @@ fn parse_html_node( _ => unreachable!(), }, contents: paragraph, + text_align, })); } } else if name.local == local_name!("ul") || name.local == local_name!("ol") { @@ -589,6 +602,30 @@ fn html_style_from_html_styles(styles: HashMap) -> Option Option { + match value.trim().to_ascii_lowercase().as_str() { + "left" => Some(TextAlign::Left), + "center" => Some(TextAlign::Center), + "right" => Some(TextAlign::Right), + _ => None, + } +} + +fn text_align_from_styles(styles: &HashMap) -> Option { + styles + .get("text-align") + .and_then(|value| parse_text_align(value)) +} + +fn text_align_from_attributes( + attrs: &RefCell>, + styles: &HashMap, +) -> Option { + text_align_from_styles(styles).or_else(|| { + attr_value(attrs, local_name!("align")).and_then(|value| parse_text_align(&value)) + }) +} + fn extract_styles_from_attributes(attrs: &RefCell>) -> HashMap { let mut styles = HashMap::new(); @@ -770,6 +807,7 @@ fn extract_html_table(node: &Node, source_range: Range) -> Optionx

", 0..40).unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } + + #[test] + fn parses_heading_text_align_from_style() { + let parsed = parse_html_block("

Title

", 0..45).unwrap(); + let ParsedHtmlElement::Heading(heading) = &parsed.children[0] else { + panic!("expected heading"); + }; + assert_eq!(heading.text_align, Some(TextAlign::Right)); + } + + #[test] + fn parses_paragraph_text_align_from_align_attribute() { + let parsed = parse_html_block("

x

", 0..24).unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } + + #[test] + fn parses_heading_text_align_from_align_attribute() { + let parsed = parse_html_block("

Title

", 0..30).unwrap(); + let ParsedHtmlElement::Heading(heading) = &parsed.children[0] else { + panic!("expected heading"); + }; + assert_eq!(heading.text_align, Some(TextAlign::Right)); + } + + #[test] + fn prefers_style_text_align_over_align_attribute() { + let parsed = parse_html_block( + "

x

", + 0..50, + ) + .unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } } diff --git a/crates/markdown/src/html/html_rendering.rs b/crates/markdown/src/html/html_rendering.rs index 103e2a6accb7dce9bc429419aafd27cbdf5080ce..6ae25eff0b4ba2ec8dedde8118ebd8d60e8fce7d 100644 --- a/crates/markdown/src/html/html_rendering.rs +++ b/crates/markdown/src/html/html_rendering.rs @@ -79,9 +79,20 @@ impl MarkdownElement { match element { ParsedHtmlElement::Paragraph(paragraph) => { - self.push_markdown_paragraph(builder, &source_range, markdown_end); - self.render_html_paragraph(paragraph, source_allocator, builder, cx, markdown_end); - builder.pop_div(); + self.push_markdown_paragraph( + builder, + &source_range, + markdown_end, + paragraph.text_align, + ); + self.render_html_paragraph( + ¶graph.contents, + source_allocator, + builder, + cx, + markdown_end, + ); + self.pop_markdown_paragraph(builder); } ParsedHtmlElement::Heading(heading) => { self.push_markdown_heading( @@ -89,6 +100,7 @@ impl MarkdownElement { heading.level, &heading.source_range, markdown_end, + heading.text_align, ); self.render_html_paragraph( &heading.contents, diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index 247c082d223005a7e0bd6d57696751ce76cc4d86..e6ad1b1f2ac9154eaabc6d18dbcb9c8695ae019d 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -36,8 +36,8 @@ use gpui::{ FocusHandle, Focusable, FontStyle, FontWeight, GlobalElementId, Hitbox, Hsla, Image, ImageFormat, ImageSource, KeyContext, Length, MouseButton, MouseDownEvent, MouseEvent, MouseMoveEvent, MouseUpEvent, Point, ScrollHandle, Stateful, StrikethroughStyle, - StyleRefinement, StyledText, Task, TextLayout, TextRun, TextStyle, TextStyleRefinement, - actions, img, point, quad, + StyleRefinement, StyledText, Task, TextAlign, TextLayout, TextRun, TextStyle, + TextStyleRefinement, actions, img, point, quad, }; use language::{CharClassifier, Language, LanguageRegistry, Rope}; use parser::CodeBlockMetadata; @@ -1025,8 +1025,17 @@ impl MarkdownElement { width: Option, height: Option, ) { + let align = builder.text_style().text_align; builder.modify_current_div(|el| { - el.items_center().flex().flex_row().child( + let mut image_container = el.flex().flex_row().items_center(); + + image_container = match align { + TextAlign::Left => image_container.justify_start(), + TextAlign::Center => image_container.justify_center(), + TextAlign::Right => image_container.justify_end(), + }; + + image_container.child( img(source) .max_w_full() .when_some(height, |this, height| this.h(height)) @@ -1041,14 +1050,29 @@ impl MarkdownElement { builder: &mut MarkdownElementBuilder, range: &Range, markdown_end: usize, + text_align_override: Option, ) { - builder.push_div( - div().when(!self.style.height_is_multiple_of_line_height, |el| { - el.mb_2().line_height(rems(1.3)) - }), - range, - markdown_end, - ); + let align = text_align_override.unwrap_or(self.style.base_text_style.text_align); + let mut paragraph = div().when(!self.style.height_is_multiple_of_line_height, |el| { + el.mb_2().line_height(rems(1.3)) + }); + + paragraph = match align { + TextAlign::Center => paragraph.text_center(), + TextAlign::Left => paragraph.text_left(), + TextAlign::Right => paragraph.text_right(), + }; + + builder.push_text_style(TextStyleRefinement { + text_align: Some(align), + ..Default::default() + }); + builder.push_div(paragraph, range, markdown_end); + } + + fn pop_markdown_paragraph(&self, builder: &mut MarkdownElementBuilder) { + builder.pop_div(); + builder.pop_text_style(); } fn push_markdown_heading( @@ -1057,15 +1081,26 @@ impl MarkdownElement { level: pulldown_cmark::HeadingLevel, range: &Range, markdown_end: usize, + text_align_override: Option, ) { + let align = text_align_override.unwrap_or(self.style.base_text_style.text_align); let mut heading = div().mb_2(); heading = apply_heading_style(heading, level, self.style.heading_level_styles.as_ref()); + heading = match align { + TextAlign::Center => heading.text_center(), + TextAlign::Left => heading.text_left(), + TextAlign::Right => heading.text_right(), + }; + let mut heading_style = self.style.heading.clone(); let heading_text_style = heading_style.text_style().clone(); heading.style().refine(&heading_style); - builder.push_text_style(heading_text_style); + builder.push_text_style(TextStyleRefinement { + text_align: Some(align), + ..heading_text_style + }); builder.push_div(heading, range, markdown_end); } @@ -1571,10 +1606,16 @@ impl Element for MarkdownElement { } } MarkdownTag::Paragraph => { - self.push_markdown_paragraph(&mut builder, range, markdown_end); + self.push_markdown_paragraph(&mut builder, range, markdown_end, None); } MarkdownTag::Heading { level, .. } => { - self.push_markdown_heading(&mut builder, *level, range, markdown_end); + self.push_markdown_heading( + &mut builder, + *level, + range, + markdown_end, + None, + ); } MarkdownTag::BlockQuote => { self.push_markdown_block_quote(&mut builder, range, markdown_end); @@ -1826,7 +1867,7 @@ impl Element for MarkdownElement { current_img_block_range.take(); } MarkdownTagEnd::Paragraph => { - builder.pop_div(); + self.pop_markdown_paragraph(&mut builder); } MarkdownTagEnd::Heading(_) => { self.pop_markdown_heading(&mut builder); From ccb9e60a6258d57104cc56db87fe03024dd231ef Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Tue, 7 Apr 2026 05:21:47 -0400 Subject: [PATCH 06/22] agent_panel: Add new thread git worktree/branch pickers (#52979) This PR allows users to create a new thread based off a git worktree that already exists or has a custom name. User's can also choose what branch they want the newly generated worktree to be based off of. The UI still needs some polish, but I'm merging this early to get the team using this before our preview launch. I'll be active today and tomorrow before launch to fix any nits we have with the UI. Functionality of this feature works! And I have a basic test to prevent regressions Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ... --------- Co-authored-by: cameron --- crates/agent_ui/src/agent_panel.rs | 673 +++++++++++------ crates/agent_ui/src/agent_ui.rs | 37 +- .../src/conversation_view/thread_view.rs | 5 +- crates/agent_ui/src/thread_branch_picker.rs | 695 ++++++++++++++++++ crates/agent_ui/src/thread_worktree_picker.rs | 485 ++++++++++++ crates/collab/tests/integration/git_tests.rs | 12 +- .../remote_editing_collaboration_tests.rs | 6 +- crates/fs/src/fake_git_repo.rs | 113 ++- crates/fs/tests/integration/fake_git_repo.rs | 12 +- crates/git/src/repository.rs | 120 ++- crates/git_ui/src/worktree_picker.rs | 9 +- crates/project/src/git_store.rs | 102 ++- crates/project/tests/integration/git_store.rs | 12 +- crates/proto/proto/git.proto | 1 + crates/zed/src/visual_test_runner.rs | 18 +- 15 files changed, 1941 insertions(+), 359 deletions(-) create mode 100644 crates/agent_ui/src/thread_branch_picker.rs create mode 100644 crates/agent_ui/src/thread_worktree_picker.rs diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 41900e71e5d3ad7e5327ee7e04f73cb05eed5a5b..8f456e0e955b823a5bbaf2815df3b409441bb0af 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -28,21 +28,20 @@ use zed_actions::agent::{ use crate::thread_metadata_store::ThreadMetadataStore; use crate::{ AddContextServer, AgentDiffPane, ConversationView, CopyThreadToClipboard, CycleStartThreadIn, - Follow, InlineAssistant, LoadThreadFromClipboard, NewThread, OpenActiveThreadAsMarkdown, - OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, StartThreadIn, - ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, + Follow, InlineAssistant, LoadThreadFromClipboard, NewThread, NewWorktreeBranchTarget, + OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, + StartThreadIn, ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, agent_configuration::{AgentConfiguration, AssistantConfigurationEvent}, conversation_view::{AcpThreadViewEvent, ThreadView}, + thread_branch_picker::ThreadBranchPicker, + thread_worktree_picker::ThreadWorktreePicker, ui::EndTrialUpsell, }; use crate::{ Agent, AgentInitialContent, ExternalSourcePrompt, NewExternalAgentThread, NewNativeAgentThreadFromSummary, }; -use crate::{ - DEFAULT_THREAD_TITLE, - ui::{AcpOnboardingModal, HoldForDefault}, -}; +use crate::{DEFAULT_THREAD_TITLE, ui::AcpOnboardingModal}; use crate::{ExpandMessageEditor, ThreadHistoryView}; use crate::{ManageProfiles, ThreadHistoryViewEvent}; use crate::{ThreadHistory, agent_connection_store::AgentConnectionStore}; @@ -73,8 +72,8 @@ use terminal::terminal_settings::TerminalSettings; use terminal_view::{TerminalView, terminal_panel::TerminalPanel}; use theme_settings::ThemeSettings; use ui::{ - Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, DocumentationSide, - PopoverMenu, PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize, + Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, PopoverMenu, + PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize, }; use util::{ResultExt as _, debug_panic}; use workspace::{ @@ -620,7 +619,31 @@ impl StartThreadIn { fn label(&self) -> SharedString { match self { Self::LocalProject => "Current Worktree".into(), - Self::NewWorktree => "New Git Worktree".into(), + Self::NewWorktree { + worktree_name: Some(worktree_name), + .. + } => format!("New: {worktree_name}").into(), + Self::NewWorktree { .. } => "New Git Worktree".into(), + Self::LinkedWorktree { display_name, .. } => format!("From: {}", &display_name).into(), + } + } + + fn worktree_branch_label(&self, default_branch_label: SharedString) -> Option { + match self { + Self::NewWorktree { branch_target, .. } => match branch_target { + NewWorktreeBranchTarget::CurrentBranch => Some(default_branch_label), + NewWorktreeBranchTarget::ExistingBranch { name } => { + Some(format!("From: {name}").into()) + } + NewWorktreeBranchTarget::CreateBranch { name, from_ref } => { + if let Some(from_ref) = from_ref { + Some(format!("From: {from_ref}").into()) + } else { + Some(format!("From: {name}").into()) + } + } + }, + _ => None, } } } @@ -632,6 +655,17 @@ pub enum WorktreeCreationStatus { Error(SharedString), } +#[derive(Clone, Debug)] +enum WorktreeCreationArgs { + New { + worktree_name: Option, + branch_target: NewWorktreeBranchTarget, + }, + Linked { + worktree_path: PathBuf, + }, +} + impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { @@ -662,7 +696,8 @@ pub struct AgentPanel { previous_view: Option, background_threads: HashMap>, new_thread_menu_handle: PopoverMenuHandle, - start_thread_in_menu_handle: PopoverMenuHandle, + start_thread_in_menu_handle: PopoverMenuHandle, + thread_branch_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, agent_navigation_menu_handle: PopoverMenuHandle, agent_navigation_menu: Option>, @@ -689,7 +724,7 @@ impl AgentPanel { }; let selected_agent = self.selected_agent.clone(); - let start_thread_in = Some(self.start_thread_in); + let start_thread_in = Some(self.start_thread_in.clone()); let last_active_thread = self.active_agent_thread(cx).map(|thread| { let thread = thread.read(cx); @@ -794,18 +829,21 @@ impl AgentPanel { } else if let Some(agent) = global_fallback { panel.selected_agent = agent; } - if let Some(start_thread_in) = serialized_panel.start_thread_in { + if let Some(ref start_thread_in) = serialized_panel.start_thread_in { let is_worktree_flag_enabled = cx.has_flag::(); let is_valid = match &start_thread_in { StartThreadIn::LocalProject => true, - StartThreadIn::NewWorktree => { + StartThreadIn::NewWorktree { .. } => { let project = panel.project.read(cx); is_worktree_flag_enabled && !project.is_via_collab() } + StartThreadIn::LinkedWorktree { path, .. } => { + is_worktree_flag_enabled && path.exists() + } }; if is_valid { - panel.start_thread_in = start_thread_in; + panel.start_thread_in = start_thread_in.clone(); } else { log::info!( "deserialized start_thread_in {:?} is no longer valid, falling back to LocalProject", @@ -979,6 +1017,7 @@ impl AgentPanel { background_threads: HashMap::default(), new_thread_menu_handle: PopoverMenuHandle::default(), start_thread_in_menu_handle: PopoverMenuHandle::default(), + thread_branch_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu: None, @@ -1948,24 +1987,43 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - if matches!(action, StartThreadIn::NewWorktree) && !cx.has_flag::() { - return; - } - - let new_target = match *action { + let new_target = match action { StartThreadIn::LocalProject => StartThreadIn::LocalProject, - StartThreadIn::NewWorktree => { + StartThreadIn::NewWorktree { .. } => { + if !cx.has_flag::() { + return; + } + if !self.project_has_git_repository(cx) { + log::error!( + "set_start_thread_in: cannot use worktree mode without a git repository" + ); + return; + } + if self.project.read(cx).is_via_collab() { + log::error!( + "set_start_thread_in: cannot use worktree mode in a collab project" + ); + return; + } + action.clone() + } + StartThreadIn::LinkedWorktree { .. } => { + if !cx.has_flag::() { + return; + } if !self.project_has_git_repository(cx) { log::error!( - "set_start_thread_in: cannot use NewWorktree without a git repository" + "set_start_thread_in: cannot use LinkedWorktree without a git repository" ); return; } if self.project.read(cx).is_via_collab() { - log::error!("set_start_thread_in: cannot use NewWorktree in a collab project"); + log::error!( + "set_start_thread_in: cannot use LinkedWorktree in a collab project" + ); return; } - StartThreadIn::NewWorktree + action.clone() } }; self.start_thread_in = new_target; @@ -1977,9 +2035,14 @@ impl AgentPanel { } fn cycle_start_thread_in(&mut self, window: &mut Window, cx: &mut Context) { - let next = match self.start_thread_in { - StartThreadIn::LocalProject => StartThreadIn::NewWorktree, - StartThreadIn::NewWorktree => StartThreadIn::LocalProject, + let next = match &self.start_thread_in { + StartThreadIn::LocalProject => StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + StartThreadIn::NewWorktree { .. } | StartThreadIn::LinkedWorktree { .. } => { + StartThreadIn::LocalProject + } }; self.set_start_thread_in(&next, window, cx); } @@ -1991,7 +2054,10 @@ impl AgentPanel { NewThreadLocation::LocalProject => StartThreadIn::LocalProject, NewThreadLocation::NewWorktree => { if self.project_has_git_repository(cx) { - StartThreadIn::NewWorktree + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + } } else { StartThreadIn::LocalProject } @@ -2219,15 +2285,39 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - if self.start_thread_in == StartThreadIn::NewWorktree { - self.handle_worktree_creation_requested(content, window, cx); - } else { - cx.defer_in(window, move |_this, window, cx| { - thread_view.update(cx, |thread_view, cx| { - let editor = thread_view.message_editor.clone(); - thread_view.send_impl(editor, window, cx); + match &self.start_thread_in { + StartThreadIn::NewWorktree { + worktree_name, + branch_target, + } => { + self.handle_worktree_requested( + content, + WorktreeCreationArgs::New { + worktree_name: worktree_name.clone(), + branch_target: branch_target.clone(), + }, + window, + cx, + ); + } + StartThreadIn::LinkedWorktree { path, .. } => { + self.handle_worktree_requested( + content, + WorktreeCreationArgs::Linked { + worktree_path: path.clone(), + }, + window, + cx, + ); + } + StartThreadIn::LocalProject => { + cx.defer_in(window, move |_this, window, cx| { + thread_view.update(cx, |thread_view, cx| { + let editor = thread_view.message_editor.clone(); + thread_view.send_impl(editor, window, cx); + }); }); - }); + } } } @@ -2289,6 +2379,33 @@ impl AgentPanel { (git_repos, non_git_paths) } + fn resolve_worktree_branch_target( + branch_target: &NewWorktreeBranchTarget, + existing_branches: &HashSet, + occupied_branches: &HashSet, + ) -> Result<(String, bool, Option)> { + let generate_branch_name = || -> Result { + let refs: Vec<&str> = existing_branches.iter().map(|s| s.as_str()).collect(); + let mut rng = rand::rng(); + crate::branch_names::generate_branch_name(&refs, &mut rng) + .ok_or_else(|| anyhow!("Failed to generate a unique branch name")) + }; + + match branch_target { + NewWorktreeBranchTarget::CreateBranch { name, from_ref } => { + Ok((name.clone(), false, from_ref.clone())) + } + NewWorktreeBranchTarget::ExistingBranch { name } => { + if occupied_branches.contains(name) { + Ok((generate_branch_name()?, false, Some(name.clone()))) + } else { + Ok((name.clone(), true, None)) + } + } + NewWorktreeBranchTarget::CurrentBranch => Ok((generate_branch_name()?, false, None)), + } + } + /// Kicks off an async git-worktree creation for each repository. Returns: /// /// - `creation_infos`: a vec of `(repo, new_path, receiver)` tuples—the @@ -2297,7 +2414,10 @@ impl AgentPanel { /// later to remap open editor tabs into the new workspace. fn start_worktree_creations( git_repos: &[Entity], + worktree_name: Option, branch_name: &str, + use_existing_branch: bool, + start_point: Option, worktree_directory_setting: &str, cx: &mut Context, ) -> Result<( @@ -2311,12 +2431,27 @@ impl AgentPanel { let mut creation_infos = Vec::new(); let mut path_remapping = Vec::new(); + let worktree_name = worktree_name.unwrap_or_else(|| branch_name.to_string()); + for repo in git_repos { let (work_dir, new_path, receiver) = repo.update(cx, |repo, _cx| { let new_path = - repo.path_for_new_linked_worktree(branch_name, worktree_directory_setting)?; - let receiver = - repo.create_worktree(branch_name.to_string(), new_path.clone(), None); + repo.path_for_new_linked_worktree(&worktree_name, worktree_directory_setting)?; + let target = if use_existing_branch { + debug_assert!( + git_repos.len() == 1, + "use_existing_branch should only be true for a single repo" + ); + git::repository::CreateWorktreeTarget::ExistingBranch { + branch_name: branch_name.to_string(), + } + } else { + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: branch_name.to_string(), + base_sha: start_point.clone(), + } + }; + let receiver = repo.create_worktree(target, new_path.clone()); let work_dir = repo.work_directory_abs_path.clone(); anyhow::Ok((work_dir, new_path, receiver)) })?; @@ -2419,9 +2554,10 @@ impl AgentPanel { cx.notify(); } - fn handle_worktree_creation_requested( + fn handle_worktree_requested( &mut self, content: Vec, + args: WorktreeCreationArgs, window: &mut Window, cx: &mut Context, ) { @@ -2437,7 +2573,7 @@ impl AgentPanel { let (git_repos, non_git_paths) = self.classify_worktrees(cx); - if git_repos.is_empty() { + if matches!(args, WorktreeCreationArgs::New { .. }) && git_repos.is_empty() { self.set_worktree_creation_error( "No git repositories found in the project".into(), window, @@ -2446,17 +2582,31 @@ impl AgentPanel { return; } - // Kick off branch listing as early as possible so it can run - // concurrently with the remaining synchronous setup work. - let branch_receivers: Vec<_> = git_repos - .iter() - .map(|repo| repo.update(cx, |repo, _cx| repo.branches())) - .collect(); - - let worktree_directory_setting = ProjectSettings::get_global(cx) - .git - .worktree_directory - .clone(); + let (branch_receivers, worktree_receivers, worktree_directory_setting) = + if matches!(args, WorktreeCreationArgs::New { .. }) { + ( + Some( + git_repos + .iter() + .map(|repo| repo.update(cx, |repo, _cx| repo.branches())) + .collect::>(), + ), + Some( + git_repos + .iter() + .map(|repo| repo.update(cx, |repo, _cx| repo.worktrees())) + .collect::>(), + ), + Some( + ProjectSettings::get_global(cx) + .git + .worktree_directory + .clone(), + ), + ) + } else { + (None, None, None) + }; let active_file_path = self.workspace.upgrade().and_then(|workspace| { let workspace = workspace.read(cx); @@ -2476,77 +2626,124 @@ impl AgentPanel { let selected_agent = self.selected_agent(); let task = cx.spawn_in(window, async move |this, cx| { - // Await the branch listings we kicked off earlier. - let mut existing_branches = Vec::new(); - for result in futures::future::join_all(branch_receivers).await { - match result { - Ok(Ok(branches)) => { - for branch in branches { - existing_branches.push(branch.name().to_string()); + let (all_paths, path_remapping, has_non_git) = match args { + WorktreeCreationArgs::New { + worktree_name, + branch_target, + } => { + let branch_receivers = branch_receivers + .expect("branch receivers must be prepared for new worktree creation"); + let worktree_receivers = worktree_receivers + .expect("worktree receivers must be prepared for new worktree creation"); + let worktree_directory_setting = worktree_directory_setting + .expect("worktree directory must be prepared for new worktree creation"); + + let mut existing_branches = HashSet::default(); + for result in futures::future::join_all(branch_receivers).await { + match result { + Ok(Ok(branches)) => { + for branch in branches { + existing_branches.insert(branch.name().to_string()); + } + } + Ok(Err(err)) => { + Err::<(), _>(err).log_err(); + } + Err(_) => {} } } - Ok(Err(err)) => { - Err::<(), _>(err).log_err(); + + let mut occupied_branches = HashSet::default(); + for result in futures::future::join_all(worktree_receivers).await { + match result { + Ok(Ok(worktrees)) => { + for worktree in worktrees { + if let Some(branch_name) = worktree.branch_name() { + occupied_branches.insert(branch_name.to_string()); + } + } + } + Ok(Err(err)) => { + Err::<(), _>(err).log_err(); + } + Err(_) => {} + } } - Err(_) => {} - } - } - let existing_branch_refs: Vec<&str> = - existing_branches.iter().map(|s| s.as_str()).collect(); - let mut rng = rand::rng(); - let branch_name = - match crate::branch_names::generate_branch_name(&existing_branch_refs, &mut rng) { - Some(name) => name, - None => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error( - "Failed to generate a unique branch name".into(), - window, + let (branch_name, use_existing_branch, start_point) = + match Self::resolve_worktree_branch_target( + &branch_target, + &existing_branches, + &occupied_branches, + ) { + Ok(target) => target, + Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + err.to_string().into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; + + let (creation_infos, path_remapping) = + match this.update_in(cx, |_this, _window, cx| { + Self::start_worktree_creations( + &git_repos, + worktree_name, + &branch_name, + use_existing_branch, + start_point, + &worktree_directory_setting, cx, - ); - })?; - return anyhow::Ok(()); - } - }; + ) + }) { + Ok(Ok(result)) => result, + Ok(Err(err)) | Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("Failed to validate worktree directory: {err}") + .into(), + window, + cx, + ); + }) + .log_err(); + return anyhow::Ok(()); + } + }; - let (creation_infos, path_remapping) = match this.update_in(cx, |_this, _window, cx| { - Self::start_worktree_creations( - &git_repos, - &branch_name, - &worktree_directory_setting, - cx, - ) - }) { - Ok(Ok(result)) => result, - Ok(Err(err)) | Err(err) => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error( - format!("Failed to validate worktree directory: {err}").into(), - window, - cx, - ); - }) - .log_err(); - return anyhow::Ok(()); - } - }; + let created_paths = + match Self::await_and_rollback_on_failure(creation_infos, cx).await { + Ok(paths) => paths, + Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("{err}").into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; - let created_paths = match Self::await_and_rollback_on_failure(creation_infos, cx).await - { - Ok(paths) => paths, - Err(err) => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error(format!("{err}").into(), window, cx); - })?; - return anyhow::Ok(()); + let mut all_paths = created_paths; + let has_non_git = !non_git_paths.is_empty(); + all_paths.extend(non_git_paths.iter().cloned()); + (all_paths, path_remapping, has_non_git) + } + WorktreeCreationArgs::Linked { worktree_path } => { + let mut all_paths = vec![worktree_path]; + let has_non_git = !non_git_paths.is_empty(); + all_paths.extend(non_git_paths.iter().cloned()); + (all_paths, Vec::new(), has_non_git) } }; - let mut all_paths = created_paths; - let has_non_git = !non_git_paths.is_empty(); - all_paths.extend(non_git_paths.iter().cloned()); - let app_state = match workspace.upgrade() { Some(workspace) => cx.update(|_, cx| workspace.read(cx).app_state().clone())?, None => { @@ -2562,7 +2759,7 @@ impl AgentPanel { }; let this_for_error = this.clone(); - if let Err(err) = Self::setup_new_workspace( + if let Err(err) = Self::open_worktree_workspace_and_start_thread( this, all_paths, app_state, @@ -2595,7 +2792,7 @@ impl AgentPanel { })); } - async fn setup_new_workspace( + async fn open_worktree_workspace_and_start_thread( this: WeakEntity, all_paths: Vec, app_state: Arc, @@ -3149,25 +3346,15 @@ impl AgentPanel { } fn render_start_thread_in_selector(&self, cx: &mut Context) -> impl IntoElement { - use settings::{NewThreadLocation, Settings}; - let focus_handle = self.focus_handle(cx); - let has_git_repo = self.project_has_git_repository(cx); - let is_via_collab = self.project.read(cx).is_via_collab(); - let fs = self.fs.clone(); let is_creating = matches!( self.worktree_creation_status, Some(WorktreeCreationStatus::Creating) ); - let current_target = self.start_thread_in; let trigger_label = self.start_thread_in.label(); - let new_thread_location = AgentSettings::get_global(cx).new_thread_location; - let is_local_default = new_thread_location == NewThreadLocation::LocalProject; - let is_new_worktree_default = new_thread_location == NewThreadLocation::NewWorktree; - let icon = if self.start_thread_in_menu_handle.is_deployed() { IconName::ChevronUp } else { @@ -3178,13 +3365,9 @@ impl AgentPanel { .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) .disabled(is_creating); - let dock_position = AgentSettings::get_global(cx).dock; - let documentation_side = match dock_position { - settings::DockPosition::Left => DocumentationSide::Right, - settings::DockPosition::Bottom | settings::DockPosition::Right => { - DocumentationSide::Left - } - }; + let project = self.project.clone(); + let current_target = self.start_thread_in.clone(); + let fs = self.fs.clone(); PopoverMenu::new("thread-target-selector") .trigger_with_tooltip(trigger_button, { @@ -3198,89 +3381,66 @@ impl AgentPanel { } }) .menu(move |window, cx| { - let is_local_selected = current_target == StartThreadIn::LocalProject; - let is_new_worktree_selected = current_target == StartThreadIn::NewWorktree; let fs = fs.clone(); + Some(cx.new(|cx| { + ThreadWorktreePicker::new(project.clone(), ¤t_target, fs, window, cx) + })) + }) + .with_handle(self.start_thread_in_menu_handle.clone()) + .anchor(Corner::TopLeft) + .offset(gpui::Point { + x: px(1.0), + y: px(1.0), + }) + } - Some(ContextMenu::build(window, cx, move |menu, _window, _cx| { - let new_worktree_disabled = !has_git_repo || is_via_collab; + fn render_new_worktree_branch_selector(&self, cx: &mut Context) -> impl IntoElement { + let is_creating = matches!( + self.worktree_creation_status, + Some(WorktreeCreationStatus::Creating) + ); + let default_branch_label = if self.project.read(cx).repositories(cx).len() > 1 { + SharedString::from("From: current branches") + } else { + self.project + .read(cx) + .active_repository(cx) + .and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|branch| SharedString::from(format!("From: {}", branch.name()))) + }) + .unwrap_or_else(|| SharedString::from("From: HEAD")) + }; + let trigger_label = self + .start_thread_in + .worktree_branch_label(default_branch_label) + .unwrap_or_else(|| SharedString::from("From: HEAD")); + let icon = if self.thread_branch_menu_handle.is_deployed() { + IconName::ChevronUp + } else { + IconName::ChevronDown + }; + let trigger_button = Button::new("thread-branch-trigger", trigger_label) + .start_icon( + Icon::new(IconName::GitBranch) + .size(IconSize::Small) + .color(Color::Muted), + ) + .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .disabled(is_creating); + let project = self.project.clone(); + let current_target = self.start_thread_in.clone(); - menu.header("Start Thread In…") - .item( - ContextMenuEntry::new("Current Worktree") - .toggleable(IconPosition::End, is_local_selected) - .documentation_aside(documentation_side, move |_| { - HoldForDefault::new(is_local_default) - .more_content(false) - .into_any_element() - }) - .handler({ - let fs = fs.clone(); - move |window, cx| { - if window.modifiers().secondary() { - update_settings_file(fs.clone(), cx, |settings, _| { - settings - .agent - .get_or_insert_default() - .set_new_thread_location( - NewThreadLocation::LocalProject, - ); - }); - } - window.dispatch_action( - Box::new(StartThreadIn::LocalProject), - cx, - ); - } - }), - ) - .item({ - let entry = ContextMenuEntry::new("New Git Worktree") - .toggleable(IconPosition::End, is_new_worktree_selected) - .disabled(new_worktree_disabled) - .handler({ - let fs = fs.clone(); - move |window, cx| { - if window.modifiers().secondary() { - update_settings_file(fs.clone(), cx, |settings, _| { - settings - .agent - .get_or_insert_default() - .set_new_thread_location( - NewThreadLocation::NewWorktree, - ); - }); - } - window.dispatch_action( - Box::new(StartThreadIn::NewWorktree), - cx, - ); - } - }); - - if new_worktree_disabled { - entry.documentation_aside(documentation_side, move |_| { - let reason = if !has_git_repo { - "No git repository found in this project." - } else { - "Not available for remote/collab projects yet." - }; - Label::new(reason) - .color(Color::Muted) - .size(LabelSize::Small) - .into_any_element() - }) - } else { - entry.documentation_aside(documentation_side, move |_| { - HoldForDefault::new(is_new_worktree_default) - .more_content(false) - .into_any_element() - }) - } - }) + PopoverMenu::new("thread-branch-selector") + .trigger_with_tooltip(trigger_button, Tooltip::text("Choose Worktree Branch…")) + .menu(move |window, cx| { + Some(cx.new(|cx| { + ThreadBranchPicker::new(project.clone(), ¤t_target, window, cx) })) }) - .with_handle(self.start_thread_in_menu_handle.clone()) + .with_handle(self.thread_branch_menu_handle.clone()) .anchor(Corner::TopLeft) .offset(gpui::Point { x: px(1.0), @@ -3621,6 +3781,14 @@ impl AgentPanel { .when( has_visible_worktrees && self.project_has_git_repository(cx), |this| this.child(self.render_start_thread_in_selector(cx)), + ) + .when( + has_v2_flag + && matches!( + self.start_thread_in, + StartThreadIn::NewWorktree { .. } + ), + |this| this.child(self.render_new_worktree_branch_selector(cx)), ), ) .child( @@ -5265,13 +5433,23 @@ mod tests { // Change thread target to NewWorktree. panel.update_in(cx, |panel, window, cx| { - panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx); + panel.set_start_thread_in( + &StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); panel.read_with(cx, |panel, _cx| { assert_eq!( *panel.start_thread_in(), - StartThreadIn::NewWorktree, + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, "thread target should be NewWorktree after set_thread_target" ); }); @@ -5289,7 +5467,10 @@ mod tests { loaded_panel.read_with(cx, |panel, _cx| { assert_eq!( *panel.start_thread_in(), - StartThreadIn::NewWorktree, + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, "thread target should survive serialization round-trip" ); }); @@ -5420,6 +5601,53 @@ mod tests { ); } + #[test] + fn test_resolve_worktree_branch_target() { + let existing_branches = HashSet::from_iter([ + "main".to_string(), + "feature".to_string(), + "origin/main".to_string(), + ]); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::CreateBranch { + name: "new-branch".to_string(), + from_ref: Some("main".to_string()), + }, + &existing_branches, + &HashSet::from_iter(["main".to_string()]), + ) + .unwrap(); + assert_eq!( + resolved, + ("new-branch".to_string(), false, Some("main".to_string())) + ); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::ExistingBranch { + name: "feature".to_string(), + }, + &existing_branches, + &HashSet::default(), + ) + .unwrap(); + assert_eq!(resolved, ("feature".to_string(), true, None)); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::ExistingBranch { + name: "main".to_string(), + }, + &existing_branches, + &HashSet::from_iter(["main".to_string()]), + ) + .unwrap(); + assert_eq!(resolved.1, false); + assert_eq!(resolved.2, Some("main".to_string())); + assert_ne!(resolved.0, "main"); + assert!(existing_branches.contains("main")); + assert!(!existing_branches.contains(&resolved.0)); + } + #[gpui::test] async fn test_worktree_creation_preserves_selected_agent(cx: &mut TestAppContext) { init_test(cx); @@ -5513,7 +5741,14 @@ mod tests { panel.selected_agent = Agent::Custom { id: CODEX_ID.into(), }; - panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx); + panel.set_start_thread_in( + &StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); // Verify the panel has the Codex agent selected. @@ -5532,7 +5767,15 @@ mod tests { "Hello from test", ))]; panel.update_in(cx, |panel, window, cx| { - panel.handle_worktree_creation_requested(content, window, cx); + panel.handle_worktree_requested( + content, + WorktreeCreationArgs::New { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); // Let the async worktree creation + workspace setup complete. diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 5cff5bfc38d4512d659d919c6e7c4ff02fcc0caf..9daa7c6cd83c276aa99adc9e3aae3e6c82c5ba88 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -28,13 +28,16 @@ mod terminal_codegen; mod terminal_inline_assistant; #[cfg(any(test, feature = "test-support"))] pub mod test_support; +mod thread_branch_picker; mod thread_history; mod thread_history_view; mod thread_import; pub mod thread_metadata_store; +mod thread_worktree_picker; pub mod threads_archive_view; mod ui; +use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; @@ -314,16 +317,42 @@ impl Agent { } } +/// Describes which branch to use when creating a new git worktree. +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum NewWorktreeBranchTarget { + /// Create a new randomly named branch from the current HEAD. + /// Will match worktree name if the newly created worktree was also randomly named. + #[default] + CurrentBranch, + /// Check out an existing branch, or create a new branch from it if it's + /// already occupied by another worktree. + ExistingBranch { name: String }, + /// Create a new branch with an explicit name, optionally from a specific ref. + CreateBranch { + name: String, + #[serde(default)] + from_ref: Option, + }, +} + /// Sets where new threads will run. -#[derive( - Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Action, -)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] #[serde(rename_all = "snake_case", tag = "kind")] pub enum StartThreadIn { #[default] LocalProject, - NewWorktree, + NewWorktree { + /// When this is None, Zed will randomly generate a worktree name + /// otherwise, the provided name will be used. + #[serde(default)] + worktree_name: Option, + #[serde(default)] + branch_target: NewWorktreeBranchTarget, + }, + /// A linked worktree that already exists on disk. + LinkedWorktree { path: PathBuf, display_name: String }, } /// Content to initialize new external agent with. diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 685621eb3c93632f1e7410bbbad22b623d5e18c7..ff3dab1170064e058c0ebb44505c0906349517ee 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -869,7 +869,10 @@ impl ThreadView { .upgrade() .and_then(|workspace| workspace.read(cx).panel::(cx)) .is_some_and(|panel| { - panel.read(cx).start_thread_in() == &StartThreadIn::NewWorktree + !matches!( + panel.read(cx).start_thread_in(), + StartThreadIn::LocalProject + ) }); if intercept_first_send { diff --git a/crates/agent_ui/src/thread_branch_picker.rs b/crates/agent_ui/src/thread_branch_picker.rs new file mode 100644 index 0000000000000000000000000000000000000000..d69cbb4a60054ad83d767928c880f3a43caef4f1 --- /dev/null +++ b/crates/agent_ui/src/thread_branch_picker.rs @@ -0,0 +1,695 @@ +use std::collections::{HashMap, HashSet}; + +use collections::HashSet as CollectionsHashSet; +use std::path::PathBuf; +use std::sync::Arc; + +use fuzzy::StringMatchCandidate; +use git::repository::Branch as GitBranch; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, + ParentElement, Render, SharedString, Styled, Task, Window, rems, +}; +use picker::{Picker, PickerDelegate, PickerEditorPosition}; +use project::Project; +use ui::{ + HighlightedLabel, Icon, IconName, Label, LabelCommon, ListItem, ListItemSpacing, Tooltip, + prelude::*, +}; +use util::ResultExt as _; + +use crate::{NewWorktreeBranchTarget, StartThreadIn}; + +pub(crate) struct ThreadBranchPicker { + picker: Entity>, + focus_handle: FocusHandle, + _subscription: gpui::Subscription, +} + +impl ThreadBranchPicker { + pub fn new( + project: Entity, + current_target: &StartThreadIn, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let project_worktree_paths: HashSet = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_path_buf()) + .collect(); + + let has_multiple_repositories = project.read(cx).repositories(cx).len() > 1; + let current_branch_name = project + .read(cx) + .active_repository(cx) + .and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|branch| branch.name().to_string()) + }) + .unwrap_or_else(|| "HEAD".to_string()); + + let repository = if has_multiple_repositories { + None + } else { + project.read(cx).active_repository(cx) + }; + let branches_request = repository + .clone() + .map(|repo| repo.update(cx, |repo, _| repo.branches())); + let default_branch_request = repository + .clone() + .map(|repo| repo.update(cx, |repo, _| repo.default_branch(false))); + let worktrees_request = repository.map(|repo| repo.update(cx, |repo, _| repo.worktrees())); + + let (worktree_name, branch_target) = match current_target { + StartThreadIn::NewWorktree { + worktree_name, + branch_target, + } => (worktree_name.clone(), branch_target.clone()), + _ => (None, NewWorktreeBranchTarget::default()), + }; + + let delegate = ThreadBranchPickerDelegate { + matches: vec![ThreadBranchEntry::CurrentBranch], + all_branches: None, + occupied_branches: None, + selected_index: 0, + worktree_name, + branch_target, + project_worktree_paths, + current_branch_name, + default_branch_name: None, + has_multiple_repositories, + }; + + let picker = cx.new(|cx| { + Picker::list(delegate, window, cx) + .list_measure_all() + .modal(false) + .max_height(Some(rems(20.).into())) + }); + + let focus_handle = picker.focus_handle(cx); + + if let (Some(branches_request), Some(default_branch_request), Some(worktrees_request)) = + (branches_request, default_branch_request, worktrees_request) + { + let picker_handle = picker.downgrade(); + cx.spawn_in(window, async move |_this, cx| { + let branches = branches_request.await??; + let default_branch = default_branch_request.await.ok().and_then(Result::ok).flatten(); + let worktrees = worktrees_request.await??; + + let remote_upstreams: CollectionsHashSet<_> = branches + .iter() + .filter_map(|branch| { + branch + .upstream + .as_ref() + .filter(|upstream| upstream.is_remote()) + .map(|upstream| upstream.ref_name.clone()) + }) + .collect(); + + let mut occupied_branches = HashMap::new(); + for worktree in worktrees { + let Some(branch_name) = worktree.branch_name().map(ToOwned::to_owned) else { + continue; + }; + + let reason = if picker_handle + .read_with(cx, |picker, _| { + picker + .delegate + .project_worktree_paths + .contains(&worktree.path) + }) + .unwrap_or(false) + { + format!( + "This branch is already checked out in the current project worktree at {}.", + worktree.path.display() + ) + } else { + format!( + "This branch is already checked out in a linked worktree at {}.", + worktree.path.display() + ) + }; + + occupied_branches.insert(branch_name, reason); + } + + let mut all_branches: Vec<_> = branches + .into_iter() + .filter(|branch| !remote_upstreams.contains(&branch.ref_name)) + .collect(); + all_branches.sort_by_key(|branch| { + ( + branch.is_remote(), + !branch.is_head, + branch + .most_recent_commit + .as_ref() + .map(|commit| 0 - commit.commit_timestamp), + ) + }); + + picker_handle.update_in(cx, |picker, window, cx| { + picker.delegate.all_branches = Some(all_branches); + picker.delegate.occupied_branches = Some(occupied_branches); + picker.delegate.default_branch_name = default_branch.map(|branch| branch.to_string()); + picker.refresh(window, cx); + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + let subscription = cx.subscribe(&picker, |_, _, _, cx| { + cx.emit(DismissEvent); + }); + + Self { + picker, + focus_handle, + _subscription: subscription, + } + } +} + +impl Focusable for ThreadBranchPicker { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for ThreadBranchPicker {} + +impl Render for ThreadBranchPicker { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(22.)) + .elevation_3(cx) + .child(self.picker.clone()) + .on_mouse_down_out(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + } +} + +#[derive(Clone)] +enum ThreadBranchEntry { + CurrentBranch, + DefaultBranch, + ExistingBranch { + branch: GitBranch, + positions: Vec, + occupied_reason: Option, + }, + CreateNamed { + name: String, + }, +} + +pub(crate) struct ThreadBranchPickerDelegate { + matches: Vec, + all_branches: Option>, + occupied_branches: Option>, + selected_index: usize, + worktree_name: Option, + branch_target: NewWorktreeBranchTarget, + project_worktree_paths: HashSet, + current_branch_name: String, + default_branch_name: Option, + has_multiple_repositories: bool, +} + +impl ThreadBranchPickerDelegate { + fn new_worktree_action(&self, branch_target: NewWorktreeBranchTarget) -> StartThreadIn { + StartThreadIn::NewWorktree { + worktree_name: self.worktree_name.clone(), + branch_target, + } + } + + fn selected_entry_name(&self) -> Option<&str> { + match &self.branch_target { + NewWorktreeBranchTarget::CurrentBranch => None, + NewWorktreeBranchTarget::ExistingBranch { name } => Some(name), + NewWorktreeBranchTarget::CreateBranch { + from_ref: Some(from_ref), + .. + } => Some(from_ref), + NewWorktreeBranchTarget::CreateBranch { name, .. } => Some(name), + } + } + + fn prefer_create_entry(&self) -> bool { + matches!( + &self.branch_target, + NewWorktreeBranchTarget::CreateBranch { from_ref: None, .. } + ) + } + + fn fixed_matches(&self) -> Vec { + let mut matches = vec![ThreadBranchEntry::CurrentBranch]; + if !self.has_multiple_repositories + && self + .default_branch_name + .as_ref() + .is_some_and(|default_branch_name| default_branch_name != &self.current_branch_name) + { + matches.push(ThreadBranchEntry::DefaultBranch); + } + matches + } + + fn current_branch_label(&self) -> SharedString { + if self.has_multiple_repositories { + SharedString::from("New branch from: current branches") + } else { + SharedString::from(format!("New branch from: {}", self.current_branch_name)) + } + } + + fn default_branch_label(&self) -> Option { + let default_branch_name = self + .default_branch_name + .as_ref() + .filter(|name| *name != &self.current_branch_name)?; + let is_occupied = self + .occupied_branches + .as_ref() + .is_some_and(|occupied| occupied.contains_key(default_branch_name)); + let prefix = if is_occupied { + "New branch from" + } else { + "From" + }; + Some(SharedString::from(format!( + "{prefix}: {default_branch_name}" + ))) + } + + fn branch_label_prefix(&self, branch_name: &str) -> &'static str { + let is_occupied = self + .occupied_branches + .as_ref() + .is_some_and(|occupied| occupied.contains_key(branch_name)); + if is_occupied { + "New branch from: " + } else { + "From: " + } + } + + fn sync_selected_index(&mut self) { + let selected_entry_name = self.selected_entry_name().map(ToOwned::to_owned); + let prefer_create = self.prefer_create_entry(); + + if prefer_create { + if let Some(ref selected_entry_name) = selected_entry_name { + if let Some(index) = self.matches.iter().position(|entry| { + matches!( + entry, + ThreadBranchEntry::CreateNamed { name } if name == selected_entry_name + ) + }) { + self.selected_index = index; + return; + } + } + } else if let Some(ref selected_entry_name) = selected_entry_name { + if selected_entry_name == &self.current_branch_name { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadBranchEntry::CurrentBranch)) + { + self.selected_index = index; + return; + } + } + + if self + .default_branch_name + .as_ref() + .is_some_and(|default_branch_name| default_branch_name == selected_entry_name) + { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadBranchEntry::DefaultBranch)) + { + self.selected_index = index; + return; + } + } + + if let Some(index) = self.matches.iter().position(|entry| { + matches!( + entry, + ThreadBranchEntry::ExistingBranch { branch, .. } + if branch.name() == selected_entry_name.as_str() + ) + }) { + self.selected_index = index; + return; + } + } + + if self.matches.len() > 1 + && self + .matches + .iter() + .skip(1) + .all(|entry| matches!(entry, ThreadBranchEntry::CreateNamed { .. })) + { + self.selected_index = 1; + return; + } + + self.selected_index = 0; + } +} + +impl PickerDelegate for ThreadBranchPickerDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search branches…".into() + } + + fn editor_position(&self) -> PickerEditorPosition { + PickerEditorPosition::Start + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + if self.has_multiple_repositories { + let mut matches = self.fixed_matches(); + + if query.is_empty() { + if let Some(name) = self.selected_entry_name().map(ToOwned::to_owned) { + if self.prefer_create_entry() { + matches.push(ThreadBranchEntry::CreateNamed { name }); + } + } + } else { + matches.push(ThreadBranchEntry::CreateNamed { + name: query.replace(' ', "-"), + }); + } + + self.matches = matches; + self.sync_selected_index(); + return Task::ready(()); + } + + let Some(all_branches) = self.all_branches.clone() else { + self.matches = self.fixed_matches(); + self.selected_index = 0; + return Task::ready(()); + }; + let occupied_branches = self.occupied_branches.clone().unwrap_or_default(); + + if query.is_empty() { + let mut matches = self.fixed_matches(); + for branch in all_branches.into_iter().filter(|branch| { + branch.name() != self.current_branch_name + && self + .default_branch_name + .as_ref() + .is_none_or(|default_branch_name| branch.name() != default_branch_name) + }) { + matches.push(ThreadBranchEntry::ExistingBranch { + occupied_reason: occupied_branches.get(branch.name()).cloned(), + branch, + positions: Vec::new(), + }); + } + + if let Some(selected_entry_name) = self.selected_entry_name().map(ToOwned::to_owned) { + let has_existing = matches.iter().any(|entry| { + matches!( + entry, + ThreadBranchEntry::ExistingBranch { branch, .. } + if branch.name() == selected_entry_name + ) + }); + if self.prefer_create_entry() && !has_existing { + matches.push(ThreadBranchEntry::CreateNamed { + name: selected_entry_name, + }); + } + } + + self.matches = matches; + self.sync_selected_index(); + return Task::ready(()); + } + + let candidates: Vec<_> = all_branches + .iter() + .enumerate() + .map(|(ix, branch)| StringMatchCandidate::new(ix, branch.name())) + .collect(); + let executor = cx.background_executor().clone(); + let query_clone = query.clone(); + let normalized_query = query.replace(' ', "-"); + + let task = cx.background_executor().spawn(async move { + fuzzy::match_strings( + &candidates, + &query_clone, + true, + true, + 10000, + &Default::default(), + executor, + ) + .await + }); + + let all_branches_clone = all_branches; + cx.spawn_in(window, async move |picker, cx| { + let fuzzy_matches = task.await; + + picker + .update_in(cx, |picker, _window, cx| { + let mut matches = picker.delegate.fixed_matches(); + + for candidate in &fuzzy_matches { + let branch = all_branches_clone[candidate.candidate_id].clone(); + if branch.name() == picker.delegate.current_branch_name + || picker.delegate.default_branch_name.as_ref().is_some_and( + |default_branch_name| branch.name() == default_branch_name, + ) + { + continue; + } + let occupied_reason = occupied_branches.get(branch.name()).cloned(); + matches.push(ThreadBranchEntry::ExistingBranch { + branch, + positions: candidate.positions.clone(), + occupied_reason, + }); + } + + if fuzzy_matches.is_empty() { + matches.push(ThreadBranchEntry::CreateNamed { + name: normalized_query.clone(), + }); + } + + picker.delegate.matches = matches; + if let Some(index) = + picker.delegate.matches.iter().position(|entry| { + matches!(entry, ThreadBranchEntry::ExistingBranch { .. }) + }) + { + picker.delegate.selected_index = index; + } else if !fuzzy_matches.is_empty() { + picker.delegate.selected_index = 0; + } else if let Some(index) = + picker.delegate.matches.iter().position(|entry| { + matches!(entry, ThreadBranchEntry::CreateNamed { .. }) + }) + { + picker.delegate.selected_index = index; + } else { + picker.delegate.sync_selected_index(); + } + cx.notify(); + }) + .log_err(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(self.selected_index) else { + return; + }; + + match entry { + ThreadBranchEntry::CurrentBranch => { + window.dispatch_action( + Box::new(self.new_worktree_action(NewWorktreeBranchTarget::CurrentBranch)), + cx, + ); + } + ThreadBranchEntry::DefaultBranch => { + let Some(default_branch_name) = self.default_branch_name.clone() else { + return; + }; + window.dispatch_action( + Box::new( + self.new_worktree_action(NewWorktreeBranchTarget::ExistingBranch { + name: default_branch_name, + }), + ), + cx, + ); + } + ThreadBranchEntry::ExistingBranch { branch, .. } => { + let branch_target = if branch.is_remote() { + let branch_name = branch + .ref_name + .as_ref() + .strip_prefix("refs/remotes/") + .and_then(|stripped| stripped.split_once('/').map(|(_, name)| name)) + .unwrap_or(branch.name()) + .to_string(); + NewWorktreeBranchTarget::CreateBranch { + name: branch_name, + from_ref: Some(branch.name().to_string()), + } + } else { + NewWorktreeBranchTarget::ExistingBranch { + name: branch.name().to_string(), + } + }; + window.dispatch_action(Box::new(self.new_worktree_action(branch_target)), cx); + } + ThreadBranchEntry::CreateNamed { name } => { + window.dispatch_action( + Box::new( + self.new_worktree_action(NewWorktreeBranchTarget::CreateBranch { + name: name.clone(), + from_ref: None, + }), + ), + cx, + ); + } + } + + cx.emit(DismissEvent); + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context>) {} + + fn separators_after_indices(&self) -> Vec { + let fixed_count = self.fixed_matches().len(); + if self.matches.len() > fixed_count { + vec![fixed_count - 1] + } else { + Vec::new() + } + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _cx: &mut Context>, + ) -> Option { + let entry = self.matches.get(ix)?; + + match entry { + ThreadBranchEntry::CurrentBranch => Some( + ListItem::new("current-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(Label::new(self.current_branch_label())), + ), + ThreadBranchEntry::DefaultBranch => Some( + ListItem::new("default-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(Label::new(self.default_branch_label()?)), + ), + ThreadBranchEntry::ExistingBranch { + branch, + positions, + occupied_reason, + } => { + let prefix = self.branch_label_prefix(branch.name()); + let branch_name = branch.name().to_string(); + let full_label = format!("{prefix}{branch_name}"); + let adjusted_positions: Vec = + positions.iter().map(|&p| p + prefix.len()).collect(); + + let item = ListItem::new(SharedString::from(format!("branch-{ix}"))) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(HighlightedLabel::new(full_label, adjusted_positions).truncate()); + + Some(if let Some(reason) = occupied_reason.clone() { + item.tooltip(Tooltip::text(reason)) + } else if branch.is_remote() { + item.tooltip(Tooltip::text( + "Create a new local branch from this remote branch", + )) + } else { + item + }) + } + ThreadBranchEntry::CreateNamed { name } => Some( + ListItem::new("create-named-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::Plus).color(Color::Accent)) + .child(Label::new(format!("Create Branch: \"{name}\"…"))), + ), + } + } + + fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { + None + } +} diff --git a/crates/agent_ui/src/thread_worktree_picker.rs b/crates/agent_ui/src/thread_worktree_picker.rs new file mode 100644 index 0000000000000000000000000000000000000000..47a6a12d71822e13ab3523a3a6b0bb1ee57c7b4b --- /dev/null +++ b/crates/agent_ui/src/thread_worktree_picker.rs @@ -0,0 +1,485 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use agent_settings::AgentSettings; +use fs::Fs; +use fuzzy::StringMatchCandidate; +use git::repository::Worktree as GitWorktree; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, + ParentElement, Render, SharedString, Styled, Task, Window, rems, +}; +use picker::{Picker, PickerDelegate, PickerEditorPosition}; +use project::{Project, git_store::RepositoryId}; +use settings::{NewThreadLocation, Settings, update_settings_file}; +use ui::{ + HighlightedLabel, Icon, IconName, Label, LabelCommon, ListItem, ListItemSpacing, Tooltip, + prelude::*, +}; +use util::ResultExt as _; + +use crate::ui::HoldForDefault; +use crate::{NewWorktreeBranchTarget, StartThreadIn}; + +pub(crate) struct ThreadWorktreePicker { + picker: Entity>, + focus_handle: FocusHandle, + _subscription: gpui::Subscription, +} + +impl ThreadWorktreePicker { + pub fn new( + project: Entity, + current_target: &StartThreadIn, + fs: Arc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let project_worktree_paths: Vec = project + .read(cx) + .visible_worktrees(cx) + .map(|wt| wt.read(cx).abs_path().to_path_buf()) + .collect(); + + let preserved_branch_target = match current_target { + StartThreadIn::NewWorktree { branch_target, .. } => branch_target.clone(), + _ => NewWorktreeBranchTarget::default(), + }; + + let delegate = ThreadWorktreePickerDelegate { + matches: vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ], + all_worktrees: project + .read(cx) + .repositories(cx) + .iter() + .map(|(repo_id, repo)| (*repo_id, repo.read(cx).linked_worktrees.clone())) + .collect(), + project_worktree_paths, + selected_index: match current_target { + StartThreadIn::LocalProject => 0, + StartThreadIn::NewWorktree { .. } => 1, + _ => 0, + }, + project: project.clone(), + preserved_branch_target, + fs, + }; + + let picker = cx.new(|cx| { + Picker::list(delegate, window, cx) + .list_measure_all() + .modal(false) + .max_height(Some(rems(20.).into())) + }); + + let subscription = cx.subscribe(&picker, |_, _, _, cx| { + cx.emit(DismissEvent); + }); + + Self { + focus_handle: picker.focus_handle(cx), + picker, + _subscription: subscription, + } + } +} + +impl Focusable for ThreadWorktreePicker { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for ThreadWorktreePicker {} + +impl Render for ThreadWorktreePicker { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(20.)) + .elevation_3(cx) + .child(self.picker.clone()) + .on_mouse_down_out(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + } +} + +#[derive(Clone)] +enum ThreadWorktreeEntry { + CurrentWorktree, + NewWorktree, + LinkedWorktree { + worktree: GitWorktree, + positions: Vec, + }, + CreateNamed { + name: String, + disabled_reason: Option, + }, +} + +pub(crate) struct ThreadWorktreePickerDelegate { + matches: Vec, + all_worktrees: Vec<(RepositoryId, Arc<[GitWorktree]>)>, + project_worktree_paths: Vec, + selected_index: usize, + preserved_branch_target: NewWorktreeBranchTarget, + project: Entity, + fs: Arc, +} + +impl ThreadWorktreePickerDelegate { + fn new_worktree_action(&self, worktree_name: Option) -> StartThreadIn { + StartThreadIn::NewWorktree { + worktree_name, + branch_target: self.preserved_branch_target.clone(), + } + } + + fn sync_selected_index(&mut self) { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadWorktreeEntry::LinkedWorktree { .. })) + { + self.selected_index = index; + } else if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadWorktreeEntry::CreateNamed { .. })) + { + self.selected_index = index; + } else { + self.selected_index = 0; + } + } +} + +impl PickerDelegate for ThreadWorktreePickerDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search or create worktrees…".into() + } + + fn editor_position(&self) -> PickerEditorPosition { + PickerEditorPosition::Start + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn separators_after_indices(&self) -> Vec { + if self.matches.len() > 2 { + vec![1] + } else { + Vec::new() + } + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + let has_multiple_repositories = self.all_worktrees.len() > 1; + + let linked_worktrees: Vec<_> = if has_multiple_repositories { + Vec::new() + } else { + self.all_worktrees + .iter() + .flat_map(|(_, worktrees)| worktrees.iter()) + .filter(|worktree| { + !self + .project_worktree_paths + .iter() + .any(|project_path| project_path == &worktree.path) + }) + .cloned() + .collect() + }; + + let normalized_query = query.replace(' ', "-"); + let has_named_worktree = self.all_worktrees.iter().any(|(_, worktrees)| { + worktrees + .iter() + .any(|worktree| worktree.display_name() == normalized_query) + }); + let create_named_disabled_reason = if has_multiple_repositories { + Some("Cannot create a named worktree in a project with multiple repositories".into()) + } else if has_named_worktree { + Some("A worktree with this name already exists".into()) + } else { + None + }; + + let mut matches = vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ]; + + if query.is_empty() { + for worktree in &linked_worktrees { + matches.push(ThreadWorktreeEntry::LinkedWorktree { + worktree: worktree.clone(), + positions: Vec::new(), + }); + } + } else if linked_worktrees.is_empty() { + matches.push(ThreadWorktreeEntry::CreateNamed { + name: normalized_query, + disabled_reason: create_named_disabled_reason, + }); + } else { + let candidates: Vec<_> = linked_worktrees + .iter() + .enumerate() + .map(|(ix, worktree)| StringMatchCandidate::new(ix, worktree.display_name())) + .collect(); + + let executor = cx.background_executor().clone(); + let query_clone = query.clone(); + + let task = cx.background_executor().spawn(async move { + fuzzy::match_strings( + &candidates, + &query_clone, + true, + true, + 10000, + &Default::default(), + executor, + ) + .await + }); + + let linked_worktrees_clone = linked_worktrees; + return cx.spawn_in(window, async move |picker, cx| { + let fuzzy_matches = task.await; + + picker + .update_in(cx, |picker, _window, cx| { + let mut new_matches = vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ]; + + for candidate in &fuzzy_matches { + new_matches.push(ThreadWorktreeEntry::LinkedWorktree { + worktree: linked_worktrees_clone[candidate.candidate_id].clone(), + positions: candidate.positions.clone(), + }); + } + + let has_exact_match = linked_worktrees_clone + .iter() + .any(|worktree| worktree.display_name() == query); + + if !has_exact_match { + new_matches.push(ThreadWorktreeEntry::CreateNamed { + name: normalized_query.clone(), + disabled_reason: create_named_disabled_reason.clone(), + }); + } + + picker.delegate.matches = new_matches; + picker.delegate.sync_selected_index(); + + cx.notify(); + }) + .log_err(); + }); + } + + self.matches = matches; + self.sync_selected_index(); + + Task::ready(()) + } + + fn confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(self.selected_index) else { + return; + }; + + match entry { + ThreadWorktreeEntry::CurrentWorktree => { + if secondary { + update_settings_file(self.fs.clone(), cx, |settings, _| { + settings + .agent + .get_or_insert_default() + .set_new_thread_location(NewThreadLocation::LocalProject); + }); + } + window.dispatch_action(Box::new(StartThreadIn::LocalProject), cx); + } + ThreadWorktreeEntry::NewWorktree => { + if secondary { + update_settings_file(self.fs.clone(), cx, |settings, _| { + settings + .agent + .get_or_insert_default() + .set_new_thread_location(NewThreadLocation::NewWorktree); + }); + } + window.dispatch_action(Box::new(self.new_worktree_action(None)), cx); + } + ThreadWorktreeEntry::LinkedWorktree { worktree, .. } => { + window.dispatch_action( + Box::new(StartThreadIn::LinkedWorktree { + path: worktree.path.clone(), + display_name: worktree.display_name().to_string(), + }), + cx, + ); + } + ThreadWorktreeEntry::CreateNamed { + name, + disabled_reason: None, + } => { + window.dispatch_action(Box::new(self.new_worktree_action(Some(name.clone()))), cx); + } + ThreadWorktreeEntry::CreateNamed { + disabled_reason: Some(_), + .. + } => { + return; + } + } + + cx.emit(DismissEvent); + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + cx: &mut Context>, + ) -> Option { + let entry = self.matches.get(ix)?; + let project = self.project.read(cx); + let is_new_worktree_disabled = + project.repositories(cx).is_empty() || project.is_via_collab(); + let new_thread_location = AgentSettings::get_global(cx).new_thread_location; + let is_local_default = new_thread_location == NewThreadLocation::LocalProject; + let is_new_worktree_default = new_thread_location == NewThreadLocation::NewWorktree; + + match entry { + ThreadWorktreeEntry::CurrentWorktree => Some( + ListItem::new("current-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::Folder).color(Color::Muted)) + .child(Label::new("Current Worktree")) + .end_slot(HoldForDefault::new(is_local_default).more_content(false)) + .tooltip(Tooltip::text("Use the current project worktree")), + ), + ThreadWorktreeEntry::NewWorktree => { + let item = ListItem::new("new-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .disabled(is_new_worktree_disabled) + .start_slot( + Icon::new(IconName::Plus).color(if is_new_worktree_disabled { + Color::Disabled + } else { + Color::Muted + }), + ) + .child( + Label::new("New Git Worktree").color(if is_new_worktree_disabled { + Color::Disabled + } else { + Color::Default + }), + ); + + Some(if is_new_worktree_disabled { + item.tooltip(Tooltip::text("Requires a Git repository in the project")) + } else { + item.end_slot(HoldForDefault::new(is_new_worktree_default).more_content(false)) + .tooltip(Tooltip::text("Start a thread in a new Git worktree")) + }) + } + ThreadWorktreeEntry::LinkedWorktree { + worktree, + positions, + } => { + let display_name = worktree.display_name(); + let first_line = display_name.lines().next().unwrap_or(display_name); + let positions: Vec<_> = positions + .iter() + .copied() + .filter(|&pos| pos < first_line.len()) + .collect(); + + Some( + ListItem::new(SharedString::from(format!("linked-worktree-{ix}"))) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitWorktree).color(Color::Muted)) + .child(HighlightedLabel::new(first_line.to_owned(), positions).truncate()), + ) + } + ThreadWorktreeEntry::CreateNamed { + name, + disabled_reason, + } => { + let is_disabled = disabled_reason.is_some(); + let item = ListItem::new("create-named-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .disabled(is_disabled) + .start_slot(Icon::new(IconName::Plus).color(if is_disabled { + Color::Disabled + } else { + Color::Accent + })) + .child(Label::new(format!("Create Worktree: \"{name}\"…")).color( + if is_disabled { + Color::Disabled + } else { + Color::Default + }, + )); + + Some(if let Some(reason) = disabled_reason.clone() { + item.tooltip(Tooltip::text(reason)) + } else { + item + }) + } + } + } + + fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { + None + } +} diff --git a/crates/collab/tests/integration/git_tests.rs b/crates/collab/tests/integration/git_tests.rs index 2fa67b072f1c3d49ef5ca1b90056fd08d57df1ba..c273005264d0a53b6a083a4013f7597a56919016 100644 --- a/crates/collab/tests/integration/git_tests.rs +++ b/crates/collab/tests/integration/git_tests.rs @@ -269,9 +269,11 @@ async fn test_remote_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repository, _| { repository.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_directory.join("feature-branch"), - Some("abc123".to_string()), ) }) }) @@ -323,9 +325,11 @@ async fn test_remote_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repository, _| { repository.create_worktree( - "bugfix-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_directory.join("bugfix-branch"), - None, ) }) }) diff --git a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs index 0796323fc5b3d8f6b1cbcb0e108a7d573240f446..d478402a9d66ca9fba4e8f9517cb62898754e677 100644 --- a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs +++ b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs @@ -473,9 +473,11 @@ async fn test_ssh_collaboration_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repo, _| { repo.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_directory.join("feature-branch"), - Some("abc123".to_string()), ) }) }) diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 751796fb83164b78dc5d6789f0ae7870eff16ce1..fbebeabf0ac15dde80016958eb358f792f46dd50 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -6,9 +6,10 @@ use git::{ Oid, RunHook, blame::Blame, repository::{ - AskPassDelegate, Branch, CommitDataReader, CommitDetails, CommitOptions, FetchOptions, - GRAPH_CHUNK_SIZE, GitRepository, GitRepositoryCheckpoint, InitialGraphCommitData, LogOrder, - LogSource, PushOptions, Remote, RepoPath, ResetMode, SearchCommitArgs, Worktree, + AskPassDelegate, Branch, CommitDataReader, CommitDetails, CommitOptions, + CreateWorktreeTarget, FetchOptions, GRAPH_CHUNK_SIZE, GitRepository, + GitRepositoryCheckpoint, InitialGraphCommitData, LogOrder, LogSource, PushOptions, Remote, + RepoPath, ResetMode, SearchCommitArgs, Worktree, }, stash::GitStash, status::{ @@ -540,9 +541,8 @@ impl GitRepository for FakeGitRepository { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let fs = self.fs.clone(); let executor = self.executor.clone(); @@ -550,30 +550,82 @@ impl GitRepository for FakeGitRepository { let common_dir_path = self.common_dir_path.clone(); async move { executor.simulate_random_delay().await; - // Check for simulated error and duplicate branch before any side effects. - fs.with_git_state(&dot_git_path, false, |state| { - if let Some(message) = &state.simulated_create_worktree_error { - anyhow::bail!("{message}"); - } - if let Some(ref name) = branch_name { - if state.branches.contains(name) { - bail!("a branch named '{}' already exists", name); + + let branch_name = target.branch_name().map(ToOwned::to_owned); + let create_branch_ref = matches!(target, CreateWorktreeTarget::NewBranch { .. }); + + // Check for simulated error and validate branch state before any side effects. + fs.with_git_state(&dot_git_path, false, { + let branch_name = branch_name.clone(); + move |state| { + if let Some(message) = &state.simulated_create_worktree_error { + anyhow::bail!("{message}"); } + + match (create_branch_ref, branch_name.as_ref()) { + (true, Some(branch_name)) => { + if state.branches.contains(branch_name) { + bail!("a branch named '{}' already exists", branch_name); + } + } + (false, Some(branch_name)) => { + if !state.branches.contains(branch_name) { + bail!("no branch named '{}' exists", branch_name); + } + } + (false, None) => {} + (true, None) => bail!("branch name is required to create a branch"), + } + + Ok(()) } - Ok(()) })??; + let (branch_name, sha, create_branch_ref) = match target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + let ref_name = format!("refs/heads/{branch_name}"); + let sha = fs.with_git_state(&dot_git_path, false, { + move |state| { + Ok::<_, anyhow::Error>( + state + .refs + .get(&ref_name) + .cloned() + .unwrap_or_else(|| "fake-sha".to_string()), + ) + } + })??; + (Some(branch_name), sha, false) + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => ( + Some(branch_name), + start_point.unwrap_or_else(|| "fake-sha".to_string()), + true, + ), + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => ( + None, + start_point.unwrap_or_else(|| "fake-sha".to_string()), + false, + ), + }; + // Create the worktree checkout directory. fs.create_dir(&path).await?; // Create .git/worktrees// directory with HEAD, commondir, gitdir. - let worktree_entry_name = branch_name - .as_deref() - .unwrap_or_else(|| path.file_name().unwrap().to_str().unwrap()); + let worktree_entry_name = branch_name.as_deref().unwrap_or_else(|| { + path.file_name() + .and_then(|name| name.to_str()) + .unwrap_or("detached") + }); let worktrees_entry_dir = common_dir_path.join("worktrees").join(worktree_entry_name); fs.create_dir(&worktrees_entry_dir).await?; - let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); let head_content = if let Some(ref branch_name) = branch_name { let ref_name = format!("refs/heads/{branch_name}"); format!("ref: {ref_name}") @@ -604,15 +656,22 @@ impl GitRepository for FakeGitRepository { false, )?; - // Update git state: add ref and branch. - fs.with_git_state(&dot_git_path, true, move |state| { - if let Some(branch_name) = branch_name { - let ref_name = format!("refs/heads/{branch_name}"); - state.refs.insert(ref_name, sha); - state.branches.insert(branch_name); - } - Ok::<(), anyhow::Error>(()) - })??; + // Update git state for newly created branches. + if create_branch_ref { + fs.with_git_state(&dot_git_path, true, { + let branch_name = branch_name.clone(); + let sha = sha.clone(); + move |state| { + if let Some(branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + state.refs.insert(ref_name, sha); + state.branches.insert(branch_name); + } + Ok::<(), anyhow::Error>(()) + } + })??; + } + Ok(()) } .boxed() diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs index f4192a22bb42f88f8769ef59f817b2bf2a288fb9..3be81ad7301e6fc4ee6f4529ce8bb587de3b4565 100644 --- a/crates/fs/tests/integration/fake_git_repo.rs +++ b/crates/fs/tests/integration/fake_git_repo.rs @@ -24,9 +24,11 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a worktree let worktree_1_dir = worktrees_dir.join("feature-branch"); repo.create_worktree( - Some("feature-branch".to_string()), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_1_dir.clone(), - Some("abc123".to_string()), ) .await .unwrap(); @@ -48,9 +50,11 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a second worktree (without explicit commit) let worktree_2_dir = worktrees_dir.join("bugfix-branch"); repo.create_worktree( - Some("bugfix-branch".to_string()), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_2_dir.clone(), - None, ) .await .unwrap(); diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index c42d2e28cf041e40404c1b8276ddcf5d10ca5f01..ba717d00c5e40374f5315d3ee8bc12e671f09552 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -241,20 +241,57 @@ pub struct Worktree { pub is_main: bool, } +/// Describes how a new worktree should choose or create its checked-out HEAD. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub enum CreateWorktreeTarget { + /// Check out an existing local branch in the new worktree. + ExistingBranch { + /// The existing local branch to check out. + branch_name: String, + }, + /// Create a new local branch for the new worktree. + NewBranch { + /// The new local branch to create and check out. + branch_name: String, + /// The commit or ref to create the branch from. Uses `HEAD` when `None`. + base_sha: Option, + }, + /// Check out a commit or ref in detached HEAD state. + Detached { + /// The commit or ref to check out. Uses `HEAD` when `None`. + base_sha: Option, + }, +} + +impl CreateWorktreeTarget { + pub fn branch_name(&self) -> Option<&str> { + match self { + Self::ExistingBranch { branch_name } | Self::NewBranch { branch_name, .. } => { + Some(branch_name) + } + Self::Detached { .. } => None, + } + } +} + impl Worktree { + /// Returns the branch name if the worktree is attached to a branch. + pub fn branch_name(&self) -> Option<&str> { + self.ref_name.as_ref().map(|ref_name| { + ref_name + .strip_prefix("refs/heads/") + .or_else(|| ref_name.strip_prefix("refs/remotes/")) + .unwrap_or(ref_name) + }) + } + /// Returns a display name for the worktree, suitable for use in the UI. /// /// If the worktree is attached to a branch, returns the branch name. /// Otherwise, returns the short SHA of the worktree's HEAD commit. pub fn display_name(&self) -> &str { - match self.ref_name { - Some(ref ref_name) => ref_name - .strip_prefix("refs/heads/") - .or_else(|| ref_name.strip_prefix("refs/remotes/")) - .unwrap_or(ref_name), - // Detached HEAD — show the short SHA as a fallback. - None => &self.sha[..self.sha.len().min(SHORT_SHA_LENGTH)], - } + self.branch_name() + .unwrap_or(&self.sha[..self.sha.len().min(SHORT_SHA_LENGTH)]) } } @@ -716,9 +753,8 @@ pub trait GitRepository: Send + Sync { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>>; fn remove_worktree(&self, path: PathBuf, force: bool) -> BoxFuture<'_, Result<()>>; @@ -1667,24 +1703,36 @@ impl GitRepository for RealGitRepository { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let git_binary = self.git_binary(); let mut args = vec![OsString::from("worktree"), OsString::from("add")]; - if let Some(branch_name) = &branch_name { - args.push(OsString::from("-b")); - args.push(OsString::from(branch_name.as_str())); - } else { - args.push(OsString::from("--detach")); - } - args.push(OsString::from("--")); - args.push(OsString::from(path.as_os_str())); - if let Some(from_commit) = from_commit { - args.push(OsString::from(from_commit)); - } else { - args.push(OsString::from("HEAD")); + + match &target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(branch_name)); + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => { + args.push(OsString::from("-b")); + args.push(OsString::from(branch_name)); + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(start_point.as_deref().unwrap_or("HEAD"))); + } + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => { + args.push(OsString::from("--detach")); + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(start_point.as_deref().unwrap_or("HEAD"))); + } } self.executor @@ -4054,9 +4102,11 @@ mod tests { // Create a new worktree repo.create_worktree( - Some("test-branch".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "test-branch".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4113,9 +4163,11 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("worktree-to-remove"); repo.create_worktree( - Some("to-remove".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "to-remove".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4137,9 +4189,11 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("dirty-wt"); repo.create_worktree( - Some("dirty-wt".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "dirty-wt".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4207,9 +4261,11 @@ mod tests { // Create a worktree let old_path = worktrees_dir.join("old-worktree-name"); repo.create_worktree( - Some("old-name".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "old-name".to_string(), + base_sha: Some("HEAD".to_string()), + }, old_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); diff --git a/crates/git_ui/src/worktree_picker.rs b/crates/git_ui/src/worktree_picker.rs index 1b4497be1f4ea96bd4f0431c97bb538eda9faa57..bd1d694fa30bb914569fbb5e6e3c67de3e3d86a0 100644 --- a/crates/git_ui/src/worktree_picker.rs +++ b/crates/git_ui/src/worktree_picker.rs @@ -318,8 +318,13 @@ impl WorktreeListDelegate { .clone(); let new_worktree_path = repo.path_for_new_linked_worktree(&branch, &worktree_directory_setting)?; - let receiver = - repo.create_worktree(branch.clone(), new_worktree_path.clone(), commit); + let receiver = repo.create_worktree( + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: branch.clone(), + base_sha: commit, + }, + new_worktree_path.clone(), + ); anyhow::Ok((receiver, new_worktree_path)) })?; receiver.await??; diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index e7e84ffe673881d898a56b64892887b9c8d6c809..8da5a14e41d9cb97865d78f4dfc2ed79f76faebd 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -32,10 +32,10 @@ use git::{ blame::Blame, parse_git_remote_url, repository::{ - Branch, CommitDetails, CommitDiff, CommitFile, CommitOptions, DiffType, FetchOptions, - GitRepository, GitRepositoryCheckpoint, GraphCommitData, InitialGraphCommitData, LogOrder, - LogSource, PushOptions, Remote, RemoteCommandOutput, RepoPath, ResetMode, SearchCommitArgs, - UpstreamTrackingStatus, Worktree as GitWorktree, + Branch, CommitDetails, CommitDiff, CommitFile, CommitOptions, CreateWorktreeTarget, + DiffType, FetchOptions, GitRepository, GitRepositoryCheckpoint, GraphCommitData, + InitialGraphCommitData, LogOrder, LogSource, PushOptions, Remote, RemoteCommandOutput, + RepoPath, ResetMode, SearchCommitArgs, UpstreamTrackingStatus, Worktree as GitWorktree, }, stash::{GitStash, StashEntry}, status::{ @@ -329,12 +329,6 @@ pub struct GraphDataResponse<'a> { pub error: Option, } -#[derive(Clone, Debug)] -enum CreateWorktreeStartPoint { - Detached, - Branched { name: String }, -} - pub struct Repository { this: WeakEntity, snapshot: RepositorySnapshot, @@ -2414,18 +2408,23 @@ impl GitStore { let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; let directory = PathBuf::from(envelope.payload.directory); - let start_point = if envelope.payload.name.is_empty() { - CreateWorktreeStartPoint::Detached + let name = envelope.payload.name; + let commit = envelope.payload.commit; + let use_existing_branch = envelope.payload.use_existing_branch; + let target = if name.is_empty() { + CreateWorktreeTarget::Detached { base_sha: commit } + } else if use_existing_branch { + CreateWorktreeTarget::ExistingBranch { branch_name: name } } else { - CreateWorktreeStartPoint::Branched { - name: envelope.payload.name, + CreateWorktreeTarget::NewBranch { + branch_name: name, + base_sha: commit, } }; - let commit = envelope.payload.commit; repository_handle .update(&mut cx, |repository_handle, _| { - repository_handle.create_worktree_with_start_point(start_point, directory, commit) + repository_handle.create_worktree(target, directory) }) .await??; @@ -6004,50 +6003,43 @@ impl Repository { }) } - fn create_worktree_with_start_point( + pub fn create_worktree( &mut self, - start_point: CreateWorktreeStartPoint, + target: CreateWorktreeTarget, path: PathBuf, - commit: Option, ) -> oneshot::Receiver> { - if matches!( - &start_point, - CreateWorktreeStartPoint::Branched { name } if name.is_empty() - ) { - let (sender, receiver) = oneshot::channel(); - sender - .send(Err(anyhow!("branch name cannot be empty"))) - .ok(); - return receiver; - } - let id = self.id; - let message = match &start_point { - CreateWorktreeStartPoint::Detached => "git worktree add (detached)".into(), - CreateWorktreeStartPoint::Branched { name } => { - format!("git worktree add: {name}").into() - } + let job_description = match target.branch_name() { + Some(branch_name) => format!("git worktree add: {branch_name}"), + None => "git worktree add (detached)".to_string(), }; - - self.send_job(Some(message), move |repo, _cx| async move { - let branch_name = match start_point { - CreateWorktreeStartPoint::Detached => None, - CreateWorktreeStartPoint::Branched { name } => Some(name), - }; - let remote_name = branch_name.clone().unwrap_or_default(); - + self.send_job(Some(job_description.into()), move |repo, _cx| async move { match repo { RepositoryState::Local(LocalRepositoryState { backend, .. }) => { - backend.create_worktree(branch_name, path, commit).await + backend.create_worktree(target, path).await } RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + let (name, commit, use_existing_branch) = match target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + (branch_name, None, true) + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => (branch_name, start_point, false), + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => (String::new(), start_point, false), + }; + client .request(proto::GitCreateWorktree { project_id: project_id.0, repository_id: id.to_proto(), - name: remote_name, + name, directory: path.to_string_lossy().to_string(), commit, + use_existing_branch, }) .await?; @@ -6057,28 +6049,16 @@ impl Repository { }) } - pub fn create_worktree( - &mut self, - branch_name: String, - path: PathBuf, - commit: Option, - ) -> oneshot::Receiver> { - self.create_worktree_with_start_point( - CreateWorktreeStartPoint::Branched { name: branch_name }, - path, - commit, - ) - } - pub fn create_worktree_detached( &mut self, path: PathBuf, commit: String, ) -> oneshot::Receiver> { - self.create_worktree_with_start_point( - CreateWorktreeStartPoint::Detached, + self.create_worktree( + CreateWorktreeTarget::Detached { + base_sha: Some(commit), + }, path, - Some(commit), ) } diff --git a/crates/project/tests/integration/git_store.rs b/crates/project/tests/integration/git_store.rs index 02f752b28b24a8135e2cba9307a5eacdc16f0fa3..bbe5c64d7cf7f5b2ffa9160df6130cd88ddc5d69 100644 --- a/crates/project/tests/integration/git_store.rs +++ b/crates/project/tests/integration/git_store.rs @@ -1267,9 +1267,11 @@ mod git_worktrees { cx.update(|cx| { repository.update(cx, |repository, _| { repository.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_1_directory.clone(), - Some("abc123".to_string()), ) }) }) @@ -1297,9 +1299,11 @@ mod git_worktrees { cx.update(|cx| { repository.update(cx, |repository, _| { repository.create_worktree( - "bugfix-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_2_directory.clone(), - None, ) }) }) diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 9324feb21b1f50ac1041ed0afc8b59cb9b7fe2c6..d0a594a2817ec50d9d35383587619e311f2950d8 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -594,6 +594,7 @@ message GitCreateWorktree { string name = 3; string directory = 4; optional string commit = 5; + bool use_existing_branch = 6; } message GitCreateCheckpoint { diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index b59123a1a159487f802210f3916e16856daf8e61..9f69cd3458c194228f37cfdeedcf0c9023b9b7bd 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -3080,7 +3080,7 @@ fn run_start_thread_in_selector_visual_tests( cx: &mut VisualTestAppContext, update_baseline: bool, ) -> Result { - use agent_ui::{AgentPanel, StartThreadIn, WorktreeCreationStatus}; + use agent_ui::{AgentPanel, NewWorktreeBranchTarget, StartThreadIn, WorktreeCreationStatus}; // Enable feature flags so the thread target selector renders cx.update(|cx| { @@ -3401,7 +3401,13 @@ edition = "2021" cx.update_window(workspace_window.into(), |_, _window, cx| { panel.update(cx, |panel, cx| { - panel.set_start_thread_in_for_tests(StartThreadIn::NewWorktree, cx); + panel.set_start_thread_in_for_tests( + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + cx, + ); }); })?; cx.run_until_parked(); @@ -3474,7 +3480,13 @@ edition = "2021" cx.run_until_parked(); cx.update_window(workspace_window.into(), |_, window, cx| { - window.dispatch_action(Box::new(StartThreadIn::NewWorktree), cx); + window.dispatch_action( + Box::new(StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }), + cx, + ); })?; cx.run_until_parked(); From 9c731640c7f5a4d91a94b3e68fa92eb8bc5e38ee Mon Sep 17 00:00:00 2001 From: Shardul Vaidya <31039336+5herlocked@users.noreply.github.com> Date: Tue, 7 Apr 2026 05:59:12 -0400 Subject: [PATCH 07/22] bedrock: Add new Bedrock models (NVIDIA, Z.AI, Mistral, MiniMax) (#53043) Add 9 new models across 3 new providers (NVIDIA, Z.AI) and expanded coverage for existing providers (Mistral, MiniMax): - NVIDIA Nemotron Super 3 120B, Nemotron Nano 3 30B - Mistral Devstral 2 123B, Ministral 14B - MiniMax M2.1, M2.5 - Z.AI GLM 5, GLM 4.7, GLM 4.7 Flash Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - bedrock: Added 9 new models across 3 new providers (NVIDIA, Z.AI) and expanded coverage for existing providers (Mistral, MiniMax) --- crates/bedrock/src/models.rs | 64 ++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 8b6113e4d5521fb3c7e27a7f2f6547c7a9db86ce..7c1e6e0e4e6ef873345c30c0af4c9e8842699c77 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -113,6 +113,10 @@ pub enum Model { MistralLarge3, #[serde(rename = "pixtral-large")] PixtralLarge, + #[serde(rename = "devstral-2-123b")] + Devstral2_123B, + #[serde(rename = "ministral-14b")] + Ministral14B, // Qwen models #[serde(rename = "qwen3-32b")] @@ -146,9 +150,27 @@ pub enum Model { #[serde(rename = "gpt-oss-120b")] GptOss120B, + // NVIDIA Nemotron models + #[serde(rename = "nemotron-super-3-120b")] + NemotronSuper3_120B, + #[serde(rename = "nemotron-nano-3-30b")] + NemotronNano3_30B, + // MiniMax models #[serde(rename = "minimax-m2")] MiniMaxM2, + #[serde(rename = "minimax-m2-1")] + MiniMaxM2_1, + #[serde(rename = "minimax-m2-5")] + MiniMaxM2_5, + + // Z.AI GLM models + #[serde(rename = "glm-5")] + GLM5, + #[serde(rename = "glm-4-7")] + GLM4_7, + #[serde(rename = "glm-4-7-flash")] + GLM4_7Flash, // Moonshot models #[serde(rename = "kimi-k2-thinking")] @@ -217,6 +239,8 @@ impl Model { Self::MagistralSmall => "magistral-small", Self::MistralLarge3 => "mistral-large-3", Self::PixtralLarge => "pixtral-large", + Self::Devstral2_123B => "devstral-2-123b", + Self::Ministral14B => "ministral-14b", Self::Qwen3_32B => "qwen3-32b", Self::Qwen3VL235B => "qwen3-vl-235b", Self::Qwen3_235B => "qwen3-235b", @@ -230,7 +254,14 @@ impl Model { Self::Nova2Lite => "nova-2-lite", Self::GptOss20B => "gpt-oss-20b", Self::GptOss120B => "gpt-oss-120b", + Self::NemotronSuper3_120B => "nemotron-super-3-120b", + Self::NemotronNano3_30B => "nemotron-nano-3-30b", Self::MiniMaxM2 => "minimax-m2", + Self::MiniMaxM2_1 => "minimax-m2-1", + Self::MiniMaxM2_5 => "minimax-m2-5", + Self::GLM5 => "glm-5", + Self::GLM4_7 => "glm-4-7", + Self::GLM4_7Flash => "glm-4-7-flash", Self::KimiK2Thinking => "kimi-k2-thinking", Self::KimiK2_5 => "kimi-k2-5", Self::DeepSeekR1 => "deepseek-r1", @@ -257,6 +288,8 @@ impl Model { Self::MagistralSmall => "mistral.magistral-small-2509", Self::MistralLarge3 => "mistral.mistral-large-3-675b-instruct", Self::PixtralLarge => "mistral.pixtral-large-2502-v1:0", + Self::Devstral2_123B => "mistral.devstral-2-123b", + Self::Ministral14B => "mistral.ministral-3-14b-instruct", Self::Qwen3VL235B => "qwen.qwen3-vl-235b-a22b", Self::Qwen3_32B => "qwen.qwen3-32b-v1:0", Self::Qwen3_235B => "qwen.qwen3-235b-a22b-2507-v1:0", @@ -270,7 +303,14 @@ impl Model { Self::Nova2Lite => "amazon.nova-2-lite-v1:0", Self::GptOss20B => "openai.gpt-oss-20b-1:0", Self::GptOss120B => "openai.gpt-oss-120b-1:0", + Self::NemotronSuper3_120B => "nvidia.nemotron-super-3-120b", + Self::NemotronNano3_30B => "nvidia.nemotron-nano-3-30b", Self::MiniMaxM2 => "minimax.minimax-m2", + Self::MiniMaxM2_1 => "minimax.minimax-m2.1", + Self::MiniMaxM2_5 => "minimax.minimax-m2.5", + Self::GLM5 => "zai.glm-5", + Self::GLM4_7 => "zai.glm-4.7", + Self::GLM4_7Flash => "zai.glm-4.7-flash", Self::KimiK2Thinking => "moonshot.kimi-k2-thinking", Self::KimiK2_5 => "moonshotai.kimi-k2.5", Self::DeepSeekR1 => "deepseek.r1-v1:0", @@ -297,6 +337,8 @@ impl Model { Self::MagistralSmall => "Magistral Small", Self::MistralLarge3 => "Mistral Large 3", Self::PixtralLarge => "Pixtral Large", + Self::Devstral2_123B => "Devstral 2 123B", + Self::Ministral14B => "Ministral 14B", Self::Qwen3VL235B => "Qwen3 VL 235B", Self::Qwen3_32B => "Qwen3 32B", Self::Qwen3_235B => "Qwen3 235B", @@ -310,7 +352,14 @@ impl Model { Self::Nova2Lite => "Amazon Nova 2 Lite", Self::GptOss20B => "GPT OSS 20B", Self::GptOss120B => "GPT OSS 120B", + Self::NemotronSuper3_120B => "Nemotron Super 3 120B", + Self::NemotronNano3_30B => "Nemotron Nano 3 30B", Self::MiniMaxM2 => "MiniMax M2", + Self::MiniMaxM2_1 => "MiniMax M2.1", + Self::MiniMaxM2_5 => "MiniMax M2.5", + Self::GLM5 => "GLM 5", + Self::GLM4_7 => "GLM 4.7", + Self::GLM4_7Flash => "GLM 4.7 Flash", Self::KimiK2Thinking => "Kimi K2 Thinking", Self::KimiK2_5 => "Kimi K2.5", Self::DeepSeekR1 => "DeepSeek R1", @@ -338,6 +387,7 @@ impl Model { Self::Llama4Scout17B | Self::Llama4Maverick17B => 128_000, Self::Gemma3_4B | Self::Gemma3_12B | Self::Gemma3_27B => 128_000, Self::MagistralSmall | Self::MistralLarge3 | Self::PixtralLarge => 128_000, + Self::Devstral2_123B | Self::Ministral14B => 256_000, Self::Qwen3_32B | Self::Qwen3VL235B | Self::Qwen3_235B @@ -349,7 +399,9 @@ impl Model { Self::NovaPremier => 1_000_000, Self::Nova2Lite => 300_000, Self::GptOss20B | Self::GptOss120B => 128_000, - Self::MiniMaxM2 => 128_000, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => 262_000, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => 196_000, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => 203_000, Self::KimiK2Thinking | Self::KimiK2_5 => 128_000, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => 128_000, Self::Custom { max_tokens, .. } => *max_tokens, @@ -373,6 +425,7 @@ impl Model { | Self::MagistralSmall | Self::MistralLarge3 | Self::PixtralLarge => 8_192, + Self::Devstral2_123B | Self::Ministral14B => 131_000, Self::Qwen3_32B | Self::Qwen3VL235B | Self::Qwen3_235B @@ -382,7 +435,9 @@ impl Model { | Self::Qwen3Coder480B => 8_192, Self::NovaLite | Self::NovaPro | Self::NovaPremier | Self::Nova2Lite => 5_000, Self::GptOss20B | Self::GptOss120B => 16_000, - Self::MiniMaxM2 => 16_000, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => 131_000, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => 98_000, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => 101_000, Self::KimiK2Thinking | Self::KimiK2_5 => 16_000, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => 16_000, Self::Custom { @@ -419,6 +474,7 @@ impl Model { | Self::ClaudeSonnet4_6 => true, Self::NovaLite | Self::NovaPro | Self::NovaPremier | Self::Nova2Lite => true, Self::MistralLarge3 | Self::PixtralLarge | Self::MagistralSmall => true, + Self::Devstral2_123B | Self::Ministral14B => true, // Gemma accepts toolConfig without error but produces unreliable tool // calls -- malformed JSON args, hallucinated tool names, dropped calls. Self::Qwen3_32B @@ -428,7 +484,9 @@ impl Model { | Self::Qwen3Coder30B | Self::Qwen3CoderNext | Self::Qwen3Coder480B => true, - Self::MiniMaxM2 => true, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => true, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => true, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => true, Self::KimiK2Thinking | Self::KimiK2_5 => true, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => true, _ => false, From 93438829c75f7f73dc14bba3c79b4626709a4b4e Mon Sep 17 00:00:00 2001 From: Bhuminjay Soni Date: Tue, 7 Apr 2026 15:35:02 +0530 Subject: [PATCH 08/22] Add fuzzy_nucleo crate for order independent file finder search (#51164) Closes #14428 Before you mark this PR as ready for review, make sure that you have: - [ ] Added a solid test coverage and/or screenshots from doing manual testing https://github.com/user-attachments/assets/7e0d67ff-cc4e-4609-880d-5c1794c64dcf - [x] Done a self-review taking into account security and performance aspects - [x] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) Release Notes: - Adds a new `fuzzy_nucleo` crate that implements order independent path matching using the `nucleo` library. currently integrated for file finder. --------- Signed-off-by: Bhuminjay Signed-off-by: 11happy --- Cargo.lock | 32 ++ Cargo.toml | 3 + crates/file_finder/Cargo.toml | 1 + crates/file_finder/src/file_finder.rs | 69 ++-- crates/file_finder/src/file_finder_tests.rs | 230 +++++++++++++ crates/fuzzy_nucleo/Cargo.toml | 21 ++ crates/fuzzy_nucleo/LICENSE-GPL | 1 + crates/fuzzy_nucleo/src/fuzzy_nucleo.rs | 5 + crates/fuzzy_nucleo/src/matcher.rs | 39 +++ crates/fuzzy_nucleo/src/paths.rs | 352 ++++++++++++++++++++ crates/project/Cargo.toml | 1 + crates/project/src/project.rs | 70 ++++ 12 files changed, 774 insertions(+), 50 deletions(-) create mode 100644 crates/fuzzy_nucleo/Cargo.toml create mode 120000 crates/fuzzy_nucleo/LICENSE-GPL create mode 100644 crates/fuzzy_nucleo/src/fuzzy_nucleo.rs create mode 100644 crates/fuzzy_nucleo/src/matcher.rs create mode 100644 crates/fuzzy_nucleo/src/paths.rs diff --git a/Cargo.lock b/Cargo.lock index 97412711a55667a4976a35313eb6c0388acc74ef..cbc494f9dc0fc1858a846fabe168b3538de4dbe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6183,6 +6183,7 @@ dependencies = [ "file_icons", "futures 0.3.32", "fuzzy", + "fuzzy_nucleo", "gpui", "menu", "open_path_prompt", @@ -6740,6 +6741,15 @@ dependencies = [ "thread_local", ] +[[package]] +name = "fuzzy_nucleo" +version = "0.1.0" +dependencies = [ + "gpui", + "nucleo", + "util", +] + [[package]] name = "gaoya" version = "0.2.0" @@ -11063,6 +11073,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nucleo" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5262af4c94921c2646c5ac6ff7900c2af9cbb08dc26a797e18130a7019c039d4" +dependencies = [ + "nucleo-matcher", + "parking_lot", + "rayon", +] + +[[package]] +name = "nucleo-matcher" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf33f538733d1a5a3494b836ba913207f14d9d4a1d3cd67030c5061bdd2cac85" +dependencies = [ + "memchr", + "unicode-segmentation", +] + [[package]] name = "num" version = "0.4.3" @@ -13203,6 +13234,7 @@ dependencies = [ "fs", "futures 0.3.32", "fuzzy", + "fuzzy_nucleo", "git", "git2", "git_hosting_providers", diff --git a/Cargo.toml b/Cargo.toml index 5cb5b991b645ec1b78b16f48493c7c8dc1426344..4c75dafae5df4d63815e0da5cabb95ccdad25e9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,7 @@ members = [ "crates/fs", "crates/fs_benchmarks", "crates/fuzzy", + "crates/fuzzy_nucleo", "crates/git", "crates/git_graph", "crates/git_hosting_providers", @@ -325,6 +326,7 @@ file_finder = { path = "crates/file_finder" } file_icons = { path = "crates/file_icons" } fs = { path = "crates/fs" } fuzzy = { path = "crates/fuzzy" } +fuzzy_nucleo = { path = "crates/fuzzy_nucleo" } git = { path = "crates/git" } git_graph = { path = "crates/git_graph" } git_hosting_providers = { path = "crates/git_hosting_providers" } @@ -609,6 +611,7 @@ naga = { version = "29.0", features = ["wgsl-in"] } nanoid = "0.4" nbformat = "1.2.0" nix = "0.29" +nucleo = "0.5" num-format = "0.4.4" objc = "0.2" objc2-app-kit = { version = "0.3", default-features = false, features = [ "NSGraphics" ] } diff --git a/crates/file_finder/Cargo.toml b/crates/file_finder/Cargo.toml index 5eb36f0f5150263629b407dbe07dc73b6eff31cf..67ebab62295e8db90a12f99cbc05e9b9e56c2c6b 100644 --- a/crates/file_finder/Cargo.toml +++ b/crates/file_finder/Cargo.toml @@ -21,6 +21,7 @@ editor.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true +fuzzy_nucleo.workspace = true gpui.workspace = true menu.workspace = true open_path_prompt.workspace = true diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index 4302669ddc11c94f7df128534217d00c27ef083a..a4d9ea042dea898b9dd9db7d40354cf960d210d5 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -9,7 +9,8 @@ use client::ChannelId; use collections::HashMap; use editor::Editor; use file_icons::FileIcons; -use fuzzy::{CharBag, PathMatch, PathMatchCandidate, StringMatch, StringMatchCandidate}; +use fuzzy::{StringMatch, StringMatchCandidate}; +use fuzzy_nucleo::{PathMatch, PathMatchCandidate}; use gpui::{ Action, AnyElement, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, Modifiers, ModifiersChangedEvent, ParentElement, Render, Styled, Task, WeakEntity, @@ -663,15 +664,6 @@ impl Matches { // For file-vs-file matches, use the existing detailed comparison. if let (Some(a_panel), Some(b_panel)) = (a.panel_match(), b.panel_match()) { - let a_in_filename = Self::is_filename_match(a_panel); - let b_in_filename = Self::is_filename_match(b_panel); - - match (a_in_filename, b_in_filename) { - (true, false) => return cmp::Ordering::Greater, - (false, true) => return cmp::Ordering::Less, - _ => {} - } - return a_panel.cmp(b_panel); } @@ -691,32 +683,6 @@ impl Matches { Match::CreateNew(_) => 0.0, } } - - /// Determines if the match occurred within the filename rather than in the path - fn is_filename_match(panel_match: &ProjectPanelOrdMatch) -> bool { - if panel_match.0.positions.is_empty() { - return false; - } - - if let Some(filename) = panel_match.0.path.file_name() { - let path_str = panel_match.0.path.as_unix_str(); - - if let Some(filename_pos) = path_str.rfind(filename) - && panel_match.0.positions[0] >= filename_pos - { - let mut prev_position = panel_match.0.positions[0]; - for p in &panel_match.0.positions[1..] { - if *p != prev_position + 1 { - return false; - } - prev_position = *p; - } - return true; - } - } - - false - } } fn matching_history_items<'a>( @@ -731,25 +697,16 @@ fn matching_history_items<'a>( let history_items_by_worktrees = history_items .into_iter() .chain(currently_opened) - .filter_map(|found_path| { + .map(|found_path| { let candidate = PathMatchCandidate { is_dir: false, // You can't open directories as project items path: &found_path.project.path, // Only match history items names, otherwise their paths may match too many queries, producing false positives. // E.g. `foo` would match both `something/foo/bar.rs` and `something/foo/foo.rs` and if the former is a history item, // it would be shown first always, despite the latter being a better match. - char_bag: CharBag::from_iter( - found_path - .project - .path - .file_name()? - .to_string() - .to_lowercase() - .chars(), - ), }; candidates_paths.insert(&found_path.project, found_path); - Some((found_path.project.worktree_id, candidate)) + (found_path.project.worktree_id, candidate) }) .fold( HashMap::default(), @@ -767,8 +724,9 @@ fn matching_history_items<'a>( let worktree_root_name = worktree_name_by_id .as_ref() .and_then(|w| w.get(&worktree).cloned()); + matching_history_paths.extend( - fuzzy::match_fixed_path_set( + fuzzy_nucleo::match_fixed_path_set( candidates, worktree.to_usize(), worktree_root_name, @@ -778,6 +736,18 @@ fn matching_history_items<'a>( path_style, ) .into_iter() + // filter matches where at least one matched position is in filename portion, to prevent directory matches, nucleo scores them higher as history items are matched against their full path + .filter(|path_match| { + if let Some(filename) = path_match.path.file_name() { + let filename_start = path_match.path.as_unix_str().len() - filename.len(); + path_match + .positions + .iter() + .any(|&pos| pos >= filename_start) + } else { + true + } + }) .filter_map(|path_match| { candidates_paths .remove_entry(&ProjectPath { @@ -940,7 +910,7 @@ impl FileFinderDelegate { self.cancel_flag = Arc::new(AtomicBool::new(false)); let cancel_flag = self.cancel_flag.clone(); cx.spawn_in(window, async move |picker, cx| { - let matches = fuzzy::match_path_sets( + let matches = fuzzy_nucleo::match_path_sets( candidate_sets.as_slice(), query.path_query(), &relative_to, @@ -1452,7 +1422,6 @@ impl PickerDelegate for FileFinderDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let raw_query = raw_query.replace(' ', ""); let raw_query = raw_query.trim(); let raw_query = match &raw_query.get(0..2) { diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs index cd9cdeee1ff266717d380aeaecf7cbeb66ec8309..7a17202a5e4ba96b001ea46ed310518d02baf1ff 100644 --- a/crates/file_finder/src/file_finder_tests.rs +++ b/crates/file_finder/src/file_finder_tests.rs @@ -4161,3 +4161,233 @@ async fn test_clear_navigation_history(cx: &mut TestAppContext) { "Should have no history items after clearing" ); } + +#[gpui::test] +async fn test_order_independent_search(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "internal": { + "auth": { + "login.rs": "", + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + // forward order + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("auth internal"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].path.as_unix_str(), "internal/auth/login.rs"); + }); + + // reverse order should give same result + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("internal auth"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].path.as_unix_str(), "internal/auth/login.rs"); + }); +} + +#[gpui::test] +async fn test_filename_preferred_over_directory_match(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "settings_ui": { + "src": { + "pages": { + "audio_test_window.rs": "", + "audio_input_output_setup.rs": "", + } + } + }, + "audio": { + "src": { + "audio_settings.rs": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("settings audio"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/audio/src/audio_settings.rs" + ); + }); +} + +#[gpui::test] +async fn test_start_of_word_preferred_over_scattered_match(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "livekit_client": { + "src": { + "livekit_client": { + "playback.rs": "", + } + } + }, + "vim": { + "test_data": { + "test_record_replay_interleaved.json": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("live pla"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/livekit_client/src/livekit_client/playback.rs", + ); + }); +} + +#[gpui::test] +async fn test_exact_filename_stem_preferred(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "assets": { + "icons": { + "file_icons": { + "nix.svg": "", + } + } + }, + "crates": { + "zed": { + "resources": { + "app-icon-nightly@2x.png": "", + "app-icon-preview@2x.png": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("nix icon"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "assets/icons/file_icons/nix.svg", + ); + }); +} + +#[gpui::test] +async fn test_exact_filename_with_directory_token(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "agent_servers": { + "src": { + "acp.rs": "", + "agent_server.rs": "", + "custom.rs": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("acp server"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/agent_servers/src/acp.rs", + ); + }); +} diff --git a/crates/fuzzy_nucleo/Cargo.toml b/crates/fuzzy_nucleo/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..59e8b642524777f449f79edba85093eef069ebff --- /dev/null +++ b/crates/fuzzy_nucleo/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "fuzzy_nucleo" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/fuzzy_nucleo.rs" +doctest = false + +[dependencies] +nucleo.workspace = true +gpui.workspace = true +util.workspace = true + +[dev-dependencies] +util = {workspace = true, features = ["test-support"]} diff --git a/crates/fuzzy_nucleo/LICENSE-GPL b/crates/fuzzy_nucleo/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/fuzzy_nucleo/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs b/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs new file mode 100644 index 0000000000000000000000000000000000000000..ddaa5c3489cf55d41d31440f037214b1dce0358c --- /dev/null +++ b/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs @@ -0,0 +1,5 @@ +mod matcher; +mod paths; +pub use paths::{ + PathMatch, PathMatchCandidate, PathMatchCandidateSet, match_fixed_path_set, match_path_sets, +}; diff --git a/crates/fuzzy_nucleo/src/matcher.rs b/crates/fuzzy_nucleo/src/matcher.rs new file mode 100644 index 0000000000000000000000000000000000000000..b31da011106341420095bcffbfd012f40014ad6c --- /dev/null +++ b/crates/fuzzy_nucleo/src/matcher.rs @@ -0,0 +1,39 @@ +use std::sync::Mutex; + +static MATCHERS: Mutex> = Mutex::new(Vec::new()); + +pub const LENGTH_PENALTY: f64 = 0.01; + +pub fn get_matcher(config: nucleo::Config) -> nucleo::Matcher { + let mut matchers = MATCHERS.lock().unwrap(); + match matchers.pop() { + Some(mut matcher) => { + matcher.config = config; + matcher + } + None => nucleo::Matcher::new(config), + } +} + +pub fn return_matcher(matcher: nucleo::Matcher) { + MATCHERS.lock().unwrap().push(matcher); +} + +pub fn get_matchers(n: usize, config: nucleo::Config) -> Vec { + let mut matchers: Vec<_> = { + let mut pool = MATCHERS.lock().unwrap(); + let available = pool.len().min(n); + pool.drain(..available) + .map(|mut matcher| { + matcher.config = config.clone(); + matcher + }) + .collect() + }; + matchers.resize_with(n, || nucleo::Matcher::new(config.clone())); + matchers +} + +pub fn return_matchers(mut matchers: Vec) { + MATCHERS.lock().unwrap().append(&mut matchers); +} diff --git a/crates/fuzzy_nucleo/src/paths.rs b/crates/fuzzy_nucleo/src/paths.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac766622c9d12c6e2a119fbcd7dd7fe7a3b5a90d --- /dev/null +++ b/crates/fuzzy_nucleo/src/paths.rs @@ -0,0 +1,352 @@ +use gpui::BackgroundExecutor; +use std::{ + cmp::Ordering, + sync::{ + Arc, + atomic::{self, AtomicBool}, + }, +}; +use util::{paths::PathStyle, rel_path::RelPath}; + +use nucleo::Utf32Str; +use nucleo::pattern::{Atom, AtomKind, CaseMatching, Normalization}; + +use crate::matcher::{self, LENGTH_PENALTY}; + +#[derive(Clone, Debug)] +pub struct PathMatchCandidate<'a> { + pub is_dir: bool, + pub path: &'a RelPath, +} + +#[derive(Clone, Debug)] +pub struct PathMatch { + pub score: f64, + pub positions: Vec, + pub worktree_id: usize, + pub path: Arc, + pub path_prefix: Arc, + pub is_dir: bool, + /// Number of steps removed from a shared parent with the relative path + /// Used to order closer paths first in the search list + pub distance_to_relative_ancestor: usize, +} + +pub trait PathMatchCandidateSet<'a>: Send + Sync { + type Candidates: Iterator>; + fn id(&self) -> usize; + fn len(&self) -> usize; + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn root_is_file(&self) -> bool; + fn prefix(&self) -> Arc; + fn candidates(&'a self, start: usize) -> Self::Candidates; + fn path_style(&self) -> PathStyle; +} + +impl PartialEq for PathMatch { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for PathMatch {} + +impl PartialOrd for PathMatch { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PathMatch { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .partial_cmp(&other.score) + .unwrap_or(Ordering::Equal) + .then_with(|| self.worktree_id.cmp(&other.worktree_id)) + .then_with(|| { + other + .distance_to_relative_ancestor + .cmp(&self.distance_to_relative_ancestor) + }) + .then_with(|| self.path.cmp(&other.path)) + } +} + +fn make_atoms(query: &str, smart_case: bool) -> Vec { + let case = if smart_case { + CaseMatching::Smart + } else { + CaseMatching::Ignore + }; + query + .split_whitespace() + .map(|word| Atom::new(word, case, Normalization::Smart, AtomKind::Fuzzy, false)) + .collect() +} + +pub(crate) fn distance_between_paths(path: &RelPath, relative_to: &RelPath) -> usize { + let mut path_components = path.components(); + let mut relative_components = relative_to.components(); + + while path_components + .next() + .zip(relative_components.next()) + .map(|(path_component, relative_component)| path_component == relative_component) + .unwrap_or_default() + {} + path_components.count() + relative_components.count() + 1 +} + +fn get_filename_match_bonus( + candidate_buf: &str, + query_atoms: &[Atom], + matcher: &mut nucleo::Matcher, +) -> f64 { + let filename = match std::path::Path::new(candidate_buf).file_name() { + Some(f) => f.to_str().unwrap_or(""), + None => return 0.0, + }; + if filename.is_empty() || query_atoms.is_empty() { + return 0.0; + } + let mut buf = Vec::new(); + let haystack = Utf32Str::new(filename, &mut buf); + let mut total_score = 0u32; + for atom in query_atoms { + if let Some(score) = atom.score(haystack, matcher) { + total_score = total_score.saturating_add(score as u32); + } + } + total_score as f64 / filename.len().max(1) as f64 +} +struct Cancelled; + +fn path_match_helper<'a>( + matcher: &mut nucleo::Matcher, + atoms: &[Atom], + candidates: impl Iterator>, + results: &mut Vec, + worktree_id: usize, + path_prefix: &Arc, + root_is_file: bool, + relative_to: &Option>, + path_style: PathStyle, + cancel_flag: &AtomicBool, +) -> Result<(), Cancelled> { + let mut candidate_buf = if !path_prefix.is_empty() && !root_is_file { + let mut s = path_prefix.display(path_style).to_string(); + s.push_str(path_style.primary_separator()); + s + } else { + String::new() + }; + let path_prefix_len = candidate_buf.len(); + let mut buf = Vec::new(); + let mut matched_chars: Vec = Vec::new(); + let mut atom_matched_chars = Vec::new(); + for candidate in candidates { + buf.clear(); + matched_chars.clear(); + if cancel_flag.load(atomic::Ordering::Relaxed) { + return Err(Cancelled); + } + + candidate_buf.truncate(path_prefix_len); + if root_is_file { + candidate_buf.push_str(path_prefix.as_unix_str()); + } else { + candidate_buf.push_str(candidate.path.as_unix_str()); + } + + let haystack = Utf32Str::new(&candidate_buf, &mut buf); + + let mut total_score: u32 = 0; + let mut all_matched = true; + + for atom in atoms { + atom_matched_chars.clear(); + if let Some(score) = atom.indices(haystack, matcher, &mut atom_matched_chars) { + total_score = total_score.saturating_add(score as u32); + matched_chars.extend_from_slice(&atom_matched_chars); + } else { + all_matched = false; + break; + } + } + + if all_matched && !atoms.is_empty() { + matched_chars.sort_unstable(); + matched_chars.dedup(); + + let length_penalty = candidate_buf.len() as f64 * LENGTH_PENALTY; + let filename_bonus = get_filename_match_bonus(&candidate_buf, atoms, matcher); + let adjusted_score = total_score as f64 + filename_bonus - length_penalty; + let mut positions: Vec = candidate_buf + .char_indices() + .enumerate() + .filter_map(|(char_offset, (byte_offset, _))| { + matched_chars + .contains(&(char_offset as u32)) + .then_some(byte_offset) + }) + .collect(); + positions.sort_unstable(); + + results.push(PathMatch { + score: adjusted_score, + positions, + worktree_id, + path: if root_is_file { + Arc::clone(path_prefix) + } else { + candidate.path.into() + }, + path_prefix: if root_is_file { + RelPath::empty().into() + } else { + Arc::clone(path_prefix) + }, + is_dir: candidate.is_dir, + distance_to_relative_ancestor: relative_to + .as_ref() + .map_or(usize::MAX, |relative_to| { + distance_between_paths(candidate.path, relative_to.as_ref()) + }), + }); + } + } + Ok(()) +} + +pub fn match_fixed_path_set( + candidates: Vec, + worktree_id: usize, + worktree_root_name: Option>, + query: &str, + smart_case: bool, + max_results: usize, + path_style: PathStyle, +) -> Vec { + let mut config = nucleo::Config::DEFAULT; + config.set_match_paths(); + let mut matcher = matcher::get_matcher(config); + + let atoms = make_atoms(query, smart_case); + + let root_is_file = worktree_root_name.is_some() && candidates.iter().all(|c| c.path.is_empty()); + + let path_prefix = worktree_root_name.unwrap_or_else(|| RelPath::empty().into()); + + let mut results = Vec::new(); + + path_match_helper( + &mut matcher, + &atoms, + candidates.into_iter(), + &mut results, + worktree_id, + &path_prefix, + root_is_file, + &None, + path_style, + &AtomicBool::new(false), + ) + .ok(); + util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a)); + matcher::return_matcher(matcher); + results +} + +pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>( + candidate_sets: &'a [Set], + query: &str, + relative_to: &Option>, + smart_case: bool, + max_results: usize, + cancel_flag: &AtomicBool, + executor: BackgroundExecutor, +) -> Vec { + let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum(); + if path_count == 0 { + return Vec::new(); + } + + let path_style = candidate_sets[0].path_style(); + + let query = if path_style.is_windows() { + query.replace('\\', "/") + } else { + query.to_owned() + }; + + let atoms = make_atoms(&query, smart_case); + + let num_cpus = executor.num_cpus().min(path_count); + let segment_size = path_count.div_ceil(num_cpus); + let mut segment_results = (0..num_cpus) + .map(|_| Vec::with_capacity(max_results)) + .collect::>(); + let mut config = nucleo::Config::DEFAULT; + config.set_match_paths(); + let mut matchers = matcher::get_matchers(num_cpus, config); + executor + .scoped(|scope| { + for (segment_idx, (results, matcher)) in segment_results + .iter_mut() + .zip(matchers.iter_mut()) + .enumerate() + { + let atoms = atoms.clone(); + let relative_to = relative_to.clone(); + scope.spawn(async move { + let segment_start = segment_idx * segment_size; + let segment_end = segment_start + segment_size; + + let mut tree_start = 0; + for candidate_set in candidate_sets { + let tree_end = tree_start + candidate_set.len(); + + if tree_start < segment_end && segment_start < tree_end { + let start = tree_start.max(segment_start) - tree_start; + let end = tree_end.min(segment_end) - tree_start; + let candidates = candidate_set.candidates(start).take(end - start); + + if path_match_helper( + matcher, + &atoms, + candidates, + results, + candidate_set.id(), + &candidate_set.prefix(), + candidate_set.root_is_file(), + &relative_to, + path_style, + cancel_flag, + ) + .is_err() + { + break; + } + } + + if tree_end >= segment_end { + break; + } + tree_start = tree_end; + } + }); + } + }) + .await; + + matcher::return_matchers(matchers); + if cancel_flag.load(atomic::Ordering::Acquire) { + return Vec::new(); + } + + let mut results = segment_results.concat(); + util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a)); + results +} diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index cd037786a399eb979fd5d9053c57efe3100dd473..628e979aab939a74bb4838477ae3e3657e2c91bc 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -52,6 +52,7 @@ fancy-regex.workspace = true fs.workspace = true futures.workspace = true fuzzy.workspace = true +fuzzy_nucleo.workspace = true git.workspace = true git_hosting_providers.workspace = true globset.workspace = true diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 0ec3366ca8f9f6c6e4e3cbd411e1894de4d0f2b8..b90972b3489c25f8a2bf10d7dbdb6d6cfe0c4c6c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -6186,6 +6186,76 @@ impl<'a> Iterator for PathMatchCandidateSetIter<'a> { } } +impl<'a> fuzzy_nucleo::PathMatchCandidateSet<'a> for PathMatchCandidateSet { + type Candidates = PathMatchCandidateSetNucleoIter<'a>; + fn id(&self) -> usize { + self.snapshot.id().to_usize() + } + fn len(&self) -> usize { + match self.candidates { + Candidates::Files => { + if self.include_ignored { + self.snapshot.file_count() + } else { + self.snapshot.visible_file_count() + } + } + Candidates::Directories => { + if self.include_ignored { + self.snapshot.dir_count() + } else { + self.snapshot.visible_dir_count() + } + } + Candidates::Entries => { + if self.include_ignored { + self.snapshot.entry_count() + } else { + self.snapshot.visible_entry_count() + } + } + } + } + fn prefix(&self) -> Arc { + if self.snapshot.root_entry().is_some_and(|e| e.is_file()) || self.include_root_name { + self.snapshot.root_name().into() + } else { + RelPath::empty().into() + } + } + fn root_is_file(&self) -> bool { + self.snapshot.root_entry().is_some_and(|f| f.is_file()) + } + fn path_style(&self) -> PathStyle { + self.snapshot.path_style() + } + fn candidates(&'a self, start: usize) -> Self::Candidates { + PathMatchCandidateSetNucleoIter { + traversal: match self.candidates { + Candidates::Directories => self.snapshot.directories(self.include_ignored, start), + Candidates::Files => self.snapshot.files(self.include_ignored, start), + Candidates::Entries => self.snapshot.entries(self.include_ignored, start), + }, + } + } +} + +pub struct PathMatchCandidateSetNucleoIter<'a> { + traversal: Traversal<'a>, +} + +impl<'a> Iterator for PathMatchCandidateSetNucleoIter<'a> { + type Item = fuzzy_nucleo::PathMatchCandidate<'a>; + fn next(&mut self) -> Option { + self.traversal + .next() + .map(|entry| fuzzy_nucleo::PathMatchCandidate { + is_dir: entry.kind.is_dir(), + path: &entry.path, + }) + } +} + impl EventEmitter for Project {} impl<'a> From<&'a ProjectPath> for SettingsLocation<'a> { From 1dc3bb90e96be26cab72e7392c4042e1e5d0d71a Mon Sep 17 00:00:00 2001 From: Pratik Karki Date: Tue, 7 Apr 2026 17:10:55 +0545 Subject: [PATCH 09/22] Fix pane::RevealInProjectPanel to focus/open project panel for non-project buffers (#51246) Update how `workspace::pane::Pane` handles the `RevealInProjectPanel` action so as to display a notification when the user attempts to reveal an unsaved buffer or a file that does not belong to any of the open projects. Closes #23967 Release Notes: - Update `pane: reveal in project panel` to display a notification when the user attempts to use it with an unsaved buffer or a file that is not part of the open projects --------- Signed-off-by: Pratik Karki Co-authored-by: dino --- .../project_panel/src/project_panel_tests.rs | 146 +++++++++++++++++- crates/workspace/src/pane.rs | 66 +++++++- 2 files changed, 203 insertions(+), 9 deletions(-) diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 55b53cde8b6252f8b9732cf4effc35ea53c073e0..603cfd892a218d866383f485d058296ad179da05 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -11,7 +11,7 @@ use std::path::{Path, PathBuf}; use util::{path, paths::PathStyle, rel_path::rel_path}; use workspace::{ AppState, ItemHandle, MultiWorkspace, Pane, Workspace, - item::{Item, ProjectItem}, + item::{Item, ProjectItem, test::TestItem}, register_project_item, }; @@ -6015,6 +6015,150 @@ async fn test_explicit_reveal(cx: &mut gpui::TestAppContext) { ); } +#[gpui::test] +async fn test_reveal_in_project_panel_notifications(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/workspace", + json!({ + "README.md": "" + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/workspace".as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let panel = workspace.update_in(cx, ProjectPanel::new); + cx.run_until_parked(); + + // Ensure that, attempting to run `pane: reveal in project panel` without + // any active item does nothing, i.e., does not focus the project panel but + // it also does not show a notification. + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an invisible worktree entry" + ); + + panel.workspace.update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_none(), + "Workspace should not have an active item" + ); + assert_eq!( + workspace.notification_ids(), + vec![], + "No notification should be shown when there's no active item" + ); + }).unwrap(); + }); + + // Create a file in a different folder than the one in the project so we can + // later open it and ensure that, attempting to reveal it in the project + // panel shows a notification and does not focus the project panel. + fs.insert_tree( + "/external", + json!({ + "file.txt": "External File", + }), + ) + .await; + + let (worktree, _) = project + .update(cx, |project, cx| { + project.find_or_create_worktree("/external/file.txt", false, cx) + }) + .await + .unwrap(); + + workspace + .update_in(cx, |workspace, window, cx| { + let worktree_id = worktree.read(cx).id(); + let path = rel_path("").into(); + let project_path = ProjectPath { worktree_id, path }; + + workspace.open_path(project_path, None, true, window, cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an invisible worktree entry" + ); + + panel.workspace.update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_some(), + "Workspace should have an active item" + ); + + let notification_ids = workspace.notification_ids(); + assert_eq!( + notification_ids.len(), + 1, + "A notification should be shown when trying to reveal an invisible worktree entry" + ); + + workspace.dismiss_notification(¬ification_ids[0], cx); + assert_eq!( + workspace.notification_ids().len(), + 0, + "No notifications should be left after dismissing" + ); + }).unwrap(); + }); + + // Create an empty buffer so we can ensure that, attempting to reveal it in + // the project panel shows a notification and does not focus the project + // panel. + let pane = workspace.update(cx, |workspace, _| workspace.active_pane().clone()); + pane.update_in(cx, |pane, window, cx| { + let item = cx.new(|cx| TestItem::new(cx).with_label("Unsaved buffer")); + pane.add_item(Box::new(item), false, false, None, window, cx); + }); + + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an unsaved buffer" + ); + + panel + .workspace + .update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_some(), + "Workspace should have an active item" + ); + + let notification_ids = workspace.notification_ids(); + assert_eq!( + notification_ids.len(), + 1, + "A notification should be shown when trying to reveal an unsaved buffer" + ); + }) + .unwrap(); + }); +} + #[gpui::test] async fn test_creating_excluded_entries(cx: &mut gpui::TestAppContext) { init_test(cx); diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 27cc96ae80a010db2dd5357a9a0bc037ca762875..a09ba73add7e94fbe6910eb400b1364bd21cd313 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -10,7 +10,10 @@ use crate::{ TabContentParams, TabTooltipContent, WeakItemHandle, }, move_item, - notifications::NotifyResultExt, + notifications::{ + NotificationId, NotifyResultExt, show_app_notification, + simple_message_notification::MessageNotification, + }, toolbar::Toolbar, workspace_settings::{AutosaveSetting, FocusFollowsMouse, TabBarSettings, WorkspaceSettings}, }; @@ -4400,17 +4403,64 @@ impl Render for Pane { )) .on_action( cx.listener(|pane: &mut Self, action: &RevealInProjectPanel, _, cx| { + let Some(active_item) = pane.active_item() else { + return; + }; + let entry_id = action .entry_id .map(ProjectEntryId::from_proto) - .or_else(|| pane.active_item()?.project_entry_ids(cx).first().copied()); - if let Some(entry_id) = entry_id { - pane.project - .update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry_id)) - }) - .ok(); + .or_else(|| active_item.project_entry_ids(cx).first().copied()); + + let show_reveal_error_toast = |display_name: &str, cx: &mut App| { + let notification_id = NotificationId::unique::(); + let message = SharedString::from(format!( + "\"{display_name}\" is not part of any open projects." + )); + + show_app_notification(notification_id, cx, move |cx| { + let message = message.clone(); + cx.new(|cx| MessageNotification::new(message, cx)) + }); + }; + + let Some(entry_id) = entry_id else { + // When working with an unsaved buffer, display a toast + // informing the user that the buffer is not present in + // any of the open projects and stop execution, as we + // don't want to open the project panel. + let display_name = active_item + .tab_tooltip_text(cx) + .unwrap_or_else(|| active_item.tab_content_text(0, cx)); + + return show_reveal_error_toast(&display_name, cx); + }; + + // We'll now check whether the entry belongs to a visible + // worktree and, if that's not the case, it means the user + // is interacting with a file that does not belong to any of + // the open projects, so we'll show a toast informing them + // of this and stop execution. + let display_name = pane + .project + .read_with(cx, |project, cx| { + project + .worktree_for_entry(entry_id, cx) + .filter(|worktree| !worktree.read(cx).is_visible()) + .map(|worktree| worktree.read(cx).root_name_str().to_string()) + }) + .ok() + .flatten(); + + if let Some(display_name) = display_name { + return show_reveal_error_toast(&display_name, cx); } + + pane.project + .update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry_id)) + }) + .log_err(); }), ) .on_action(cx.listener(|_, _: &menu::Cancel, window, cx| { From eaf14d028a6c9cca193f725871116cd05a21c305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Soares?= <37777652+Dnreikronos@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:12:30 -0300 Subject: [PATCH 10/22] gpui: Fix SVG renderer not rendering text when system fonts are unavailable (#51623) Closes #51466 Before you mark this PR as ready for review, make sure that you have: - [x] Added a solid test coverage and/or screenshots from doing manual testing - [x] Done a self-review taking into account security and performance aspects - [ ] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) Release Notes: - Fixed mermaid diagrams not showing text in markdown preview by bundling fallback fonts and fixing generic font family resolution in the SVG renderer. --- crates/gpui/src/svg_renderer.rs | 127 ++++++++++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 7 deletions(-) diff --git a/crates/gpui/src/svg_renderer.rs b/crates/gpui/src/svg_renderer.rs index 8653ab9b162031772ab29367b60ff988e33cd823..a766a25cc1ef66039f5b2a1d0aeaab51ace89578 100644 --- a/crates/gpui/src/svg_renderer.rs +++ b/crates/gpui/src/svg_renderer.rs @@ -105,18 +105,36 @@ pub enum SvgSize { impl SvgRenderer { /// Creates a new SVG renderer with the provided asset source. pub fn new(asset_source: Arc) -> Self { - static FONT_DB: LazyLock> = LazyLock::new(|| { + static SYSTEM_FONT_DB: LazyLock> = LazyLock::new(|| { let mut db = usvg::fontdb::Database::new(); db.load_system_fonts(); Arc::new(db) }); + + let fontdb = { + let mut db = (**SYSTEM_FONT_DB).clone(); + load_bundled_fonts(&*asset_source, &mut db); + fix_generic_font_families(&mut db); + Arc::new(db) + }; + let default_font_resolver = usvg::FontResolver::default_font_selector(); let font_resolver = Box::new( move |font: &usvg::Font, db: &mut Arc| { if db.is_empty() { - *db = FONT_DB.clone(); + *db = fontdb.clone(); + } + if let Some(id) = default_font_resolver(font, db) { + return Some(id); } - default_font_resolver(font, db) + // fontdb doesn't recognize CSS system font keywords like "system-ui" + // or "ui-sans-serif", so fall back to sans-serif before any face. + let sans_query = usvg::fontdb::Query { + families: &[usvg::fontdb::Family::SansSerif], + ..Default::default() + }; + db.query(&sans_query) + .or_else(|| db.faces().next().map(|f| f.id)) }, ); let default_fallback_selection = usvg::FontResolver::default_fallback_selector(); @@ -226,14 +244,69 @@ impl SvgRenderer { } } +fn load_bundled_fonts(asset_source: &dyn AssetSource, db: &mut usvg::fontdb::Database) { + let font_paths = [ + "fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf", + "fonts/lilex/Lilex-Regular.ttf", + ]; + for path in font_paths { + match asset_source.load(path) { + Ok(Some(data)) => db.load_font_data(data.into_owned()), + Ok(None) => log::warn!("Bundled font not found: {path}"), + Err(error) => log::warn!("Failed to load bundled font {path}: {error}"), + } + } +} + +// fontdb defaults generic families to Microsoft fonts ("Arial", "Times New Roman") +// which aren't installed on most Linux systems. fontconfig normally overrides these, +// but when it fails the defaults remain and all generic family queries return None. +fn fix_generic_font_families(db: &mut usvg::fontdb::Database) { + use usvg::fontdb::{Family, Query}; + + let families_and_fallbacks: &[(Family<'_>, &str)] = &[ + (Family::SansSerif, "IBM Plex Sans"), + // No serif font bundled; use sans-serif as best available fallback. + (Family::Serif, "IBM Plex Sans"), + (Family::Monospace, "Lilex"), + (Family::Cursive, "IBM Plex Sans"), + (Family::Fantasy, "IBM Plex Sans"), + ]; + + for (family, fallback_name) in families_and_fallbacks { + let query = Query { + families: &[*family], + ..Default::default() + }; + if db.query(&query).is_none() { + match family { + Family::SansSerif => db.set_sans_serif_family(*fallback_name), + Family::Serif => db.set_serif_family(*fallback_name), + Family::Monospace => db.set_monospace_family(*fallback_name), + Family::Cursive => db.set_cursive_family(*fallback_name), + Family::Fantasy => db.set_fantasy_family(*fallback_name), + _ => {} + } + } + } +} + #[cfg(test)] mod tests { use super::*; + use usvg::fontdb::{Database, Family, Query}; const IBM_PLEX_REGULAR: &[u8] = include_bytes!("../../../assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf"); const LILEX_REGULAR: &[u8] = include_bytes!("../../../assets/fonts/lilex/Lilex-Regular.ttf"); + fn db_with_bundled_fonts() -> Database { + let mut db = Database::new(); + db.load_font_data(IBM_PLEX_REGULAR.to_vec()); + db.load_font_data(LILEX_REGULAR.to_vec()); + db + } + #[test] fn test_is_emoji_presentation() { let cases = [ @@ -266,11 +339,33 @@ mod tests { } #[test] - fn test_select_emoji_font_skips_family_without_glyph() { - let mut db = usvg::fontdb::Database::new(); + fn fix_generic_font_families_sets_all_families() { + let mut db = db_with_bundled_fonts(); + fix_generic_font_families(&mut db); + + let families = [ + Family::SansSerif, + Family::Serif, + Family::Monospace, + Family::Cursive, + Family::Fantasy, + ]; - db.load_font_data(IBM_PLEX_REGULAR.to_vec()); - db.load_font_data(LILEX_REGULAR.to_vec()); + for family in families { + let query = Query { + families: &[family], + ..Default::default() + }; + assert!( + db.query(&query).is_some(), + "Expected generic family {family:?} to resolve after fix_generic_font_families" + ); + } + } + + #[test] + fn test_select_emoji_font_skips_family_without_glyph() { + let mut db = db_with_bundled_fonts(); let ibm_plex_sans = db .query(&usvg::fontdb::Query { @@ -294,4 +389,22 @@ mod tests { assert!(!font_has_char(&db, ibm_plex_sans, '│')); assert!(font_has_char(&db, selected, '│')); } + + #[test] + fn fix_generic_font_families_monospace_resolves_to_lilex() { + let mut db = db_with_bundled_fonts(); + fix_generic_font_families(&mut db); + + let query = Query { + families: &[Family::Monospace], + ..Default::default() + }; + let id = db.query(&query).expect("Monospace should resolve"); + let face = db.face(id).expect("Face should exist"); + assert!( + face.families.iter().any(|(name, _)| name.contains("Lilex")), + "Monospace should map to Lilex, got {:?}", + face.families + ); + } } From 0bde5094f695c9ddf4e5fa591712baab546d3b4b Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:13:05 -0300 Subject: [PATCH 11/22] agent_ui: Set max-width for thread view content (#52730) This PR adds a configurable max-width to the agent panel. This will be particularly useful when opting into an agentic-first layout where the thread will be at the center of the UI (with the panel most likely full-screen'ed, which is why I'm also adding here the button to make it full screen in the toolbar). The default max-width is 850, which is a bit bigger than the one generally considered as a standard (~66 characters wide, which usually sums up to 750 pixels). Release Notes: - Agent: Added a max-width to the thread view for better readability, particularly when the panel is zoomed in. --- assets/settings/default.json | 3 + crates/agent/src/tool_permissions.rs | 1 + crates/agent_settings/src/agent_settings.rs | 2 + crates/agent_ui/src/agent_panel.rs | 85 +++--- crates/agent_ui/src/agent_ui.rs | 1 + .../src/conversation_view/thread_view.rs | 242 ++++++++++-------- crates/settings_content/src/agent.rs | 6 + crates/settings_ui/src/page_data.rs | 20 +- 8 files changed, 203 insertions(+), 157 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 63e906e3b11206fc458f8d7353f3ecba0abeb825..a32e1b27aee08bf2676922fea3790a99b7d7844b 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -965,6 +965,9 @@ "default_width": 640, // Default height when the agent panel is docked to the bottom. "default_height": 320, + // Maximum content width when the agent panel is wider than this value. + // Content will be centered within the panel. + "max_content_width": 850, // The default model to use when creating new threads. "default_model": { // The provider to use. diff --git a/crates/agent/src/tool_permissions.rs b/crates/agent/src/tool_permissions.rs index 58e779da59aef176464839ed6f2d6a5c16e4bc12..ff9e735b6c4181588ed5cddbd6dada7fbae5f18f 100644 --- a/crates/agent/src/tool_permissions.rs +++ b/crates/agent/src/tool_permissions.rs @@ -574,6 +574,7 @@ mod tests { flexible: true, default_width: px(300.), default_height: px(600.), + max_content_width: px(850.), default_model: None, inline_assistant_model: None, inline_assistant_use_streaming_tools: false, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 0c68d2f25d54f966d1cc0a93476457bbba79c959..5d6dca9322482daecf7525f79ead63b4471b7a53 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -154,6 +154,7 @@ pub struct AgentSettings { pub sidebar_side: SidebarDockPosition, pub default_width: Pixels, pub default_height: Pixels, + pub max_content_width: Pixels, pub default_model: Option, pub inline_assistant_model: Option, pub inline_assistant_use_streaming_tools: bool, @@ -600,6 +601,7 @@ impl Settings for AgentSettings { sidebar_side: agent.sidebar_side.unwrap(), default_width: px(agent.default_width.unwrap()), default_height: px(agent.default_height.unwrap()), + max_content_width: px(agent.max_content_width.unwrap()), flexible: agent.flexible.unwrap(), default_model: Some(agent.default_model.unwrap()), inline_assistant_model: agent.inline_assistant_model, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 8f456e0e955b823a5bbaf2815df3b409441bb0af..01b897fc63da76247b5624f8316ea06b2c1f85e5 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -3186,17 +3186,11 @@ impl AgentPanel { fn render_panel_options_menu( &self, - window: &mut Window, + _window: &mut Window, cx: &mut Context, ) -> impl IntoElement { let focus_handle = self.focus_handle(cx); - let full_screen_label = if self.is_zoomed(window, cx) { - "Disable Full Screen" - } else { - "Enable Full Screen" - }; - let conversation_view = match &self.active_view { ActiveView::AgentThread { conversation_view } => Some(conversation_view.clone()), _ => None, @@ -3272,8 +3266,7 @@ impl AgentPanel { .action("Profiles", Box::new(ManageProfiles::default())) .action("Settings", Box::new(OpenSettings)) .separator() - .action("Toggle Threads Sidebar", Box::new(ToggleWorkspaceSidebar)) - .action(full_screen_label, Box::new(ToggleZoom)); + .action("Toggle Threads Sidebar", Box::new(ToggleWorkspaceSidebar)); if has_auth_methods { menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent)) @@ -3709,21 +3702,37 @@ impl AgentPanel { ); let is_full_screen = self.is_zoomed(window, cx); + let full_screen_button = if is_full_screen { + IconButton::new("disable-full-screen", IconName::Minimize) + .icon_size(IconSize::Small) + .tooltip(move |_, cx| Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx)) + .on_click(cx.listener(move |this, _, window, cx| { + this.toggle_zoom(&ToggleZoom, window, cx); + })) + } else { + IconButton::new("enable-full-screen", IconName::Maximize) + .icon_size(IconSize::Small) + .tooltip(move |_, cx| Tooltip::for_action("Enable Full Screen", &ToggleZoom, cx)) + .on_click(cx.listener(move |this, _, window, cx| { + this.toggle_zoom(&ToggleZoom, window, cx); + })) + }; let use_v2_empty_toolbar = has_v2_flag && is_empty_state && !is_in_history_or_config; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + let base_container = h_flex() - .id("agent-panel-toolbar") - .h(Tab::container_height(cx)) - .max_w_full() + .size_full() + // TODO: This is only until we remove Agent settings from the panel. + .when(!is_in_history_or_config, |this| { + this.max_w(max_content_width).mx_auto() + }) .flex_none() .justify_between() - .gap_2() - .bg(cx.theme().colors().tab_bar_background) - .border_b_1() - .border_color(cx.theme().colors().border); + .gap_2(); - if use_v2_empty_toolbar { + let toolbar_content = if use_v2_empty_toolbar { let (chevron_icon, icon_color, label_color) = if self.new_thread_menu_handle.is_deployed() { (IconName::ChevronUp, Color::Accent, Color::Accent) @@ -3805,20 +3814,7 @@ impl AgentPanel { cx, )) }) - .when(is_full_screen, |this| { - this.child( - IconButton::new("disable-full-screen", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(move |_, cx| { - Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx) - }) - .on_click({ - cx.listener(move |_, _, window, cx| { - window.dispatch_action(ToggleZoom.boxed_clone(), cx); - }) - }), - ) - }) + .child(full_screen_button) .child(self.render_panel_options_menu(window, cx)), ) .into_any_element() @@ -3871,24 +3867,21 @@ impl AgentPanel { cx, )) }) - .when(is_full_screen, |this| { - this.child( - IconButton::new("disable-full-screen", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(move |_, cx| { - Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx) - }) - .on_click({ - cx.listener(move |_, _, window, cx| { - window.dispatch_action(ToggleZoom.boxed_clone(), cx); - }) - }), - ) - }) + .child(full_screen_button) .child(self.render_panel_options_menu(window, cx)), ) .into_any_element() - } + }; + + h_flex() + .id("agent-panel-toolbar") + .h(Tab::container_height(cx)) + .flex_shrink_0() + .max_w_full() + .bg(cx.theme().colors().tab_bar_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child(toolbar_content) } fn render_worktree_creation_status(&self, cx: &mut Context) -> Option { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 9daa7c6cd83c276aa99adc9e3aae3e6c82c5ba88..58b52d9ea2eb10a4f7f483402b98c4be4b08924f 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -742,6 +742,7 @@ mod tests { flexible: true, default_width: px(300.), default_height: px(600.), + max_content_width: px(850.), default_model: None, inline_assistant_model: None, inline_assistant_use_streaming_tools: false, diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index ff3dab1170064e058c0ebb44505c0906349517ee..27ebadade8047db5f2b4de63c5c3731708d9af59 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -3014,14 +3014,12 @@ impl ThreadView { let is_done = thread.read(cx).status() == ThreadStatus::Idle; let is_canceled_or_failed = self.is_subagent_canceled_or_failed(cx); + let max_content_width = AgentSettings::get_global(cx).max_content_width; + Some( h_flex() - .h(Tab::container_height(cx)) - .pl_2() - .pr_1p5() .w_full() - .justify_between() - .gap_1() + .h(Tab::container_height(cx)) .border_b_1() .when(is_done && is_canceled_or_failed, |this| { this.border_dashed() @@ -3030,50 +3028,61 @@ impl ThreadView { .bg(cx.theme().colors().editor_background.opacity(0.2)) .child( h_flex() - .flex_1() - .gap_2() + .size_full() + .max_w(max_content_width) + .mx_auto() + .pl_2() + .pr_1() + .flex_shrink_0() + .justify_between() + .gap_1() .child( - Icon::new(IconName::ForwardArrowUp) - .size(IconSize::Small) - .color(Color::Muted), + h_flex() + .flex_1() + .gap_2() + .child( + Icon::new(IconName::ForwardArrowUp) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(self.title_editor.clone()) + .when(is_done && is_canceled_or_failed, |this| { + this.child(Icon::new(IconName::Close).color(Color::Error)) + }) + .when(is_done && !is_canceled_or_failed, |this| { + this.child(Icon::new(IconName::Check).color(Color::Success)) + }), ) - .child(self.title_editor.clone()) - .when(is_done && is_canceled_or_failed, |this| { - this.child(Icon::new(IconName::Close).color(Color::Error)) - }) - .when(is_done && !is_canceled_or_failed, |this| { - this.child(Icon::new(IconName::Check).color(Color::Success)) - }), - ) - .child( - h_flex() - .gap_0p5() - .when(!is_done, |this| { - this.child( - IconButton::new("stop_subagent", IconName::Stop) - .icon_size(IconSize::Small) - .icon_color(Color::Error) - .tooltip(Tooltip::text("Stop Subagent")) - .on_click(move |_, _, cx| { - thread.update(cx, |thread, cx| { - thread.cancel(cx).detach(); - }); - }), - ) - }) .child( - IconButton::new("minimize_subagent", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Minimize Subagent")) - .on_click(move |_, window, cx| { - let _ = server_view.update(cx, |server_view, cx| { - server_view.navigate_to_session( - parent_session_id.clone(), - window, - cx, - ); - }); - }), + h_flex() + .gap_0p5() + .when(!is_done, |this| { + this.child( + IconButton::new("stop_subagent", IconName::Stop) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .tooltip(Tooltip::text("Stop Subagent")) + .on_click(move |_, _, cx| { + thread.update(cx, |thread, cx| { + thread.cancel(cx).detach(); + }); + }), + ) + }) + .child( + IconButton::new("minimize_subagent", IconName::Dash) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Minimize Subagent")) + .on_click(move |_, window, cx| { + let _ = server_view.update(cx, |server_view, cx| { + server_view.navigate_to_session( + parent_session_id.clone(), + window, + cx, + ); + }); + }), + ), ), ), ) @@ -3099,6 +3108,8 @@ impl ThreadView { (IconName::Maximize, "Expand Message Editor") }; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + v_flex() .on_action(cx.listener(Self::expand_message_editor)) .p_2() @@ -3113,73 +3124,80 @@ impl ThreadView { }) .child( v_flex() - .relative() - .size_full() - .when(v2_empty_state, |this| this.flex_1()) - .pt_1() - .pr_2p5() - .child(self.message_editor.clone()) - .when(!v2_empty_state, |this| { - this.child( - h_flex() - .absolute() - .top_0() - .right_0() - .opacity(0.5) - .hover(|this| this.opacity(1.0)) - .child( - IconButton::new("toggle-height", expand_icon) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .tooltip({ - move |_window, cx| { - Tooltip::for_action_in( - expand_tooltip, - &ExpandMessageEditor, - &focus_handle, - cx, - ) - } - }) - .on_click(cx.listener(|this, _, window, cx| { - this.expand_message_editor( - &ExpandMessageEditor, - window, - cx, - ); - })), - ), - ) - }), - ) - .child( - h_flex() - .flex_none() - .flex_wrap() - .justify_between() + .flex_1() + .w_full() + .max_w(max_content_width) + .mx_auto() .child( - h_flex() - .gap_0p5() - .child(self.render_add_context_button(cx)) - .child(self.render_follow_toggle(cx)) - .children(self.render_fast_mode_control(cx)) - .children(self.render_thinking_control(cx)), + v_flex() + .relative() + .size_full() + .when(v2_empty_state, |this| this.flex_1()) + .pt_1() + .pr_2p5() + .child(self.message_editor.clone()) + .when(!v2_empty_state, |this| { + this.child( + h_flex() + .absolute() + .top_0() + .right_0() + .opacity(0.5) + .hover(|this| this.opacity(1.0)) + .child( + IconButton::new("toggle-height", expand_icon) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip({ + move |_window, cx| { + Tooltip::for_action_in( + expand_tooltip, + &ExpandMessageEditor, + &focus_handle, + cx, + ) + } + }) + .on_click(cx.listener(|this, _, window, cx| { + this.expand_message_editor( + &ExpandMessageEditor, + window, + cx, + ); + })), + ), + ) + }), ) .child( h_flex() - .gap_1() - .children(self.render_token_usage(cx)) - .children(self.profile_selector.clone()) - .map(|this| { - // Either config_options_view OR (mode_selector + model_selector) - match self.config_options_view.clone() { - Some(config_view) => this.child(config_view), - None => this - .children(self.mode_selector.clone()) - .children(self.model_selector.clone()), - } - }) - .child(self.render_send_button(cx)), + .flex_none() + .flex_wrap() + .justify_between() + .child( + h_flex() + .gap_0p5() + .child(self.render_add_context_button(cx)) + .child(self.render_follow_toggle(cx)) + .children(self.render_fast_mode_control(cx)) + .children(self.render_thinking_control(cx)), + ) + .child( + h_flex() + .gap_1() + .children(self.render_token_usage(cx)) + .children(self.profile_selector.clone()) + .map(|this| { + // Either config_options_view OR (mode_selector + model_selector) + match self.config_options_view.clone() { + Some(config_view) => this.child(config_view), + None => this + .children(self.mode_selector.clone()) + .children(self.model_selector.clone()), + } + }) + .child(self.render_send_button(cx)), + ), ), ) .into_any() @@ -8559,8 +8577,12 @@ impl Render for ThreadView { let has_messages = self.list_state.item_count() > 0; let v2_empty_state = cx.has_flag::() && !has_messages; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + let conversation = v_flex() - .when(!v2_empty_state, |this| this.flex_1()) + .mx_auto() + .max_w(max_content_width) + .when(!v2_empty_state, |this| this.flex_1().size_full()) .map(|this| { let this = this.when(self.resumed_without_history, |this| { this.child(Self::render_resume_notice(cx)) diff --git a/crates/settings_content/src/agent.rs b/crates/settings_content/src/agent.rs index 5b1b3c014f8c538cb0dff506e05d84a80dc863d1..7a9a1ddb16ac91f90f73e17b3972cd31536d7a66 100644 --- a/crates/settings_content/src/agent.rs +++ b/crates/settings_content/src/agent.rs @@ -128,6 +128,12 @@ pub struct AgentSettingsContent { /// Default: 320 #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] pub default_height: Option, + /// Maximum content width in pixels for the agent panel. Content will be + /// centered when the panel is wider than this value. + /// + /// Default: 850 + #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] + pub max_content_width: Option, /// The default model to use when creating new chats and for other features when a specific model is not specified. pub default_model: Option, /// Favorite models to show at the top of the model selector. diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 9978832c05bb29c97f118fccbe301214d81fa0c6..259ee2cf261f9e435a5431ddf3c470640daf41f9 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -5737,7 +5737,7 @@ fn panels_page() -> SettingsPage { ] } - fn agent_panel_section() -> [SettingsPageItem; 6] { + fn agent_panel_section() -> [SettingsPageItem; 7] { [ SettingsPageItem::SectionHeader("Agent Panel"), SettingsPageItem::SettingItem(SettingItem { @@ -5812,6 +5812,24 @@ fn panels_page() -> SettingsPage { metadata: None, files: USER, }), + SettingsPageItem::SettingItem(SettingItem { + title: "Agent Panel Max Content Width", + description: "Maximum content width in pixels. Content will be centered when the panel is wider than this value.", + field: Box::new(SettingField { + json_path: Some("agent.max_content_width"), + pick: |settings_content| { + settings_content.agent.as_ref()?.max_content_width.as_ref() + }, + write: |settings_content, value| { + settings_content + .agent + .get_or_insert_default() + .max_content_width = value; + }, + }), + metadata: None, + files: USER, + }), ] } From 8292ab440d87172c6663e2dffa1fad33d10ddb11 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:26:09 -0300 Subject: [PATCH 12/22] collab_panel: Make channel items have a fixed height (#53304) Follow-up to https://github.com/zed-industries/zed/pull/53290 This PR fixes a mistake I pushed before of making the `ListItem`'s height method take pixels instead of a scalable unit like rems. Now, it takes `DefiniteLength` which can house both values, meaning we should be clear to set an explicit height for all of these items while still preserving font-size scaling. Release Notes: - N/A --- crates/collab_ui/src/collab_panel.rs | 7 +++++++ crates/ui/src/components/list/list_item.rs | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 1e1aab3b9d4aa0e48ad4a84ec77bdc6dff51c7f5..7dc807998760a8e65d373164eec5c7663171e5d0 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -1181,6 +1181,7 @@ impl CollabPanel { .into(); ListItem::new(project_id as usize) + .height(rems_from_px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.workspace @@ -1221,6 +1222,7 @@ impl CollabPanel { let id = peer_id.map_or(usize::MAX, |id| id.as_u64() as usize); ListItem::new(("screen", id)) + .height(rems_from_px(24.)) .toggle_state(is_selected) .start_slot( h_flex() @@ -1267,6 +1269,7 @@ impl CollabPanel { let has_channel_buffer_changed = channel_store.has_channel_buffer_changed(channel_id); ListItem::new("channel-notes") + .height(rems_from_px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.open_channel_notes(channel_id, window, cx); @@ -3207,9 +3210,12 @@ impl CollabPanel { (IconName::Star, Color::Default, "Add to Favorites") }; + let height = rems_from_px(24.); + h_flex() .id(ix) .group("") + .h(height) .w_full() .overflow_hidden() .when(!channel.is_root_channel(), |el| { @@ -3239,6 +3245,7 @@ impl CollabPanel { ) .child( ListItem::new(ix) + .height(height) // Add one level of depth for the disclosure arrow. .indent_level(depth + 1) .indent_step_size(px(20.)) diff --git a/crates/ui/src/components/list/list_item.rs b/crates/ui/src/components/list/list_item.rs index 9a764efd58cfd3365d92e534a715a0f23ce46e90..ece1fd3c61ec486c090808891a8eec662138b1b4 100644 --- a/crates/ui/src/components/list/list_item.rs +++ b/crates/ui/src/components/list/list_item.rs @@ -52,7 +52,7 @@ pub struct ListItem { overflow_x: bool, focused: Option, docked_right: bool, - height: Option, + height: Option, } impl ListItem { @@ -207,8 +207,8 @@ impl ListItem { self } - pub fn height(mut self, height: Pixels) -> Self { - self.height = Some(height); + pub fn height(mut self, height: impl Into) -> Self { + self.height = Some(height.into()); self } } From 3ed1c32bf9a1ebb485e3da6cabc8b3c0a423beea Mon Sep 17 00:00:00 2001 From: Xin Zhao Date: Tue, 7 Apr 2026 22:15:33 +0800 Subject: [PATCH 13/22] editor: Fix diagnostic rendering when semantic tokens set to full (#53008) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [ ] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #50212 There are two unreasonable coupling account for this issue, the coupling of `use_tree_sitter` with `languge_aware` in https://github.com/zed-industries/zed/blob/7892b932795911516f26f3c1c1c72249ed181ba8/crates/editor/src/element.rs#L3820-L3822 and the coupling of `language_aware` with `diagnostics` in https://github.com/zed-industries/zed/blob/7892b932795911516f26f3c1c1c72249ed181ba8/crates/language/src/buffer.rs#L3736-L3746 Because of these couplings, when the editor stops using Tree-sitter highlighting when `"semantic_tokens"` set to `"full"`, it also accidentally stops fetching diagnostic information. This is why error and warning underlines disappear. I’ve fixed this by adding a separate `use_tree_sitter` parameter to `highlighted_chunks`. This way, we can keep `language_aware` true to get the diagnostic data we need, but still decide whether or not to apply Tree-sitter highlights. I chose to fix this at the `highlighted_chunks` level because I’m worried that changing the logic in the deeper layers of the DisplayMap or Buffer might have too many side effects that are hard to predict. This approach feels like a safer way to solve the problem. Release Notes: - Fixed a bug where diagnostic underlines disappeared when "semantic_tokens" set to "full" --------- Co-authored-by: Kirill Bulatov --- crates/editor/src/display_map.rs | 48 +++++-- crates/editor/src/display_map/block_map.rs | 14 +- .../src/display_map/custom_highlights.rs | 9 +- crates/editor/src/display_map/fold_map.rs | 33 ++++- crates/editor/src/display_map/inlay_map.rs | 37 +++-- crates/editor/src/display_map/tab_map.rs | 71 ++++++++-- crates/editor/src/display_map/wrap_map.rs | 19 ++- crates/editor/src/editor.rs | 22 ++- crates/editor/src/element.rs | 17 ++- crates/editor/src/semantic_tokens.rs | 132 +++++++++++++++++- crates/language/src/buffer.rs | 43 +++++- crates/language/src/buffer_tests.rs | 8 +- crates/multi_buffer/src/multi_buffer.rs | 33 +++-- crates/multi_buffer/src/multi_buffer_tests.rs | 24 +++- crates/outline_panel/src/outline_panel.rs | 13 +- crates/project/src/lsp_store.rs | 15 +- .../tests/integration/project_tests.rs | 15 +- crates/vim/src/state.rs | 12 +- 18 files changed, 468 insertions(+), 97 deletions(-) diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index f95f1030276015af4825119fc98ac68b876d0e5f..7cb8040e282a47d27cf5d7b33e5453295b4f645f 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -98,7 +98,7 @@ use gpui::{ WeakEntity, }; use language::{ - Point, Subscription as BufferSubscription, + LanguageAwareStyling, Point, Subscription as BufferSubscription, language_settings::{AllLanguageSettings, LanguageSettings}, }; @@ -1769,7 +1769,10 @@ impl DisplaySnapshot { self.block_snapshot .chunks( BlockRow(display_row.0)..BlockRow(self.max_point().row().next_row().0), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, self.masked, Highlights::default(), ) @@ -1783,7 +1786,10 @@ impl DisplaySnapshot { self.block_snapshot .chunks( BlockRow(row)..BlockRow(row + 1), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, self.masked, Highlights::default(), ) @@ -1798,7 +1804,7 @@ impl DisplaySnapshot { pub fn chunks( &self, display_rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlight_styles: HighlightStyles, ) -> DisplayChunks<'_> { self.block_snapshot.chunks( @@ -1818,7 +1824,7 @@ impl DisplaySnapshot { pub fn highlighted_chunks<'a>( &'a self, display_rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, editor_style: &'a EditorStyle, ) -> impl Iterator> { self.chunks( @@ -1910,7 +1916,10 @@ impl DisplaySnapshot { let chunks = custom_highlights::CustomHighlightsChunks::new( multibuffer_range, - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, None, Some(&self.semantic_token_highlights), multibuffer, @@ -1961,7 +1970,14 @@ impl DisplaySnapshot { let mut line = String::new(); let range = display_row..display_row.next_row(); - for chunk in self.highlighted_chunks(range, false, editor_style) { + for chunk in self.highlighted_chunks( + range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + editor_style, + ) { line.push_str(chunk.text); let text_style = if let Some(style) = chunk.style { @@ -3388,7 +3404,14 @@ pub mod tests { let snapshot = map.update(cx, |map, cx| map.snapshot(cx)); let mut chunks = Vec::<(String, Option, Rgba)>::new(); - for chunk in snapshot.chunks(DisplayRow(0)..DisplayRow(5), true, Default::default()) { + for chunk in snapshot.chunks( + DisplayRow(0)..DisplayRow(5), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + Default::default(), + ) { let color = chunk .highlight_style .and_then(|style| style.color) @@ -3940,7 +3963,14 @@ pub mod tests { ) -> Vec<(String, Option, Option)> { let snapshot = map.update(cx, |map, cx| map.snapshot(cx)); let mut chunks: Vec<(String, Option, Option)> = Vec::new(); - for chunk in snapshot.chunks(rows, true, HighlightStyles::default()) { + for chunk in snapshot.chunks( + rows, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + HighlightStyles::default(), + ) { let syntax_color = chunk .syntax_highlight_id .and_then(|id| theme.get(id)?.color); diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 67318e3300e73085fe40c2e22edfcd06778902c8..17fa7e3de4a361f6728664e76368583788053cfd 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -9,7 +9,7 @@ use crate::{ }; use collections::{Bound, HashMap, HashSet}; use gpui::{AnyElement, App, EntityId, Pixels, Window}; -use language::{Patch, Point}; +use language::{LanguageAwareStyling, Patch, Point}; use multi_buffer::{ Anchor, ExcerptBoundaryInfo, MultiBuffer, MultiBufferOffset, MultiBufferPoint, MultiBufferRow, MultiBufferSnapshot, RowInfo, ToOffset, ToPoint as _, @@ -2140,7 +2140,10 @@ impl BlockSnapshot { pub fn text(&self) -> String { self.chunks( BlockRow(0)..self.transforms.summary().output_rows, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, false, Highlights::default(), ) @@ -2152,7 +2155,7 @@ impl BlockSnapshot { pub(crate) fn chunks<'a>( &'a self, rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, masked: bool, highlights: Highlights<'a>, ) -> BlockChunks<'a> { @@ -4300,7 +4303,10 @@ mod tests { let actual_text = blocks_snapshot .chunks( BlockRow(start_row as u32)..BlockRow(end_row as u32), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, false, Highlights::default(), ) diff --git a/crates/editor/src/display_map/custom_highlights.rs b/crates/editor/src/display_map/custom_highlights.rs index 39eabef2f9627b8088dc826ec64379bf76a6c9fa..6e93e562172decb0843da35c7f55fafd92ed21cc 100644 --- a/crates/editor/src/display_map/custom_highlights.rs +++ b/crates/editor/src/display_map/custom_highlights.rs @@ -1,6 +1,6 @@ use collections::BTreeMap; use gpui::HighlightStyle; -use language::Chunk; +use language::{Chunk, LanguageAwareStyling}; use multi_buffer::{MultiBufferChunks, MultiBufferOffset, MultiBufferSnapshot, ToOffset as _}; use std::{ cmp, @@ -34,7 +34,7 @@ impl<'a> CustomHighlightsChunks<'a> { #[ztracing::instrument(skip_all)] pub fn new( range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, text_highlights: Option<&'a TextHighlights>, semantic_token_highlights: Option<&'a SemanticTokensHighlights>, multibuffer_snapshot: &'a MultiBufferSnapshot, @@ -308,7 +308,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = CustomHighlightsChunks::new( MultiBufferOffset(0)..buffer_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, None, None, &buffer_snapshot, diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 1554bb96dab0e2f76a17df1396bd945f332af208..4c6c04b86cc3e2fb9ef10be58c14faae623dc65f 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -5,7 +5,7 @@ use super::{ inlay_map::{InlayBufferRows, InlayChunks, InlayEdit, InlayOffset, InlayPoint, InlaySnapshot}, }; use gpui::{AnyElement, App, ElementId, HighlightStyle, Pixels, SharedString, Stateful, Window}; -use language::{Edit, HighlightId, Point}; +use language::{Edit, HighlightId, LanguageAwareStyling, Point}; use multi_buffer::{ Anchor, AnchorRangeExt, MBTextSummary, MultiBufferOffset, MultiBufferRow, MultiBufferSnapshot, RowInfo, ToOffset, @@ -707,7 +707,10 @@ impl FoldSnapshot { pub fn text(&self) -> String { self.chunks( FoldOffset(MultiBufferOffset(0))..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|c| c.text) @@ -909,7 +912,7 @@ impl FoldSnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> FoldChunks<'a> { let mut transform_cursor = self @@ -954,7 +957,10 @@ impl FoldSnapshot { pub fn chars_at(&self, start: FoldPoint) -> impl '_ + Iterator { self.chunks( start.to_offset(self)..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .flat_map(|chunk| chunk.text.chars()) @@ -964,7 +970,10 @@ impl FoldSnapshot { pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> { self.chunks( start.to_offset(self)..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) } @@ -2131,7 +2140,14 @@ mod tests { let text = &expected_text[start.0.0..end.0.0]; assert_eq!( snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default() + ) .map(|c| c.text) .collect::(), text, @@ -2303,7 +2319,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = snapshot.chunks( FoldOffset(MultiBufferOffset(0))..FoldOffset(snapshot.len().0), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index 47ca295ccb1a08768ce129b92d10506294a9cf78..698b58682d7ef7682094e7728f419348fd5d32d9 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -10,7 +10,7 @@ use crate::{ inlays::{Inlay, InlayContent}, }; use collections::BTreeSet; -use language::{Chunk, Edit, Point, TextSummary}; +use language::{Chunk, Edit, LanguageAwareStyling, Point, TextSummary}; use multi_buffer::{ MBTextSummary, MultiBufferOffset, MultiBufferRow, MultiBufferRows, MultiBufferSnapshot, RowInfo, ToOffset, @@ -1200,7 +1200,7 @@ impl InlaySnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> InlayChunks<'a> { let mut cursor = self @@ -1234,9 +1234,16 @@ impl InlaySnapshot { #[cfg(test)] #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { - self.chunks(Default::default()..self.len(), false, Highlights::default()) - .map(|chunk| chunk.chunk.text) - .collect() + self.chunks( + Default::default()..self.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) + .map(|chunk| chunk.chunk.text) + .collect() } #[ztracing::instrument(skip_all)] @@ -1979,7 +1986,10 @@ mod tests { let actual_text = inlay_snapshot .chunks( range, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights { text_highlights: Some(&text_highlights), inlay_highlights: Some(&inlay_highlights), @@ -2158,7 +2168,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = snapshot.chunks( InlayOffset(MultiBufferOffset(0))..snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); @@ -2293,7 +2306,10 @@ mod tests { let chunks: Vec<_> = inlay_snapshot .chunks( InlayOffset(MultiBufferOffset(0))..inlay_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, highlights, ) .collect(); @@ -2408,7 +2424,10 @@ mod tests { let chunks: Vec<_> = inlay_snapshot .chunks( InlayOffset(MultiBufferOffset(0))..inlay_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, highlights, ) .collect(); diff --git a/crates/editor/src/display_map/tab_map.rs b/crates/editor/src/display_map/tab_map.rs index 187ed8614e01ddb8dcdae930fd484de9594cf63f..bb0e642df380e04fcfa9b9533f027be7171b4975 100644 --- a/crates/editor/src/display_map/tab_map.rs +++ b/crates/editor/src/display_map/tab_map.rs @@ -3,7 +3,7 @@ use super::{ fold_map::{self, Chunk, FoldChunks, FoldEdit, FoldPoint, FoldSnapshot}, }; -use language::Point; +use language::{LanguageAwareStyling, Point}; use multi_buffer::MultiBufferSnapshot; use std::{cmp, num::NonZeroU32, ops::Range}; use sum_tree::Bias; @@ -101,7 +101,10 @@ impl TabMap { let mut last_tab_with_changed_expansion_offset = None; 'outer: for chunk in old_snapshot.fold_snapshot.chunks( fold_edit.old.end..old_end_row_successor_offset, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) { let mut remaining_tabs = chunk.tabs; @@ -244,7 +247,14 @@ impl TabSnapshot { self.max_point() }; let first_line_chars = self - .chunks(range.start..line_end, false, Highlights::default()) + .chunks( + range.start..line_end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) .flat_map(|chunk| chunk.text.chars()) .take_while(|&c| c != '\n') .count() as u32; @@ -254,7 +264,10 @@ impl TabSnapshot { } else { self.chunks( TabPoint::new(range.end.row(), 0)..range.end, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .flat_map(|chunk| chunk.text.chars()) @@ -274,7 +287,7 @@ impl TabSnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> TabChunks<'a> { let (input_start, expanded_char_column, to_next_stop) = @@ -324,7 +337,10 @@ impl TabSnapshot { pub fn text(&self) -> String { self.chunks( TabPoint::zero()..self.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|chunk| chunk.text) @@ -1170,7 +1186,10 @@ mod tests { tab_snapshot .chunks( TabPoint::new(0, ix as u32)..tab_snapshot.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|c| c.text) @@ -1246,8 +1265,14 @@ mod tests { let mut chunks = Vec::new(); let mut was_tab = false; let mut text = String::new(); - for chunk in snapshot.chunks(start..snapshot.max_point(), false, Highlights::default()) - { + for chunk in snapshot.chunks( + start..snapshot.max_point(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) { if chunk.is_tab != was_tab { if !text.is_empty() { chunks.push((mem::take(&mut text), was_tab)); @@ -1296,7 +1321,14 @@ mod tests { // This should not panic. let result: String = tab_snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) .map(|c| c.text) .collect(); assert!(!result.is_empty()); @@ -1354,7 +1386,14 @@ mod tests { let expected_summary = TextSummary::from(expected_text.as_str()); assert_eq!( tabs_snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default() + ) .map(|c| c.text) .collect::(), expected_text, @@ -1436,7 +1475,10 @@ mod tests { let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let chunks = fold_snapshot.chunks( FoldOffset(MultiBufferOffset(0))..fold_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Default::default(), ); let mut cursor = TabStopCursor::new(chunks); @@ -1598,7 +1640,10 @@ mod tests { let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let chunks = fold_snapshot.chunks( FoldOffset(MultiBufferOffset(0))..fold_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Default::default(), ); let mut cursor = TabStopCursor::new(chunks); diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index d21642977ed923e15a583dfe767fd566e78c5de9..4ff11b1ef67971c5159a81278a5afaaaea171a28 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -5,7 +5,7 @@ use super::{ tab_map::{self, TabEdit, TabPoint, TabSnapshot}, }; use gpui::{App, AppContext as _, Context, Entity, Font, LineWrapper, Pixels, Task}; -use language::Point; +use language::{LanguageAwareStyling, Point}; use multi_buffer::{MultiBufferSnapshot, RowInfo}; use smol::future::yield_now; use std::{cmp, collections::VecDeque, mem, ops::Range, sync::LazyLock, time::Duration}; @@ -513,7 +513,10 @@ impl WrapSnapshot { let mut remaining = None; let mut chunks = new_tab_snapshot.chunks( TabPoint::new(edit.new_rows.start, 0)..new_tab_snapshot.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); let mut edit_transforms = Vec::::new(); @@ -656,7 +659,7 @@ impl WrapSnapshot { pub(crate) fn chunks<'a>( &'a self, rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> WrapChunks<'a> { let output_start = WrapPoint::new(rows.start, 0); @@ -960,7 +963,10 @@ impl WrapSnapshot { pub fn text_chunks(&self, wrap_row: WrapRow) -> impl Iterator { self.chunks( wrap_row..self.max_point().row() + WrapRow(1), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|h| h.text) @@ -1719,7 +1725,10 @@ mod tests { let actual_text = self .chunks( WrapRow(start_row)..WrapRow(end_row), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, Highlights::default(), ) .map(|c| c.text) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 6550d79c9f73799d37ccf6433db38f2719636ee6..ae852b1055b33f151b402ee999ce50ba064788a4 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -132,9 +132,9 @@ use language::{ AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, BufferSnapshot, Capability, CharClassifier, CharKind, CharScopeContext, CodeLabel, CursorShape, DiagnosticEntryRef, DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, - IndentSize, Language, LanguageName, LanguageRegistry, LanguageScope, LocalFile, OffsetRangeExt, - OutlineItem, Point, Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions, - WordsQuery, + IndentSize, Language, LanguageAwareStyling, LanguageName, LanguageRegistry, LanguageScope, + LocalFile, OffsetRangeExt, OutlineItem, Point, Selection, SelectionGoal, TextObject, + TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, AllLanguageSettings, LanguageSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, @@ -19147,7 +19147,13 @@ impl Editor { let range = buffer.anchor_before(rename_start)..buffer.anchor_after(rename_end); let mut old_highlight_id = None; let old_name: Arc = buffer - .chunks(rename_start..rename_end, true) + .chunks( + rename_start..rename_end, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) .map(|chunk| { if old_highlight_id.is_none() { old_highlight_id = chunk.syntax_highlight_id; @@ -25005,7 +25011,13 @@ impl Editor { selection.range() }; - let chunks = snapshot.chunks(range, true); + let chunks = snapshot.chunks( + range, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ); let mut lines = Vec::new(); let mut line: VecDeque = VecDeque::new(); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 7a532dc7a75ea3583456be6611ef072cd7692bc7..512fbb8855aa11d8c540065a55eb296919012821 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -51,7 +51,10 @@ use gpui::{ pattern_slash, point, px, quad, relative, size, solid_background, transparent_black, }; use itertools::Itertools; -use language::{HighlightedText, IndentGuideSettings, language_settings::ShowWhitespaceSetting}; +use language::{ + HighlightedText, IndentGuideSettings, LanguageAwareStyling, + language_settings::ShowWhitespaceSetting, +}; use markdown::Markdown; use multi_buffer::{ Anchor, ExcerptBoundaryInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint, @@ -3819,7 +3822,11 @@ impl EditorElement { } else { let use_tree_sitter = !snapshot.semantic_tokens_enabled || snapshot.use_tree_sitter_for_syntax(rows.start, cx); - let chunks = snapshot.highlighted_chunks(rows.clone(), use_tree_sitter, style); + let language_aware = LanguageAwareStyling { + tree_sitter: use_tree_sitter, + diagnostics: true, + }; + let chunks = snapshot.highlighted_chunks(rows.clone(), language_aware, style); LineWithInvisibles::from_chunks( chunks, style, @@ -11999,7 +12006,11 @@ pub fn layout_line( ) -> LineWithInvisibles { let use_tree_sitter = !snapshot.semantic_tokens_enabled || snapshot.use_tree_sitter_for_syntax(row, cx); - let chunks = snapshot.highlighted_chunks(row..row + DisplayRow(1), use_tree_sitter, style); + let language_aware = LanguageAwareStyling { + tree_sitter: use_tree_sitter, + diagnostics: true, + }; + let chunks = snapshot.highlighted_chunks(row..row + DisplayRow(1), language_aware, style); LineWithInvisibles::from_chunks( chunks, style, diff --git a/crates/editor/src/semantic_tokens.rs b/crates/editor/src/semantic_tokens.rs index 5e78be70d5627bd4f484a3efd44b13519b31b400..d485cfa70237fed542a240f202a8dc47b07467c4 100644 --- a/crates/editor/src/semantic_tokens.rs +++ b/crates/editor/src/semantic_tokens.rs @@ -475,13 +475,17 @@ mod tests { use gpui::{ AppContext as _, Entity, Focusable as _, HighlightStyle, TestAppContext, UpdateGlobal as _, }; - use language::{Language, LanguageConfig, LanguageMatcher}; + use language::{ + Diagnostic, DiagnosticEntry, DiagnosticSet, Language, LanguageAwareStyling, LanguageConfig, + LanguageMatcher, + }; use languages::FakeLspAdapter; + use lsp::LanguageServerId; use multi_buffer::{ AnchorRangeExt, ExpandExcerptDirection, MultiBuffer, MultiBufferOffset, PathKey, }; use project::Project; - use rope::Point; + use rope::{Point, PointUtf16}; use serde_json::json; use settings::{ GlobalLspSettingsContent, LanguageSettingsContent, SemanticTokenRule, SemanticTokenRules, @@ -2088,6 +2092,130 @@ mod tests { ); } + #[gpui::test] + async fn test_diagnostics_visible_when_semantic_token_set_to_full(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + update_test_language_settings(cx, &|language_settings| { + language_settings.languages.0.insert( + "Rust".into(), + LanguageSettingsContent { + semantic_tokens: Some(SemanticTokens::Full), + ..LanguageSettingsContent::default() + }, + ); + }); + + let mut cx = EditorLspTestContext::new_rust( + lsp::ServerCapabilities { + semantic_tokens_provider: Some( + lsp::SemanticTokensServerCapabilities::SemanticTokensOptions( + lsp::SemanticTokensOptions { + legend: lsp::SemanticTokensLegend { + token_types: vec!["function".into()], + token_modifiers: Vec::new(), + }, + full: Some(lsp::SemanticTokensFullOptions::Delta { delta: None }), + ..lsp::SemanticTokensOptions::default() + }, + ), + ), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + let mut full_request = cx + .set_request_handler::( + move |_, _, _| { + async move { + Ok(Some(lsp::SemanticTokensResult::Tokens( + lsp::SemanticTokens { + data: vec![ + 0, // delta_line + 3, // delta_start + 4, // length + 0, // token_type + 0, // token_modifiers_bitset + ], + result_id: Some("a".into()), + }, + ))) + } + }, + ); + + cx.set_state("ˇfn main() {}"); + assert!(full_request.next().await.is_some()); + + let task = cx.update_editor(|e, _, _| e.semantic_token_state.take_update_task()); + task.await; + + cx.update_buffer(|buffer, cx| { + buffer.update_diagnostics( + LanguageServerId(0), + DiagnosticSet::new( + [DiagnosticEntry { + range: PointUtf16::new(0, 3)..PointUtf16::new(0, 7), + diagnostic: Diagnostic { + severity: lsp::DiagnosticSeverity::ERROR, + group_id: 1, + message: "unused function".into(), + ..Default::default() + }, + }], + buffer, + ), + cx, + ) + }); + + cx.run_until_parked(); + let chunks = cx.update_editor(|editor, window, cx| { + editor + .snapshot(window, cx) + .display_snapshot + .chunks( + crate::display_map::DisplayRow(0)..crate::display_map::DisplayRow(1), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: true, + }, + crate::HighlightStyles::default(), + ) + .map(|chunk| { + ( + chunk.text.to_string(), + chunk.diagnostic_severity, + chunk.highlight_style, + ) + }) + .collect::>() + }); + + assert_eq!( + extract_semantic_highlights(&cx.editor, &cx), + vec![MultiBufferOffset(3)..MultiBufferOffset(7)] + ); + + assert!( + chunks.iter().any( + |(text, severity, style): &( + String, + Option, + Option + )| { + text == "main" + && *severity == Some(lsp::DiagnosticSeverity::ERROR) + && style.is_some() + } + ), + "expected 'main' chunk to have both diagnostic and semantic styling: {:?}", + chunks + ); + } + fn extract_semantic_highlight_styles( editor: &Entity, cx: &TestAppContext, diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index a467cd789555d39a32ad4e1d7b21da7b14df9c25..1e54134efcab4f0074a73b241f8e0d04cfbcbcdd 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -3733,16 +3733,24 @@ impl BufferSnapshot { /// returned in chunks where each chunk has a single syntax highlighting style and /// diagnostic status. #[ztracing::instrument(skip_all)] - pub fn chunks(&self, range: Range, language_aware: bool) -> BufferChunks<'_> { + pub fn chunks( + &self, + range: Range, + language_aware: LanguageAwareStyling, + ) -> BufferChunks<'_> { let range = range.start.to_offset(self)..range.end.to_offset(self); let mut syntax = None; - if language_aware { + if language_aware.tree_sitter { syntax = Some(self.get_highlights(range.clone())); } - // We want to look at diagnostic spans only when iterating over language-annotated chunks. - let diagnostics = language_aware; - BufferChunks::new(self.text.as_rope(), range, syntax, diagnostics, Some(self)) + BufferChunks::new( + self.text.as_rope(), + range, + syntax, + language_aware.diagnostics, + Some(self), + ) } pub fn highlighted_text_for_range( @@ -4477,7 +4485,13 @@ impl BufferSnapshot { let mut text = String::new(); let mut highlight_ranges = Vec::new(); let mut name_ranges = Vec::new(); - let mut chunks = self.chunks(source_range_for_text.clone(), true); + let mut chunks = self.chunks( + source_range_for_text.clone(), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ); let mut last_buffer_range_end = 0; for (buffer_range, is_name) in buffer_ranges { let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end; @@ -5402,7 +5416,13 @@ impl BufferSnapshot { let mut words = BTreeMap::default(); let mut current_word_start_ix = None; let mut chunk_ix = query.range.start; - for chunk in self.chunks(query.range, false) { + for chunk in self.chunks( + query.range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) { for (i, c) in chunk.text.char_indices() { let ix = chunk_ix + i; if classifier.is_word(c) { @@ -5441,6 +5461,15 @@ impl BufferSnapshot { } } +/// A configuration to use when producing styled text chunks. +#[derive(Clone, Copy)] +pub struct LanguageAwareStyling { + /// Whether to highlight text chunks using tree-sitter. + pub tree_sitter: bool, + /// Whether to highlight text chunks based on the diagnostics data. + pub diagnostics: bool, +} + pub struct WordsQuery<'a> { /// Only returns words with all chars from the fuzzy string in them. pub fuzzy_contents: Option<&'a str>, diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index 9308ee6f0a0ee207b30be9e6fafa73ba9452d94c..9f4562bf547f389c5ecc5ca29470ac4e49da0e04 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -4102,7 +4102,13 @@ fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { let snapshot = buffer.read(cx).snapshot(); // Get all chunks and verify their bitmaps - let chunks = snapshot.chunks(0..snapshot.len(), false); + let chunks = snapshot.chunks( + 0..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index a54ff64af028f44adced1758933f794e9a002c5a..47c1288c8f9baeebf4afd54dd0597bfe5a41d15f 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -21,9 +21,9 @@ use itertools::Itertools; use language::{ AutoindentMode, Buffer, BufferChunks, BufferRow, BufferSnapshot, Capability, CharClassifier, CharKind, CharScopeContext, Chunk, CursorShape, DiagnosticEntryRef, File, IndentGuideSettings, - IndentSize, Language, LanguageScope, OffsetRangeExt, OffsetUtf16, Outline, OutlineItem, Point, - PointUtf16, Selection, TextDimension, TextObject, ToOffset as _, ToPoint as _, TransactionId, - TreeSitterOptions, Unclipped, + IndentSize, Language, LanguageAwareStyling, LanguageScope, OffsetRangeExt, OffsetUtf16, + Outline, OutlineItem, Point, PointUtf16, Selection, TextDimension, TextObject, ToOffset as _, + ToPoint as _, TransactionId, TreeSitterOptions, Unclipped, language_settings::{AllLanguageSettings, LanguageSettings}, }; @@ -1072,7 +1072,7 @@ pub struct MultiBufferChunks<'a> { range: Range, excerpt_offset_range: Range, excerpt_chunks: Option>, - language_aware: bool, + language_aware: LanguageAwareStyling, snapshot: &'a MultiBufferSnapshot, } @@ -3340,9 +3340,15 @@ impl EventEmitter for MultiBuffer {} impl MultiBufferSnapshot { pub fn text(&self) -> String { - self.chunks(MultiBufferOffset::ZERO..self.len(), false) - .map(|chunk| chunk.text) - .collect() + self.chunks( + MultiBufferOffset::ZERO..self.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) + .map(|chunk| chunk.text) + .collect() } pub fn reversed_chars_at(&self, position: T) -> impl Iterator + '_ { @@ -3378,7 +3384,14 @@ impl MultiBufferSnapshot { } pub fn text_for_range(&self, range: Range) -> impl Iterator + '_ { - self.chunks(range, false).map(|chunk| chunk.text) + self.chunks( + range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) + .map(|chunk| chunk.text) } pub fn is_line_blank(&self, row: MultiBufferRow) -> bool { @@ -4178,7 +4191,7 @@ impl MultiBufferSnapshot { pub fn chunks( &self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, ) -> MultiBufferChunks<'_> { let mut chunks = MultiBufferChunks { excerpt_offset_range: ExcerptDimension(MultiBufferOffset::ZERO) @@ -7227,7 +7240,7 @@ impl Excerpt { fn chunks_in_range<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, snapshot: &'a MultiBufferSnapshot, ) -> ExcerptChunks<'a> { let buffer = self.buffer_snapshot(snapshot); diff --git a/crates/multi_buffer/src/multi_buffer_tests.rs b/crates/multi_buffer/src/multi_buffer_tests.rs index bc904d1a05488ee365ebddf36c3b30accdfb9301..cebc9073e9d87a3c6eaf71d78e181d3e833ad56a 100644 --- a/crates/multi_buffer/src/multi_buffer_tests.rs +++ b/crates/multi_buffer/src/multi_buffer_tests.rs @@ -5039,7 +5039,13 @@ fn check_edits( fn assert_chunks_in_ranges(snapshot: &MultiBufferSnapshot) { let full_text = snapshot.text(); for ix in 0..full_text.len() { - let mut chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let mut chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); chunks.seek(MultiBufferOffset(ix)..snapshot.len()); let tail = chunks.map(|chunk| chunk.text).collect::(); assert_eq!(tail, &full_text[ix..], "seek to range: {:?}", ix..); @@ -5300,7 +5306,13 @@ fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { let snapshot = multibuffer.read(cx).snapshot(cx); - let chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; @@ -5466,7 +5478,13 @@ fn test_random_chunk_bitmaps_with_diffs(cx: &mut App, mut rng: StdRng) { let snapshot = multibuffer.read(cx).snapshot(cx); - let chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index b7d5afcb687c017fdf253717a9dae2c95c55b53b..fa23b805cd48461dabaddbb7670155cdfe1ba8b0 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -23,8 +23,8 @@ use gpui::{ uniform_list, }; use itertools::Itertools; -use language::language_settings::LanguageSettings; use language::{Anchor, BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; +use language::{LanguageAwareStyling, language_settings::LanguageSettings}; use menu::{Cancel, SelectFirst, SelectLast, SelectNext, SelectPrevious}; use std::{ @@ -217,10 +217,13 @@ impl SearchState { let mut offset = context_offset_range.start; let mut context_text = String::new(); let mut highlight_ranges = Vec::new(); - for mut chunk in highlight_arguments - .multi_buffer_snapshot - .chunks(context_offset_range.start..context_offset_range.end, true) - { + for mut chunk in highlight_arguments.multi_buffer_snapshot.chunks( + context_offset_range.start..context_offset_range.end, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { if !non_whitespace_symbol_occurred { for c in chunk.text.chars() { if c.is_whitespace() { diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 2f579f5a724db143bbd4b0f9853a217bd6b14655..9ea50fdc8f12b68147c1073219625c4fd257afd3 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -72,9 +72,10 @@ use itertools::Itertools as _; use language::{ Bias, BinaryStatus, Buffer, BufferRow, BufferSnapshot, CachedLspAdapter, Capability, CodeLabel, CodeLabelExt, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, - File as _, Language, LanguageName, LanguageRegistry, LocalFile, LspAdapter, LspAdapterDelegate, - LspInstaller, ManifestDelegate, ManifestName, ModelineSettings, OffsetUtf16, Patch, PointUtf16, - TextBufferSnapshot, ToOffset, ToOffsetUtf16, ToPointUtf16, Toolchain, Transaction, Unclipped, + File as _, Language, LanguageAwareStyling, LanguageName, LanguageRegistry, LocalFile, + LspAdapter, LspAdapterDelegate, LspInstaller, ManifestDelegate, ManifestName, ModelineSettings, + OffsetUtf16, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToOffsetUtf16, ToPointUtf16, + Toolchain, Transaction, Unclipped, language_settings::{ AllLanguageSettings, FormatOnSave, Formatter, LanguageSettings, all_language_settings, }, @@ -13527,7 +13528,13 @@ fn resolve_word_completion(snapshot: &BufferSnapshot, completion: &mut Completio } let mut offset = 0; - for chunk in snapshot.chunks(word_range.clone(), true) { + for chunk in snapshot.chunks( + word_range.clone(), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { let end_offset = offset + chunk.text.len(); if let Some(highlight_id) = chunk.syntax_highlight_id { completion diff --git a/crates/project/tests/integration/project_tests.rs b/crates/project/tests/integration/project_tests.rs index d6c2ce37c9e60e17bd43c3f6c3ad10cde52b4bec..f680ccee78e997064af2647f68d8aa3631fa4bd3 100644 --- a/crates/project/tests/integration/project_tests.rs +++ b/crates/project/tests/integration/project_tests.rs @@ -41,9 +41,10 @@ use gpui::{ use itertools::Itertools; use language::{ Buffer, BufferEvent, Diagnostic, DiagnosticEntry, DiagnosticEntryRef, DiagnosticSet, - DiagnosticSourceKind, DiskState, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, - LanguageName, LineEnding, ManifestName, ManifestProvider, ManifestQuery, OffsetRangeExt, Point, - ToPoint, Toolchain, ToolchainList, ToolchainLister, ToolchainMetadata, + DiagnosticSourceKind, DiskState, FakeLspAdapter, Language, LanguageAwareStyling, + LanguageConfig, LanguageMatcher, LanguageName, LineEnding, ManifestName, ManifestProvider, + ManifestQuery, OffsetRangeExt, Point, ToPoint, Toolchain, ToolchainList, ToolchainLister, + ToolchainMetadata, language_settings::{LanguageSettings, LanguageSettingsContent}, markdown_lang, rust_lang, tree_sitter_typescript, }; @@ -4382,7 +4383,13 @@ fn chunks_with_diagnostics( range: Range, ) -> Vec<(String, Option)> { let mut chunks: Vec<(String, Option)> = Vec::new(); - for chunk in buffer.snapshot().chunks(range, true) { + for chunk in buffer.snapshot().chunks( + range, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { if chunks .last() .is_some_and(|prev_chunk| prev_chunk.1 == chunk.diagnostic_severity) diff --git a/crates/vim/src/state.rs b/crates/vim/src/state.rs index 4dd557199ab9aebe0a2b26438bdaa0e321a956b2..9e9b42d31900e0ceb160df4ad4dd3ce3a530e155 100644 --- a/crates/vim/src/state.rs +++ b/crates/vim/src/state.rs @@ -17,7 +17,7 @@ use gpui::{ Action, App, AppContext, BorrowAppContext, ClipboardEntry, ClipboardItem, DismissEvent, Entity, EntityId, Global, HighlightStyle, StyledText, Subscription, Task, TextStyle, WeakEntity, }; -use language::{Buffer, BufferEvent, BufferId, Chunk, Point}; +use language::{Buffer, BufferEvent, BufferId, Chunk, LanguageAwareStyling, Point}; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; @@ -1504,7 +1504,10 @@ impl PickerDelegate for MarksViewDelegate { position.row, snapshot.line_len(MultiBufferRow(position.row)), ), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, ); matches.push(MarksMatch { name: name.clone(), @@ -1530,7 +1533,10 @@ impl PickerDelegate for MarksViewDelegate { let chunks = snapshot.chunks( Point::new(position.row, 0) ..Point::new(position.row, snapshot.line_len(position.row)), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, ); matches.push(MarksMatch { From a856093ccafa7b422080f3073097560b04e8918d Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:51:46 -0300 Subject: [PATCH 14/22] sidebar: Fix focus movement while toggling it on and off (#53283) I was testing out the changes in https://github.com/zed-industries/zed/pull/52730 and realized that the agent panel, when full screen, would be auto-dismissed if I toggled the sidebar off. Turns out this happens because we were "hard-coding" the focus back to the center pane, which was automatically dismissing zoomed items. So, in this PR, I essentially am copying the ModalLayer approach of storing whatever was focused before so we can return focus back to it if possible. Release Notes: - N/A --- crates/workspace/src/multi_workspace.rs | 36 ++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index a61ad3576c57ecd8b1811363d6b5607ead737821..1b057e3fb1e3b5e0639e4a44462fc7528f6db85d 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -276,6 +276,7 @@ pub struct MultiWorkspace { pending_removal_tasks: Vec>, _serialize_task: Option>, _subscriptions: Vec, + previous_focus_handle: Option, } impl EventEmitter for MultiWorkspace {} @@ -333,6 +334,7 @@ impl MultiWorkspace { quit_subscription, settings_subscription, ], + previous_focus_handle: None, } } @@ -387,6 +389,7 @@ impl MultiWorkspace { if self.sidebar_open() { self.close_sidebar(window, cx); } else { + self.previous_focus_handle = window.focused(cx); self.open_sidebar(cx); if let Some(sidebar) = &self.sidebar { sidebar.prepare_for_focus(window, cx); @@ -417,14 +420,16 @@ impl MultiWorkspace { .is_some_and(|s| s.focus_handle(cx).contains_focused(window, cx)); if sidebar_is_focused { - let pane = self.workspace().read(cx).active_pane().clone(); - let pane_focus = pane.read(cx).focus_handle(cx); - window.focus(&pane_focus, cx); - } else if let Some(sidebar) = &self.sidebar { - sidebar.prepare_for_focus(window, cx); - sidebar.focus(window, cx); + self.restore_previous_focus(false, window, cx); + } else { + self.previous_focus_handle = window.focused(cx); + if let Some(sidebar) = &self.sidebar { + sidebar.prepare_for_focus(window, cx); + sidebar.focus(window, cx); + } } } else { + self.previous_focus_handle = window.focused(cx); self.open_sidebar(cx); if let Some(sidebar) = &self.sidebar { sidebar.prepare_for_focus(window, cx); @@ -457,13 +462,26 @@ impl MultiWorkspace { workspace.set_sidebar_focus_handle(None); }); } - let pane = self.workspace().read(cx).active_pane().clone(); - let pane_focus = pane.read(cx).focus_handle(cx); - window.focus(&pane_focus, cx); + self.restore_previous_focus(true, window, cx); self.serialize(cx); cx.notify(); } + fn restore_previous_focus(&mut self, clear: bool, window: &mut Window, cx: &mut Context) { + let focus_handle = if clear { + self.previous_focus_handle.take() + } else { + self.previous_focus_handle.clone() + }; + + if let Some(previous_focus) = focus_handle { + previous_focus.focus(window, cx); + } else { + let pane = self.workspace().read(cx).active_pane().clone(); + window.focus(&pane.read(cx).focus_handle(cx), cx); + } + } + pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context) { cx.spawn_in(window, async move |this, cx| { let workspaces = this.update(cx, |multi_workspace, _cx| { From 98c17ca1607a9bb223f831af6e221b3e7d47b28c Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 7 Apr 2026 12:28:19 -0300 Subject: [PATCH 15/22] language_models: Refactor deps and extract cloud (#53270) - `language_model` no longer depends on provider-specific crates such as `anthropic` and `open_ai` (inverted dependency) - `language_model_core` was extracted from `language_model` which contains the types for the provider-specific crates to convert to/from. - `gpui::SharedString` has been extracted into its own crate (still exposed by `gpui`), so `language_model_core` and provider API crates don't have to depend on `gpui`. - Removes some unnecessary `&'static str` | `SharedString` -> `String` -> `SharedString` conversions across the codebase. - Extracts the core logic of the cloud `LanguageModelProvider` into its own crate with simpler dependencies. Release Notes: - N/A --------- Co-authored-by: John Tur --- Cargo.lock | 89 +- Cargo.toml | 6 + crates/agent/src/tools/read_file_tool.rs | 2 +- crates/agent_servers/src/acp.rs | 2 +- crates/agent_ui/src/agent_registry_ui.rs | 2 +- crates/agent_ui/src/mention_set.rs | 2 +- crates/anthropic/Cargo.toml | 4 + crates/anthropic/src/anthropic.rs | 84 + crates/anthropic/src/completion.rs | 765 +++++++ crates/client/Cargo.toml | 1 - crates/client/src/client.rs | 2 +- crates/client/src/llm_token.rs | 2 +- crates/cloud_api_client/Cargo.toml | 1 + .../cloud_api_client/src/cloud_api_client.rs | 3 + crates/cloud_api_client/src/llm_token.rs | 74 + crates/cloud_llm_client/Cargo.toml | 3 +- .../cloud_llm_client/src/cloud_llm_client.rs | 1 + crates/collab_ui/src/collab_panel.rs | 4 +- crates/edit_prediction/Cargo.toml | 3 +- crates/edit_prediction/src/edit_prediction.rs | 2 +- crates/edit_prediction/src/ollama.rs | 2 +- .../src/zed_edit_prediction_delegate.rs | 4 +- crates/edit_prediction_cli/Cargo.toml | 2 +- crates/env_var/Cargo.toml | 2 +- crates/env_var/src/env_var.rs | 2 +- crates/git_ui/src/branch_picker.rs | 2 +- crates/google_ai/Cargo.toml | 4 +- crates/google_ai/src/completion.rs | 492 +++++ crates/google_ai/src/google_ai.rs | 3 +- crates/gpui/Cargo.toml | 1 + crates/gpui/src/gpui.rs | 3 +- crates/gpui/src/text_system/line.rs | 2 +- crates/gpui_shared_string/Cargo.toml | 17 + crates/gpui_shared_string/LICENSE-APACHE | 1 + .../gpui_shared_string.rs} | 0 crates/language_core/Cargo.toml | 4 +- crates/language_core/src/diagnostic.rs | 2 +- crates/language_core/src/grammar.rs | 2 +- crates/language_core/src/language_config.rs | 2 +- crates/language_core/src/language_name.rs | 2 +- crates/language_core/src/lsp_adapter.rs | 2 +- crates/language_core/src/manifest.rs | 2 +- crates/language_core/src/toolchain.rs | 2 +- crates/language_model/Cargo.toml | 9 +- crates/language_model/src/fake_provider.rs | 3 +- crates/language_model/src/language_model.rs | 633 +----- .../language_model/src/model/cloud_model.rs | 73 - crates/language_model/src/provider.rs | 12 - .../language_model/src/provider/anthropic.rs | 80 - crates/language_model/src/provider/google.rs | 5 - crates/language_model/src/provider/open_ai.rs | 28 - .../src/provider/open_router.rs | 69 - crates/language_model/src/provider/x_ai.rs | 4 - crates/language_model/src/provider/zed.rs | 5 - crates/language_model/src/registry.rs | 4 +- crates/language_model/src/request.rs | 626 +----- crates/language_model_core/Cargo.toml | 27 + crates/language_model_core/LICENSE-GPL | 1 + .../src/language_model_core.rs | 658 ++++++ crates/language_model_core/src/provider.rs | 21 + .../src/rate_limiter.rs | 0 crates/language_model_core/src/request.rs | 463 +++++ .../src/role.rs | 0 .../src/tool_schema.rs | 12 - .../src}/util.rs | 18 +- crates/language_models/Cargo.toml | 7 +- crates/language_models/src/provider.rs | 2 +- .../language_models/src/provider/anthropic.rs | 779 +------- .../language_models/src/provider/bedrock.rs | 2 +- crates/language_models/src/provider/cloud.rs | 1159 +---------- .../src/provider/copilot_chat.rs | 8 +- .../language_models/src/provider/deepseek.rs | 2 +- crates/language_models/src/provider/google.rs | 805 +------- .../language_models/src/provider/lmstudio.rs | 2 +- .../language_models/src/provider/mistral.rs | 2 +- .../language_models/src/provider/open_ai.rs | 1756 +---------------- .../src/provider/open_ai_compatible.rs | 4 +- .../src/provider/open_router.rs | 2 +- crates/language_models/src/provider/x_ai.rs | 40 +- crates/language_models_cloud/Cargo.toml | 33 + crates/language_models_cloud/LICENSE-GPL | 1 + .../src/language_models_cloud.rs | 1059 ++++++++++ crates/open_ai/Cargo.toml | 7 +- crates/open_ai/src/completion.rs | 1693 ++++++++++++++++ crates/open_ai/src/open_ai.rs | 26 +- crates/open_router/Cargo.toml | 1 + crates/open_router/src/open_router.rs | 68 + crates/project/src/prettier_store.rs | 2 +- crates/settings_content/Cargo.toml | 1 + crates/settings_content/src/language_model.rs | 34 +- crates/web_search_providers/Cargo.toml | 1 + crates/web_search_providers/src/cloud.rs | 2 +- crates/x_ai/Cargo.toml | 2 + crates/x_ai/src/completion.rs | 30 + crates/x_ai/src/x_ai.rs | 2 + 95 files changed, 5895 insertions(+), 5995 deletions(-) create mode 100644 crates/anthropic/src/completion.rs create mode 100644 crates/cloud_api_client/src/llm_token.rs create mode 100644 crates/google_ai/src/completion.rs create mode 100644 crates/gpui_shared_string/Cargo.toml create mode 120000 crates/gpui_shared_string/LICENSE-APACHE rename crates/{gpui/src/shared_string.rs => gpui_shared_string/gpui_shared_string.rs} (100%) delete mode 100644 crates/language_model/src/provider.rs delete mode 100644 crates/language_model/src/provider/anthropic.rs delete mode 100644 crates/language_model/src/provider/google.rs delete mode 100644 crates/language_model/src/provider/open_ai.rs delete mode 100644 crates/language_model/src/provider/open_router.rs delete mode 100644 crates/language_model/src/provider/x_ai.rs delete mode 100644 crates/language_model/src/provider/zed.rs create mode 100644 crates/language_model_core/Cargo.toml create mode 120000 crates/language_model_core/LICENSE-GPL create mode 100644 crates/language_model_core/src/language_model_core.rs create mode 100644 crates/language_model_core/src/provider.rs rename crates/{language_model => language_model_core}/src/rate_limiter.rs (100%) create mode 100644 crates/language_model_core/src/request.rs rename crates/{language_model => language_model_core}/src/role.rs (100%) rename crates/{language_model => language_model_core}/src/tool_schema.rs (92%) rename crates/{language_models/src/provider => language_model_core/src}/util.rs (88%) create mode 100644 crates/language_models_cloud/Cargo.toml create mode 120000 crates/language_models_cloud/LICENSE-GPL create mode 100644 crates/language_models_cloud/src/language_models_cloud.rs create mode 100644 crates/open_ai/src/completion.rs create mode 100644 crates/x_ai/src/completion.rs diff --git a/Cargo.lock b/Cargo.lock index cbc494f9dc0fc1858a846fabe168b3538de4dbe5..3fccd850ae697925330d15ed6b72804c39f4795e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,13 +629,17 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "collections", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -2903,7 +2907,6 @@ dependencies = [ "http_client", "http_client_tls", "httparse", - "language_model", "log", "objc2-foundation", "parking_lot", @@ -2959,6 +2962,7 @@ dependencies = [ "http_client", "parking_lot", "serde_json", + "smol", "thiserror 2.0.17", "yawc", ] @@ -5162,6 +5166,7 @@ dependencies = [ "buffer_diff", "client", "clock", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "collections", @@ -5641,7 +5646,7 @@ dependencies = [ name = "env_var" version = "0.1.0" dependencies = [ - "gpui", + "gpui_shared_string", ] [[package]] @@ -7468,11 +7473,13 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", + "tiktoken-rs", ] [[package]] @@ -7541,6 +7548,7 @@ dependencies = [ "getrandom 0.3.4", "gpui_macros", "gpui_platform", + "gpui_shared_string", "gpui_util", "gpui_web", "http_client", @@ -7710,6 +7718,16 @@ dependencies = [ "gpui_windows", ] +[[package]] +name = "gpui_shared_string" +version = "0.1.0" +dependencies = [ + "derive_more", + "gpui_util", + "schemars", + "serde", +] + [[package]] name = "gpui_tokio" version = "0.1.0" @@ -9358,7 +9376,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "gpui", + "gpui_shared_string", "log", "lsp", "parking_lot", @@ -9397,12 +9415,8 @@ dependencies = [ name = "language_model" version = "0.1.0" dependencies = [ - "anthropic", "anyhow", "base64 0.22.1", - "cloud_api_client", - "cloud_api_types", - "cloud_llm_client", "collections", "credentials_provider", "env_var", @@ -9411,16 +9425,31 @@ dependencies = [ "http_client", "icons", "image", + "language_model_core", "log", - "open_ai", - "open_router", "parking_lot", + "serde", + "serde_json", + "thiserror 2.0.17", + "util", +] + +[[package]] +name = "language_model_core" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "gpui_shared_string", + "http_client", + "partial-json-fixer", "schemars", "serde", "serde_json", "smol", + "strum 0.27.2", "thiserror 2.0.17", - "util", ] [[package]] @@ -9436,8 +9465,8 @@ dependencies = [ "base64 0.22.1", "bedrock", "client", + "cloud_api_client", "cloud_api_types", - "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9456,6 +9485,7 @@ dependencies = [ "http_client", "language", "language_model", + "language_models_cloud", "lmstudio", "log", "menu", @@ -9464,17 +9494,14 @@ dependencies = [ "open_ai", "open_router", "opencode", - "partial-json-fixer", "pretty_assertions", "release_channel", "schemars", - "semver", "serde", "serde_json", "settings", "smol", "strum 0.27.2", - "thiserror 2.0.17", "tiktoken-rs", "tokio", "ui", @@ -9484,6 +9511,28 @@ dependencies = [ "x_ai", ] +[[package]] +name = "language_models_cloud" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "google_ai", + "gpui", + "http_client", + "language_model", + "open_ai", + "schemars", + "semver", + "serde", + "serde_json", + "smol", + "thiserror 2.0.17", + "x_ai", +] + [[package]] name = "language_onboarding" version = "0.1.0" @@ -11631,16 +11680,19 @@ name = "open_ai" version = "0.1.0" dependencies = [ "anyhow", + "collections", "futures 0.3.32", "http_client", + "language_model_core", "log", + "pretty_assertions", "rand 0.9.2", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -11672,6 +11724,7 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", "schemars", "serde", "serde_json", @@ -15801,6 +15854,7 @@ dependencies = [ "collections", "derive_more", "gpui", + "language_model_core", "log", "schemars", "serde", @@ -20180,6 +20234,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "futures 0.3.32", @@ -21783,9 +21838,11 @@ name = "x_ai" version = "0.1.0" dependencies = [ "anyhow", + "language_model_core", "schemars", "serde", "strum 0.27.2", + "tiktoken-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4c75dafae5df4d63815e0da5cabb95ccdad25e9d..5a7fc9caaf982953168855671bebbcf4f010df03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,7 @@ members = [ "crates/google_ai", "crates/grammars", "crates/gpui", + "crates/gpui_shared_string", "crates/gpui_linux", "crates/gpui_macos", "crates/gpui_macros", @@ -110,7 +111,9 @@ members = [ "crates/language_core", "crates/language_extension", "crates/language_model", + "crates/language_model_core", "crates/language_models", + "crates/language_models_cloud", "crates/language_onboarding", "crates/language_selector", "crates/language_tools", @@ -335,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } grammars = { path = "crates/grammars" } gpui = { path = "crates/gpui", default-features = false } +gpui_shared_string = { path = "crates/gpui_shared_string" } gpui_linux = { path = "crates/gpui_linux", default-features = false } gpui_macos = { path = "crates/gpui_macos", default-features = false } gpui_macros = { path = "crates/gpui_macros" } @@ -361,7 +365,9 @@ language = { path = "crates/language" } language_core = { path = "crates/language_core" } language_extension = { path = "crates/language_extension" } language_model = { path = "crates/language_model" } +language_model_core = { path = "crates/language_model_core" } language_models = { path = "crates/language_models" } +language_models_cloud = { path = "crates/language_models_cloud" } language_onboarding = { path = "crates/language_onboarding" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 0086a82f4e79c9924502202873ceb2b25d2e66fb..9b013f111e7eaa981652d8868dfcf3c098d9dc7e 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -5,7 +5,7 @@ use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; -use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent}; use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 5f452bc9c0e2e9c2322042583295894a5866b053..e56db9df927ab3cdf838587f1cb4f9514eb5a758 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -325,7 +325,7 @@ impl AcpConnection { // Use the one the agent provides if we have one .map(|info| info.name.into()) // Otherwise, just use the name - .unwrap_or_else(|| agent_id.0.to_string().into()); + .unwrap_or_else(|| agent_id.0.clone()); let session_list = if response .agent_capabilities diff --git a/crates/agent_ui/src/agent_registry_ui.rs b/crates/agent_ui/src/agent_registry_ui.rs index 78b4e3a5a3965c72b96d4ec201139b1d8e510fb2..e19afdecc390268cefbd7be4e5d0759aa2a29c19 100644 --- a/crates/agent_ui/src/agent_registry_ui.rs +++ b/crates/agent_ui/src/agent_registry_ui.rs @@ -382,7 +382,7 @@ impl AgentRegistryPage { self.install_button(agent, install_status, supports_current_platform, cx); let repository_button = agent.repository().map(|repository| { - let repository_for_tooltip: SharedString = repository.to_string().into(); + let repository_for_tooltip = repository.clone(); let repository_for_click = repository.to_string(); IconButton::new( diff --git a/crates/agent_ui/src/mention_set.rs b/crates/agent_ui/src/mention_set.rs index 1b2ec0ad2fd460b4eec5a8b757bdd3058d4a3704..880257e3f942bf71d1d51b1e661d911474aa786b 100644 --- a/crates/agent_ui/src/mention_set.rs +++ b/crates/agent_ui/src/mention_set.rs @@ -18,7 +18,7 @@ use gpui::{ use http_client::{AsyncBody, HttpClientWithUrl}; use itertools::Either; use language::Buffer; -use language_model::LanguageModelImage; +use language_model::{LanguageModelImage, LanguageModelImageExt}; use multi_buffer::MultiBufferRow; use postage::stream::Stream as _; use project::{Project, ProjectItem, ProjectPath, Worktree}; diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 1e2587435489dea6952c697b0e0a4cf627226728..458f9bfae7da4736c4e54e42f08b5e3a926ed30a 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -18,12 +18,16 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true chrono.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 5d7790b86b09853e22436252fcde1bebf5feff9b..48fa318d7c1d87e63725cef836baf9c945966206 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString}; use thiserror::Error; pub mod batches; +pub mod completion; pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; @@ -1026,6 +1027,89 @@ pub async fn count_tokens( } } +// -- Conversions from/to `language_model_core` types -- + +impl From for Speed { + fn from(speed: language_model_core::Speed) -> Self { + match speed { + language_model_core::Speed::Standard => Speed::Standard, + language_model_core::Speed::Fast => Speed::Fast, + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: AnthropicError) -> Self { + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error { + AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, + AnthropicError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + AnthropicError::HttpResponseError { + status_code, + message, + } => Self::HttpResponseError { + provider, + status_code, + message, + }, + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + AnthropicError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error.code() { + Some(code) => match code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + NotFoundError => Self::ApiEndpointNotFound { provider }, + RequestTooLarge => Self::PromptTooLarge { + tokens: language_model_core::parse_prompt_too_long(&error.message), + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + }, + None => Self::Other(error.into()), + } + } +} + #[test] fn test_match_window_exceeded() { let error = ApiError { diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6175a4f7c24b3b724734b2edef48ef8acfaa159 --- /dev/null +++ b/crates/anthropic/src/completion.rs @@ -0,0 +1,765 @@ +use anyhow::Result; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::str::FromStr; + +use crate::{ + AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, + CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent, + StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage, +}; + +fn to_anthropic_content(content: MessageContent) -> Option { + match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if let Some(signature) = signature + && !thinking.is_empty() + { + Some(RequestContent::Thinking { + thinking, + signature, + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(RequestContent::RedactedThinking { data }) + } else { + None + } + } + MessageContent::Image(image) => Some(RequestContent::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }), + MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }), + } +} + +/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. +pub fn into_anthropic_count_tokens_request( + request: LanguageModelRequest, + model: String, + mode: AnthropicModelMode, +) -> CountTokensRequest { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + CountTokensRequest { + model, + messages: new_messages, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + } +} + +/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, +/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. +pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) +} + +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: AnthropicModelMode, +) -> crate::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + // Mark the last segment of the message as cached + if message.cache { + let cache_control_value = Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + RequestContent::RedactedThinking { .. } => { + // Caching is not possible, fallback to next message + } + RequestContent::Text { cache_control, .. } + | RequestContent::Thinking { cache_control, .. } + | RequestContent::Image { cache_control, .. } + | RequestContent::ToolUse { cache_control, .. } + | RequestContent::ToolResult { cache_control, .. } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + crate::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + metadata: None, + output_config: if request.thinking_allowed + && matches!(mode, AnthropicModelMode::AdaptiveThinking) + { + request.thinking_effort.as_deref().and_then(|effort| { + let effort = match effort { + "low" => Some(crate::Effort::Low), + "medium" => Some(crate::Effort::Medium), + "high" => Some(crate::Effort::High), + "max" => Some(crate::Effort::Max), + _ => None, + }; + effort.map(|effort| crate::OutputConfig { + effort: Some(effort), + }) + }) + } else { + None + }, + stop_sequences: Vec::new(), + speed: request.speed.map(Into::into), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = + serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) + { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))]; + } + } + vec![] + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let event_result = match parse_tool_arguments(input_json) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + thought_signature: None, + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::AnthropicModelMode; + use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; + + #[test] + fn test_cache_control_only_on_last_segment() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Some prompt".to_string()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + ], + cache: true, + reasoning_details: None, + }], + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + let anthropic_request = into_anthropic( + request, + "claude-3-5-sonnet".to_string(), + 0.7, + 4096, + AnthropicModelMode::Default, + ); + + assert_eq!(anthropic_request.messages.len(), 1); + + let message = &anthropic_request.messages[0]; + assert_eq!(message.content.len(), 5); + + assert!(matches!( + message.content[0], + RequestContent::Text { + cache_control: None, + .. + } + )); + for i in 1..3 { + assert!(matches!( + message.content[i], + RequestContent::Image { + cache_control: None, + .. + } + )); + } + + assert!(matches!( + message.content[4], + RequestContent::Image { + cache_control: Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }), + .. + } + )); + } + + fn request_with_assistant_content(assistant_content: Vec) -> crate::Request { + let mut request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("Hello".to_string())], + cache: false, + reasoning_details: None, + }], + thinking_effort: None, + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + speed: None, + }; + request.messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: assistant_content, + cache: false, + reasoning_details: None, + }); + into_anthropic( + request, + "claude-sonnet-4-5".to_string(), + 1.0, + 16000, + AnthropicModelMode::Thinking { + budget_tokens: Some(10000), + }, + ) + } + + #[test] + fn test_unsigned_thinking_blocks_stripped() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Cancelled mid-think, no signature".to_string(), + signature: None, + }, + MessageContent::Text("Some response text".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should still exist"); + + assert_eq!( + assistant_message.content.len(), + 1, + "Only the text content should remain; unsigned thinking block should be stripped" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Text { text, .. } if text == "Some response text" + )); + } + + #[test] + fn test_signed_thinking_blocks_preserved() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Completed thinking".to_string(), + signature: Some("valid-signature".to_string()), + }, + MessageContent::Text("Response".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should exist"); + + assert_eq!( + assistant_message.content.len(), + 2, + "Both the signed thinking block and text should be preserved" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Thinking { thinking, signature, .. } + if thinking == "Completed thinking" && signature == "valid-signature" + )); + } + + #[test] + fn test_only_unsigned_thinking_block_omits_entire_message() { + let result = request_with_assistant_content(vec![MessageContent::Thinking { + text: "Cancelled before any text or signature".to_string(), + signature: None, + }]); + + let assistant_messages: Vec<_> = result + .messages + .iter() + .filter(|m| m.role == crate::Role::Assistant) + .collect(); + + assert_eq!( + assistant_messages.len(), + 0, + "An assistant message whose only content was an unsigned thinking block \ + should be omitted entirely" + ); + } +} diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 7bbaccb22e0e6c7508240186103e216f83be2f0c..532fe38f7df1f686730ed862a81806e9a531e156 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -36,7 +36,6 @@ gpui_tokio.workspace = true http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" -language_model.workspace = true log.workspace = true parking_lot.workspace = true paths.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index dfd9963a0ee52d167f8d4edb0b850f4debed7fd4..05ca974f80438542b232262dd375e0e38ab4327c 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{ http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; +use cloud_api_client::LlmApiToken; use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{ClientApiError, CloudApiClient}; use cloud_api_types::OrganizationId; @@ -26,7 +27,6 @@ use futures::{ }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; -use language_model::LlmApiToken; use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs index f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816..70457679e4b965e3251ae4861d3052bfa41fd65a 100644 --- a/crates/client/src/llm_token.rs +++ b/crates/client/src/llm_token.rs @@ -1,10 +1,10 @@ use super::{Client, UserStore}; +use cloud_api_client::LlmApiToken; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, }; -use language_model::LlmApiToken; use std::sync::Arc; pub trait NeedsLlmTokenRefresh { diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml index 78c684e3e54ee29a5f3f3ae5620d4a52b445f92e..cf293d83f848e1266dec977c0925af7f66608ce6 100644 --- a/crates/cloud_api_client/Cargo.toml +++ b/crates/cloud_api_client/Cargo.toml @@ -20,5 +20,6 @@ gpui_tokio.workspace = true http_client.workspace = true parking_lot.workspace = true serde_json.workspace = true +smol.workspace = true thiserror.workspace = true yawc.workspace = true diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 13d67838b216f4990f15ec22c1701aa7aef9dbf2..8c605bb3490ef5c7aea6e96045680338e8344a83 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -1,3 +1,4 @@ +mod llm_token; mod websocket; use std::sync::Arc; @@ -18,6 +19,8 @@ use yawc::WebSocket; use crate::websocket::Connection; +pub use llm_token::LlmApiToken; + struct Credentials { user_id: u32, access_token: String, diff --git a/crates/cloud_api_client/src/llm_token.rs b/crates/cloud_api_client/src/llm_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..711e0d51b89bf34db255d7cb1e58483c9de340fc --- /dev/null +++ b/crates/cloud_api_client/src/llm_token.rs @@ -0,0 +1,74 @@ +use std::sync::Arc; + +use cloud_api_types::OrganizationId; +use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; + +use crate::{ClientApiError, CloudApiClient}; + +#[derive(Clone, Default)] +pub struct LlmApiToken(Arc>>); + +impl LlmApiToken { + pub async fn acquire( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let lock = self.0.upgradable_read().await; + if let Some(token) = lock.as_ref() { + Ok(token.to_string()) + } else { + Self::fetch( + RwLockUpgradableReadGuard::upgrade(lock).await, + client, + system_id, + organization_id, + ) + .await + } + } + + pub async fn refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + Self::fetch(self.0.write().await, client, system_id, organization_id).await + } + + /// Clears the existing token before attempting to fetch a new one. + /// + /// Used when switching organizations so that a failed refresh doesn't + /// leave a token for the wrong organization. + pub async fn clear_and_refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let mut lock = self.0.write().await; + *lock = None; + Self::fetch(lock, client, system_id, organization_id).await + } + + async fn fetch( + mut lock: RwLockWriteGuard<'_, Option>, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let result = client.create_llm_token(system_id, organization_id).await; + match result { + Ok(response) => { + *lock = Some(response.token.0.clone()); + Ok(response.token.0) + } + Err(err) => { + *lock = None; + Err(err) + } + } + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index a7b4f925a9302296e8fe25a14177a583e5f44b33..7cc59f255abeb27c6e35a2064654d8eca1a581fe 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" [features] test-support = [] +predict-edits = ["dep:zeta_prompt"] [lints] workspace = true @@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } uuid = { workspace = true, features = ["serde"] } -zeta_prompt.workspace = true +zeta_prompt = { workspace = true, optional = true } diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 35eb3f2b80dd400558b1f027781f5b8cf63bb6cb..ac8bdd462a9c4754ef42a6afa41f1bef8b5bbe6a 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "predict-edits")] pub mod predict_edits_v3; use std::str::FromStr; diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 7dc807998760a8e65d373164eec5c7663171e5d0..327ef1cf6003eb959bd0926d67d2b0ed3b4ab0ba 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2846,11 +2846,11 @@ impl CollabPanel { } }; - Some(channel.name.as_ref()) + Some(channel.name.clone()) }); if let Some(name) = channel_name { - SharedString::from(name.to_string()) + name } else { SharedString::from("Current Call") } diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index eabb1641fd4fbec7b2f8ef0ba399a8fe9600dfa3..87ad4e42e7826cdda4fc6a8c31a27afe888830f0 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -21,8 +21,9 @@ heapless.workspace = true buffer_diff.workspace = true client.workspace = true clock.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true copilot.workspace = true copilot_ui.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 280427df006b510e1854ffb40cd7f995fcd9fdc6..2d90e13fb9b45aedd354f753502cd4e616ae3bcd 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,5 +1,6 @@ use anyhow::Result; use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, @@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::LlmApiToken; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; diff --git a/crates/edit_prediction/src/ollama.rs b/crates/edit_prediction/src/ollama.rs index 0250ec44a46cf081c6badc6fa11a9c34ebb65c4a..0ae90dd9f6eca4bfe9f87950a5a66916d8894df4 100644 --- a/crates/edit_prediction/src/ollama.rs +++ b/crates/edit_prediction/src/ollama.rs @@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec { let mut models: Vec = provider .provided_models(cx) .into_iter() - .map(|model| SharedString::from(model.id().0.to_string())) + .map(|model| model.id().0) .collect(); models.sort(); models diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index c5e97fd87eaad9b98aeb9b946a9a69b1c1071db2..1a574e9389715ce888f8b8c5ec8be921ceab4a38 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { BufferEditPrediction::Local { prediction } => prediction, BufferEditPrediction::Jump { prediction } => { return Some(edit_prediction_types::EditPrediction::Jump { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), snapshot: prediction.snapshot.clone(), target: prediction.edits.first().unwrap().0.start, }); @@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { } Some(edit_prediction_types::EditPrediction::Local { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), cursor_position: prediction.cursor_position, edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 323ee3de41902b2140f95da22b0e37fb98d31fd5..a999fed2baf990273f0801bac15573b3aed0cc78 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -22,7 +22,7 @@ http_client.workspace = true chrono.workspace = true clap = "4" client.workspace = true -cloud_llm_client.workspace= true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true db.workspace = true debug_adapter_extension.workspace = true diff --git a/crates/env_var/Cargo.toml b/crates/env_var/Cargo.toml index 2cbbd08c7833d3e57a09766d42ffffe35c620a93..3c879a2f49184e19a131046320d767931e1ca8ec 100644 --- a/crates/env_var/Cargo.toml +++ b/crates/env_var/Cargo.toml @@ -12,4 +12,4 @@ workspace = true path = "src/env_var.rs" [dependencies] -gpui.workspace = true +gpui_shared_string.workspace = true diff --git a/crates/env_var/src/env_var.rs b/crates/env_var/src/env_var.rs index 79f671e0147ebfaad4ab76a123cc477dc7e55cb7..cb436e95e0e734e4b7d8d271199246e1558a074d 100644 --- a/crates/env_var/src/env_var.rs +++ b/crates/env_var/src/env_var.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone)] pub struct EnvVar { diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 83c8119a077ac1c024dbb3b3df948f762b072ec1..2bf4a1991f7a302ed73fe098e8914fedd0f9eb2a 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -1906,7 +1906,7 @@ mod tests { assert_eq!( remotes, vec![Remote { - name: SharedString::from("my_new_remote".to_string()) + name: SharedString::from("my_new_remote") }] ); } diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 81e05e4836529e9b73b58b72683a7e72a4d5c984..d91d28851997723835ba85be343a453918301c71 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -18,8 +18,10 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a15fdaa0187e52cb82dc8c71b5b861eb797f1a8 --- /dev/null +++ b/crates/google_ai/src/completion.rs @@ -0,0 +1,492 @@ +use anyhow::Result; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, + StopReason, TokenUsage, +}; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{self, AtomicU64}; + +use crate::{ + Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, + GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode, + InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig, + UsageMetadata, +}; + +pub fn into_google( + mut request: LanguageModelRequest, + model_id: String, + mode: GoogleModelMode, +) -> crate::GenerateContentRequest { + fn map_content(content: Vec) -> Vec { + content + .into_iter() + .flat_map(|content| match content { + MessageContent::Text(text) => { + if !text.is_empty() { + vec![Part::TextPart(TextPart { text })] + } else { + vec![] + } + } + MessageContent::Thinking { + text: _, + signature: Some(signature), + } => { + if !signature.is_empty() { + vec![Part::ThoughtPart(crate::ThoughtPart { + thought: true, + thought_signature: signature, + })] + } else { + vec![] + } + } + MessageContent::Thinking { .. } => { + vec![] + } + MessageContent::RedactedThinking(_) => vec![], + MessageContent::Image(image) => { + vec![Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + })] + } + MessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None + let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); + + vec![Part::FunctionCallPart(crate::FunctionCallPart { + function_call: crate::FunctionCall { + name: tool_use.name.to_string(), + args: tool_use.input, + }, + thought_signature, + })] + } + MessageContent::ToolResult(tool_result) => { + match tool_result.content { + language_model_core::LanguageModelToolResultContent::Text(text) => { + vec![Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": text + }), + }, + })] + } + language_model_core::LanguageModelToolResultContent::Image(image) => { + vec![ + Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": "Tool responded with an image" + }), + }, + }), + Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }), + ] + } + } + } + }) + .collect() + } + + let system_instructions = if request + .messages + .first() + .is_some_and(|msg| matches!(msg.role, Role::System)) + { + let message = request.messages.remove(0); + Some(SystemInstruction { + parts: map_content(message.content), + }) + } else { + None + }; + + crate::GenerateContentRequest { + model: ModelName { model_id }, + system_instruction: system_instructions, + contents: request + .messages + .into_iter() + .filter_map(|message| { + let parts = map_content(message.content); + if parts.is_empty() { + None + } else { + Some(Content { + parts, + role: match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Model, + Role::System => crate::Role::User, // Google AI doesn't have a system role + }, + }) + } + }) + .collect(), + generation_config: Some(GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(request.stop), + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config: match (request.thinking_allowed, mode) { + (true, GoogleModelMode::Thinking { budget_tokens }) => { + budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) + } + _ => None, + }, + top_p: None, + top_k: None, + }), + safety_settings: None, + tools: (!request.tools.is_empty()).then(|| { + vec![crate::Tool { + function_declarations: request + .tools + .into_iter() + .map(|tool| FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }) + .collect(), + }] + }), + tool_config: request.tool_choice.map(|choice| ToolConfig { + function_calling_config: FunctionCallingConfig { + mode: match choice { + LanguageModelToolChoice::Auto => FunctionCallingMode::Auto, + LanguageModelToolChoice::Any => FunctionCallingMode::Any, + LanguageModelToolChoice::None => FunctionCallingMode::None, + }, + allowed_function_names: None, + }, + }), + } +} + +pub struct GoogleEventMapper { + usage: UsageMetadata, + stop_reason: StopReason, +} + +impl GoogleEventMapper { + pub fn new() -> Self { + Self { + usage: UsageMetadata::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events + .map(Some) + .chain(futures::stream::once(async { None })) + .flat_map(move |event| { + futures::stream::iter(match event { + Some(Ok(event)) => self.map_event(event), + Some(Err(error)) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], + }) + }) + } + + pub fn map_event( + &mut self, + event: GenerateContentResponse, + ) -> Vec> { + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + + let mut events: Vec<_> = Vec::new(); + let mut wants_to_use_tool = false; + if let Some(usage_metadata) = event.usage_metadata { + update_usage(&mut self.usage, &usage_metadata); + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))) + } + + if let Some(prompt_feedback) = event.prompt_feedback + && let Some(block_reason) = prompt_feedback.block_reason.as_deref() + { + self.stop_reason = match block_reason { + "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { + StopReason::Refusal + } + _ => { + log::error!("Unexpected Google block_reason: {block_reason}"); + StopReason::Refusal + } + }; + events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); + + return events; + } + + if let Some(candidates) = event.candidates { + for candidate in candidates { + if let Some(finish_reason) = candidate.finish_reason.as_deref() { + self.stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + _ => { + log::error!("Unexpected google finish_reason: {finish_reason}"); + StopReason::EndTurn + } + }; + } + candidate + .content + .parts + .into_iter() + .for_each(|part| match part { + Part::TextPart(text_part) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + } + Part::InlineDataPart(_) => {} + Part::FunctionCallPart(function_call_part) => { + wants_to_use_tool = true; + let name: Arc = function_call_part.function_call.name.into(); + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); + let id: LanguageModelToolUseId = + format!("{}-{}", name, next_tool_id).into(); + + // Normalize empty string signatures to None + let thought_signature = function_call_part + .thought_signature + .filter(|s| !s.is_empty()); + + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id, + name, + is_input_complete: true, + raw_input: function_call_part.function_call.args.to_string(), + input: function_call_part.function_call.args, + thought_signature, + }, + ))); + } + Part::FunctionResponsePart(_) => {} + Part::ThoughtPart(part) => { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? + signature: Some(part.thought_signature), + })); + } + }); + } + } + + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` + if wants_to_use_tool { + self.stop_reason = StopReason::ToolUse; + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + events + } +} + +/// Count tokens for a Google AI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_google_tokens(request: LanguageModelRequest) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) +} + +fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { + if let Some(prompt_token_count) = new.prompt_token_count { + usage.prompt_token_count = Some(prompt_token_count); + } + if let Some(cached_content_token_count) = new.cached_content_token_count { + usage.cached_content_token_count = Some(cached_content_token_count); + } + if let Some(candidates_token_count) = new.candidates_token_count { + usage.candidates_token_count = Some(candidates_token_count); + } + if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { + usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); + } + if let Some(thoughts_token_count) = new.thoughts_token_count { + usage.thoughts_token_count = Some(thoughts_token_count); + } + if let Some(total_token_count) = new.total_token_count { + usage.total_token_count = Some(total_token_count); + } +} + +fn convert_usage(usage: &UsageMetadata) -> TokenUsage { + let prompt_tokens = usage.prompt_token_count.unwrap_or(0); + let cached_tokens = usage.cached_content_token_count.unwrap_or(0); + let input_tokens = prompt_tokens - cached_tokens; + let output_tokens = usage.candidates_token_count.unwrap_or(0); + + TokenUsage { + input_tokens, + output_tokens, + cache_read_input_tokens: cached_tokens, + cache_creation_input_tokens: 0, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, + Part, Role as GoogleRole, + }; + use serde_json::json; + + #[test] + fn test_function_call_with_signature_creates_tool_use_with_signature() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature_123".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "test_function"); + assert_eq!( + tool_use.thought_signature.as_deref(), + Some("test_signature_123") + ); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_function_call_without_signature_has_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: None, + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_empty_string_signature_normalized_to_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } +} diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 7659be8ab44da35efd16389c4abd0bf99d8cf3a4..5770c9a020b04bf280908993911b67ec3a5b980f 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -3,8 +3,9 @@ use std::mem; use anyhow::{Result, anyhow, bail}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +pub use language_model_core::ModelMode as GoogleModelMode; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use settings::ModelMode as GoogleModelMode; +pub mod completion; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 915f0fc03e2cc5beaf40c810654724295c41cde8..efb4817ef0e0c037bc08d0c5a8ad702705cb996d 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -56,6 +56,7 @@ etagere = "0.2" futures.workspace = true futures-concurrency.workspace = true gpui_macros.workspace = true +gpui_shared_string.workspace = true http_client.workspace = true image.workspace = true inventory.workspace = true diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 6d7d801cd42c3639d7892295a660319d21b05dfa..dbb57f46efc37678c07dfd4f02bb3faebc60c9a3 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -39,7 +39,6 @@ pub mod profiler; #[expect(missing_docs)] pub mod queue; mod scene; -mod shared_string; mod shared_uri; mod style; mod styled; @@ -92,6 +91,7 @@ pub use global::*; pub use gpui_macros::{ AppContext, IntoElement, Render, VisualContext, property_test, register_action, test, }; +pub use gpui_shared_string::*; pub use gpui_util::arc_cow::ArcCow; pub use http_client; pub use input::*; @@ -106,7 +106,6 @@ pub use profiler::*; pub use queue::{PriorityQueueReceiver, PriorityQueueSender}; pub use refineable::*; pub use scene::*; -pub use shared_string::*; pub use shared_uri::*; use std::{any::Any, future::Future}; pub use style::*; diff --git a/crates/gpui/src/text_system/line.rs b/crates/gpui/src/text_system/line.rs index 7b5714188ff97d0169806ac5da9f039f9be2c16a..611c979bc29f488fa18386c7b319a7310b6ce1c6 100644 --- a/crates/gpui/src/text_system/line.rs +++ b/crates/gpui/src/text_system/line.rs @@ -882,7 +882,7 @@ mod tests { ], len: 6, }), - text: SharedString::new("abcdef".to_string()), + text: "abcdef".into(), decoration_runs: SmallVec::new(), }; diff --git a/crates/gpui_shared_string/Cargo.toml b/crates/gpui_shared_string/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f7735b4f88253de7cd62d30445153d2a6284751 --- /dev/null +++ b/crates/gpui_shared_string/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "gpui_shared_string" +version = "0.1.0" +publish.workspace = true +edition.workspace = true + +[lib] +path = "gpui_shared_string.rs" + +[dependencies] +derive_more.workspace = true +gpui_util.workspace = true +schemars.workspace = true +serde.workspace = true + +[lints] +workspace = true diff --git a/crates/gpui_shared_string/LICENSE-APACHE b/crates/gpui_shared_string/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/gpui_shared_string/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/gpui/src/shared_string.rs b/crates/gpui_shared_string/gpui_shared_string.rs similarity index 100% rename from crates/gpui/src/shared_string.rs rename to crates/gpui_shared_string/gpui_shared_string.rs diff --git a/crates/language_core/Cargo.toml b/crates/language_core/Cargo.toml index 4861632b4663c860706525c65cd8607133b3ec71..cd1143f61d3af1d3b72bb5bd3a23e53b27aa9aba 100644 --- a/crates/language_core/Cargo.toml +++ b/crates/language_core/Cargo.toml @@ -10,7 +10,7 @@ path = "src/language_core.rs" [dependencies] anyhow.workspace = true collections.workspace = true -gpui.workspace = true +gpui_shared_string.workspace = true log.workspace = true lsp.workspace = true parking_lot.workspace = true @@ -22,8 +22,6 @@ toml.workspace = true tree-sitter.workspace = true util.workspace = true -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } [features] test-support = [] diff --git a/crates/language_core/src/diagnostic.rs b/crates/language_core/src/diagnostic.rs index 9a468a14b863a94ef23e00c3e15edd9fa2d8b09a..00abcb61d1b1290dd96c69b31296eebfd3900348 100644 --- a/crates/language_core/src/diagnostic.rs +++ b/crates/language_core/src/diagnostic.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::{DiagnosticSeverity, NumberOrString}; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/language_core/src/grammar.rs b/crates/language_core/src/grammar.rs index 54e9a3f1b3309718436b206874802779925a9d04..44f73ac6dea235a522393b5b0bd10729999b45bf 100644 --- a/crates/language_core/src/grammar.rs +++ b/crates/language_core/src/grammar.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use collections::HashMap; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; diff --git a/crates/language_core/src/language_config.rs b/crates/language_core/src/language_config.rs index f412af418b7948b40e3bdac5a3a649d12d008e8a..89474dbad9171d37cfb1b7f55f70a137eeb535d5 100644 --- a/crates/language_core/src/language_config.rs +++ b/crates/language_core/src/language_config.rs @@ -1,6 +1,6 @@ use crate::LanguageName; use collections::{HashMap, HashSet, IndexSet}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use regex::Regex; use schemars::{JsonSchema, SchemaGenerator, json_schema}; diff --git a/crates/language_core/src/language_name.rs b/crates/language_core/src/language_name.rs index 764b54a48a566ad98212de3e22bce6aca9a1e393..14528435d9103b4faad3e055ea69bbdaf372113c 100644 --- a/crates/language_core/src/language_name.rs +++ b/crates/language_core/src/language_name.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ diff --git a/crates/language_core/src/lsp_adapter.rs b/crates/language_core/src/lsp_adapter.rs index 03012f71143428b49ea9d75a03b0118b50e413b4..8f449637b306c2a33a76cb5b356d0280903f4187 100644 --- a/crates/language_core/src/lsp_adapter.rs +++ b/crates/language_core/src/lsp_adapter.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use serde::{Deserialize, Serialize}; /// Converts a value into an LSP position. diff --git a/crates/language_core/src/manifest.rs b/crates/language_core/src/manifest.rs index 1e762ff6e7c364eef02eea16ce9e1ecaaa198554..864f89e6cee65b0dff7c4462c99940c32ba0830f 100644 --- a/crates/language_core/src/manifest.rs +++ b/crates/language_core/src/manifest.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ManifestName(SharedString); diff --git a/crates/language_core/src/toolchain.rs b/crates/language_core/src/toolchain.rs index a021cb86bd36295a065b16281209c5fc3b63cffc..78bd69917fbc0f66af454ba262c1eb3b7c357290 100644 --- a/crates/language_core/src/toolchain.rs +++ b/crates/language_core/src/toolchain.rs @@ -6,7 +6,7 @@ use std::{path::Path, sync::Arc}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use util::rel_path::RelPath; use crate::{LanguageName, ManifestName}; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4712d86dff6c44f9cdd8576a08349ccfa7d0ecca..d679588138ccec0f8d9fd830d26d13f2f65d44a3 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -16,13 +16,9 @@ doctest = false test-support = [] [dependencies] -anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true -cloud_api_client.workspace = true -cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true env_var.workspace = true futures.workspace = true @@ -30,14 +26,11 @@ gpui.workspace = true http_client.workspace = true icons.workspace = true image.workspace = true +language_model_core.workspace = true log.workspace = true -open_ai = { workspace = true, features = ["schemars"] } -open_router.workspace = true parking_lot.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true -smol.workspace = true thiserror.workspace = true util.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 50037f31facbac446de7ecf38536d1e4a24c7867..cee65c21e575e7c96579c271805386527a29d4da 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -5,11 +5,10 @@ use crate::{ LanguageModelRequest, LanguageModelToolChoice, }; use anyhow::anyhow; -use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; -use smol::stream::StreamExt; use std::sync::{ Arc, atomic::{AtomicBool, Ordering::SeqCst}, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 3f309b7b1d4152c54324efaaf0ad3bdb7035eea4..60e8228fec52ffee763e19541f042ce47246dad2 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,380 +1,31 @@ mod api_key; mod model; -mod provider; -mod rate_limiter; mod registry; mod request; -mod role; -pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Result, anyhow}; -use cloud_llm_client::CompletionRequestStatus; +pub use language_model_core::*; + +use anyhow::Result; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::{StatusCode, http}; +use gpui::{AnyView, App, AsyncApp, Task, Window}; use icons::IconName; use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; -use std::ops::{Add, Sub}; -use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; -use std::{fmt, io}; -use thiserror::Error; -use util::serde::is_default; pub use crate::api_key::{ApiKey, ApiKeyState}; pub use crate::model::*; -pub use crate::rate_limiter::*; pub use crate::registry::*; -pub use crate::request::*; -pub use crate::role::*; -pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui}; pub use env_var::{EnvVar, env_var}; -pub use provider::*; pub fn init(cx: &mut App) { registry::init(cx); } -#[derive(Clone, Debug)] -pub struct LanguageModelCacheConfiguration { - pub max_cache_anchors: usize, - pub should_speculate: bool, - pub min_total_token: u64, -} - -/// A completion event from a language model. -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -pub enum LanguageModelCompletionEvent { - Queued { - position: usize, - }, - Started, - Stop(StopReason), - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking { - data: String, - }, - ToolUse(LanguageModelToolUse), - ToolUseJsonParseError { - id: LanguageModelToolUseId, - tool_name: Arc, - raw_input: Arc, - json_parse_error: String, - }, - StartMessage { - message_id: String, - }, - ReasoningDetails(serde_json::Value), - UsageUpdate(TokenUsage), -} - -impl LanguageModelCompletionEvent { - pub fn from_completion_request_status( - status: CompletionRequestStatus, - upstream_provider: LanguageModelProviderName, - ) -> Result, LanguageModelCompletionError> { - match status { - CompletionRequestStatus::Queued { position } => { - Ok(Some(LanguageModelCompletionEvent::Queued { position })) - } - CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), - CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), - CompletionRequestStatus::Failed { - code, - message, - request_id: _, - retry_after, - } => Err(LanguageModelCompletionError::from_cloud_failure( - upstream_provider, - code, - message, - retry_after.map(Duration::from_secs_f64), - )), - } - } -} - -#[derive(Error, Debug)] -pub enum LanguageModelCompletionError { - #[error("prompt too large for context window")] - PromptTooLarge { tokens: Option }, - #[error("missing {provider} API key")] - NoApiKey { provider: LanguageModelProviderName }, - #[error("{provider}'s API rate limit exceeded")] - RateLimitExceeded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API servers are overloaded right now")] - ServerOverloaded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API server reported an internal server error: {message}")] - ApiInternalServerError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("{message}")] - UpstreamProviderError { - message: String, - status: StatusCode, - retry_after: Option, - }, - #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] - HttpResponseError { - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - }, - - // Client errors - #[error("invalid request format to {provider}'s API: {message}")] - BadRequestFormat { - provider: LanguageModelProviderName, - message: String, - }, - #[error("authentication error with {provider}'s API: {message}")] - AuthenticationError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("Permission error with {provider}'s API: {message}")] - PermissionError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("language model provider API endpoint not found")] - ApiEndpointNotFound { provider: LanguageModelProviderName }, - #[error("I/O error reading response from {provider}'s API")] - ApiReadResponseError { - provider: LanguageModelProviderName, - #[source] - error: io::Error, - }, - #[error("error serializing request to {provider} API")] - SerializeRequest { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - #[error("error building request body to {provider} API")] - BuildRequestBody { - provider: LanguageModelProviderName, - #[source] - error: http::Error, - }, - #[error("error sending HTTP request to {provider} API")] - HttpSend { - provider: LanguageModelProviderName, - #[source] - error: anyhow::Error, - }, - #[error("error deserializing {provider} API response")] - DeserializeResponse { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - - #[error("stream from {provider} ended unexpectedly")] - StreamEndedUnexpectedly { provider: LanguageModelProviderName }, - - // TODO: Ideally this would be removed in favor of having a comprehensive list of errors. - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl LanguageModelCompletionError { - fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { - let error_json = serde_json::from_str::(message).ok()?; - let upstream_status = error_json - .get("upstream_status") - .and_then(|v| v.as_u64()) - .and_then(|status| u16::try_from(status).ok()) - .and_then(|status| StatusCode::from_u16(status).ok())?; - let inner_message = error_json - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or(message) - .to_string(); - Some((upstream_status, inner_message)) - } - - pub fn from_cloud_failure( - upstream_provider: LanguageModelProviderName, - code: String, - message: String, - retry_after: Option, - ) -> Self { - if let Some(tokens) = parse_prompt_too_long(&message) { - // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR - // to be reported. This is a temporary workaround to handle this in the case where the - // token limit has been exceeded. - Self::PromptTooLarge { - tokens: Some(tokens), - } - } else if code == "upstream_http_error" { - if let Some((upstream_status, inner_message)) = - Self::parse_upstream_error_json(&message) - { - return Self::from_http_status( - upstream_provider, - upstream_status, - inner_message, - retry_after, - ); - } - anyhow!("completion request failed, code: {code}, message: {message}").into() - } else if let Some(status_code) = code - .strip_prefix("upstream_http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(upstream_provider, status_code, message, retry_after) - } else if let Some(status_code) = code - .strip_prefix("http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) - } else { - anyhow!("completion request failed, code: {code}, message: {message}").into() - } - } - - pub fn from_http_status( - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - retry_after: Option, - ) -> Self { - match status_code { - StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, - StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, - StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, - StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, - StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&message), - }, - StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { - provider, - retry_after, - }, - StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, - StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { - provider, - retry_after, - }, - _ if status_code.as_u16() == 529 => Self::ServerOverloaded { - provider, - retry_after, - }, - _ => Self::HttpResponseError { - provider, - status_code, - message, - }, - } - } -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum StopReason { - EndTurn, - MaxTokens, - ToolUse, - Refusal, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] -pub struct TokenUsage { - #[serde(default, skip_serializing_if = "is_default")] - pub input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub output_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_creation_input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_read_input_tokens: u64, -} - -impl TokenUsage { - pub fn total_tokens(&self) -> u64 { - self.input_tokens - + self.output_tokens - + self.cache_read_input_tokens - + self.cache_creation_input_tokens - } -} - -impl Add for TokenUsage { - type Output = Self; - - fn add(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens + other.input_tokens, - output_tokens: self.output_tokens + other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - + other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, - } - } -} - -impl Sub for TokenUsage { - type Output = Self; - - fn sub(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens - other.input_tokens, - output_tokens: self.output_tokens - other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - - other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUseId(Arc); - -impl fmt::Display for LanguageModelToolUseId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelToolUseId -where - T: Into>, -{ - fn from(value: T) -> Self { - Self(value.into()) - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUse { - pub id: LanguageModelToolUseId, - pub name: Arc, - pub raw_input: String, - pub input: serde_json::Value, - pub is_input_complete: bool, - /// Thought signature the model sent us. Some models require that this - /// signature be preserved and sent back in conversation history for validation. - pub thought_signature: Option, -} - pub struct LanguageModelTextStream { pub message_id: Option, pub stream: BoxStream<'static, Result>, @@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream { } } -#[derive(Debug, Clone)] -pub struct LanguageModelEffortLevel { - pub name: SharedString, - pub value: SharedString, - pub is_default: bool, -} - pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; @@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync { } impl std::fmt::Debug for dyn LanguageModel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("") .field("id", &self.id()) .field("name", &self.name()) @@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel { } } -/// An error that occurred when trying to authenticate the language model provider. -#[derive(Debug, Error)] -pub enum AuthenticateError { - #[error("connection refused")] - ConnectionRefused, - #[error("credentials not found")] - CredentialsNotFound, - #[error(transparent)] - Other(#[from] anyhow::Error), -} - /// Either a built-in icon name or a path to an external SVG. #[derive(Debug, Clone, PartialEq, Eq)] pub enum IconOrSvg { @@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static { } } -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] -pub struct LanguageModelId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelName(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderName(pub SharedString); - #[derive(Clone, Debug, PartialEq)] pub enum LanguageModelCostInfo { /// Cost per 1,000 input and output tokens @@ -741,245 +362,3 @@ impl LanguageModelCostInfo { } } } - -impl LanguageModelProviderId { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl LanguageModelProviderName { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl fmt::Display for LanguageModelProviderId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl fmt::Display for LanguageModelProviderName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderId { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderName { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_cloud_failure_with_upstream_http_error() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!(message, "Internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_from_cloud_failure_with_standard_format() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_503".to_string(), - "Service unavailable".to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!("Expected ServerOverloaded error for upstream_http_503"), - } - } - - #[test] - fn test_upstream_http_error_connection_timeout() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" - ); - } - _ => panic!( - "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_language_model_tool_use_serializes_with_signature() { - use serde_json::json; - - let tool_use = LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature".to_string()), - }; - - let serialized = serde_json::to_value(&tool_use).unwrap(); - - assert_eq!(serialized["id"], "test_id"); - assert_eq!(serialized["name"], "test_tool"); - assert_eq!(serialized["thought_signature"], "test_signature"); - } - - #[test] - fn test_language_model_tool_use_deserializes_with_missing_signature() { - use serde_json::json; - - let json = json!({ - "id": "test_id", - "name": "test_tool", - "raw_input": "{\"arg\":\"value\"}", - "input": {"arg": "value"}, - "is_input_complete": true - }); - - let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); - - assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); - assert_eq!(tool_use.name.as_ref(), "test_tool"); - assert_eq!(tool_use.thought_signature, None); - } - - #[test] - fn test_language_model_tool_use_round_trip_with_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("round_trip_id"), - name: "round_trip_tool".into(), - raw_input: json!({"key": "value"}).to_string(), - input: json!({"key": "value"}), - is_input_complete: true, - thought_signature: Some("round_trip_sig".to_string()), - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, original.thought_signature); - } - - #[test] - fn test_language_model_tool_use_round_trip_without_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("no_sig_id"), - name: "no_sig_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, None); - } -} diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c..8cd71928b10fb1e86f3df40ca118305c198c094f 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,10 +1,5 @@ use std::fmt; -use std::sync::Arc; -use cloud_api_client::ClientApiError; -use cloud_api_client::CloudApiClient; -use cloud_api_types::OrganizationId; -use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; #[derive(Error, Debug)] @@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError { ) } } - -#[derive(Clone, Default)] -pub struct LlmApiToken(Arc>>); - -impl LlmApiToken { - pub async fn acquire( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let lock = self.0.upgradable_read().await; - if let Some(token) = lock.as_ref() { - Ok(token.to_string()) - } else { - Self::fetch( - RwLockUpgradableReadGuard::upgrade(lock).await, - client, - system_id, - organization_id, - ) - .await - } - } - - pub async fn refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - Self::fetch(self.0.write().await, client, system_id, organization_id).await - } - - /// Clears the existing token before attempting to fetch a new one. - /// - /// Used when switching organizations so that a failed refresh doesn't - /// leave a token for the wrong organization. - pub async fn clear_and_refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let mut lock = self.0.write().await; - *lock = None; - Self::fetch(lock, client, system_id, organization_id).await - } - - async fn fetch( - mut lock: RwLockWriteGuard<'_, Option>, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let result = client.create_llm_token(system_id, organization_id).await; - match result { - Ok(response) => { - *lock = Some(response.token.0.clone()); - Ok(response.token.0) - } - Err(err) => { - *lock = None; - Err(err) - } - } - } -} diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs deleted file mode 100644 index 707d8e2d618894e2898e253450dbfbb5e9483bba..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod anthropic; -pub mod google; -pub mod open_ai; -pub mod open_router; -pub mod x_ai; -pub mod zed; - -pub use anthropic::*; -pub use google::*; -pub use open_ai::*; -pub use x_ai::*; -pub use zed::*; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs deleted file mode 100644 index 0878be2070fdbb9e57145684f59c962a32bb9fd2..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/anthropic.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use anthropic::AnthropicError; -pub use anthropic::parse_prompt_too_long; - -pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = - LanguageModelProviderId::new("anthropic"); -pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Anthropic"); - -impl From for LanguageModelCompletionError { - fn from(error: AnthropicError) -> Self { - let provider = ANTHROPIC_PROVIDER_NAME; - match error { - AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, - AnthropicError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - AnthropicError::HttpResponseError { - status_code, - message, - } => Self::HttpResponseError { - provider, - status_code, - message, - }, - AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - AnthropicError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: anthropic::ApiError) -> Self { - use anthropic::ApiErrorCode::*; - let provider = ANTHROPIC_PROVIDER_NAME; - match error.code() { - Some(code) => match code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - NotFoundError => Self::ApiEndpointNotFound { provider }, - RequestTooLarge => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&error.message), - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - }, - None => Self::Other(error.into()), - } - } -} diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs deleted file mode 100644 index 1caee496b519f395dd10744b127bc29ee893849f..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/google.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); -pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Google AI"); diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs deleted file mode 100644 index 3796eb9a3aef78628c52d92e92fabb3812249e04..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_ai.rs +++ /dev/null @@ -1,28 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use http_client::http; -use std::time::Duration; - -pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); -pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("OpenAI"); - -impl From for LanguageModelCompletionError { - fn from(error: open_ai::RequestError) -> Self { - match error { - open_ai::RequestError::HttpResponseError { - provider, - status_code, - body, - headers, - } => { - let retry_after = headers - .get(http::header::RETRY_AFTER) - .and_then(|val| val.to_str().ok()?.parse::().ok()) - .map(Duration::from_secs); - - Self::from_http_status(provider.into(), status_code, body, retry_after) - } - open_ai::RequestError::Other(e) => Self::Other(e), - } - } -} diff --git a/crates/language_model/src/provider/open_router.rs b/crates/language_model/src/provider/open_router.rs deleted file mode 100644 index 809e22f1fec0f2d205caa3ebbcb0baaf129b062c..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_router.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderName}; -use http_client::StatusCode; -use open_router::OpenRouterError; - -impl From for LanguageModelCompletionError { - fn from(error: OpenRouterError) -> Self { - let provider = LanguageModelProviderName::new("OpenRouter"); - match error { - OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, - OpenRouterError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - OpenRouterError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: open_router::ApiError) -> Self { - use open_router::ApiErrorCode::*; - let provider = LanguageModelProviderName::new("OpenRouter"); - match error.code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PaymentRequiredError => Self::AuthenticationError { - provider, - message: format!("Payment required: {}", error.message), - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - RequestTimedOut => Self::HttpResponseError { - provider, - status_code: StatusCode::REQUEST_TIMEOUT, - message: error.message, - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - } - } -} diff --git a/crates/language_model/src/provider/x_ai.rs b/crates/language_model/src/provider/x_ai.rs deleted file mode 100644 index 3d0f794fa4087a4beeb4a9b6253d016a9b592f0e..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/x_ai.rs +++ /dev/null @@ -1,4 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); -pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); diff --git a/crates/language_model/src/provider/zed.rs b/crates/language_model/src/provider/zed.rs deleted file mode 100644 index 0ba793e99aad1caa25f049a96faf02c16e8970fa..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/zed.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); -pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index bf14fbb0b5804505b33074e6e4cbcc36ddf21fab..680078808ab33cc2a90caead8b304326beccf11b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,6 +1,6 @@ use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderState, + LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, }; use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; @@ -101,7 +101,7 @@ impl ConfiguredModel { } pub fn is_provided_by_zed(&self) -> bool { - self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID + self.provider.id() == ZED_CLOUD_PROVIDER_ID } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 9a5e96078cd4d952185261c79032c5c5fdf30060..ef73864fe3e2f5b58e73dec848c686123a61fcde 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -4,78 +4,13 @@ use std::sync::Arc; use anyhow::Result; use base64::write::EncoderWriter; use gpui::{ - App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, - point, px, size, + App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size, }; use image::GenericImageView as _; use image::codecs::png::PngEncoder; -use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; - -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub struct LanguageModelImage { - /// A base64-encoded PNG image. - pub source: SharedString, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub size: Option>, -} - -impl LanguageModelImage { - pub fn len(&self) -> usize { - self.source.len() - } - - pub fn is_empty(&self) -> bool { - self.source.is_empty() - } - - // Parse Self from a JSON object with case-insensitive field names - pub fn from_json(obj: &serde_json::Map) -> Option { - let mut source = None; - let mut size_obj = None; - - // Find source and size fields (case-insensitive) - for (k, v) in obj.iter() { - match k.to_lowercase().as_str() { - "source" => source = v.as_str(), - "size" => size_obj = v.as_object(), - _ => {} - } - } - - let source = source?; - let size_obj = size_obj?; - - let mut width = None; - let mut height = None; - - // Find width and height in size object (case-insensitive) - for (k, v) in size_obj.iter() { - match k.to_lowercase().as_str() { - "width" => width = v.as_i64().map(|w| w as i32), - "height" => height = v.as_i64().map(|h| h as i32), - _ => {} - } - } - - Some(Self { - size: Some(size(DevicePixels(width?), DevicePixels(height?))), - source: SharedString::from(source.to_string()), - }) - } -} - -impl std::fmt::Debug for LanguageModelImage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LanguageModelImage") - .field("source", &format!("<{} bytes>", self.source.len())) - .field("size", &self.size) - .finish() - } -} +use language_model_core::{ImageSize, LanguageModelImage}; /// Anthropic wants uploaded images to be smaller than this in both dimensions. const ANTHROPIC_SIZE_LIMIT: f32 = 1568.; @@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024; /// `DEFAULT_IMAGE_MAX_BYTES`. const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8; -impl LanguageModelImage { - // All language model images are encoded as PNGs. - pub const FORMAT: ImageFormat = ImageFormat::Png; +/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality. +pub trait LanguageModelImageExt { + const FORMAT: ImageFormat; + fn from_image(data: Arc, cx: &mut App) -> Task>; +} - pub fn empty() -> Self { - Self { - source: "".into(), - size: None, - } - } +impl LanguageModelImageExt for LanguageModelImage { + const FORMAT: ImageFormat = ImageFormat::Png; - pub fn from_image(data: Arc, cx: &mut App) -> Task> { + fn from_image(data: Arc, cx: &mut App) -> Task> { cx.background_spawn(async move { let image_bytes = Cursor::new(data.bytes()); let dynamic_image = match data.format() { @@ -186,28 +119,14 @@ impl LanguageModelImage { let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { - size: Some(image_size), + size: Some(ImageSize { + width: width as i32, + height: height as i32, + }), source: source.into(), }) }) } - - pub fn estimate_tokens(&self) -> usize { - let Some(size) = self.size.as_ref() else { - return 0; - }; - let width = size.width.0.unsigned_abs() as usize; - let height = size.height.0.unsigned_abs() as usize; - - // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs - // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this, - // so this method is more of a rough guess. - (width * height) / 750 - } - - pub fn to_base64_url(&self) -> String { - format!("data:image/png;base64,{}", self.source) - } } fn encode_png_bytes(image: &image::DynamicImage) -> Result> { @@ -228,512 +147,85 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result> { Ok(base64_image) } -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub struct LanguageModelToolResult { - pub tool_use_id: LanguageModelToolUseId, - pub tool_name: Arc, - pub is_error: bool, - /// The tool output formatted for presenting to the model - pub content: LanguageModelToolResultContent, - /// The raw tool output, if available, often for debugging or extra state for replay - pub output: Option, -} - -#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] -pub enum LanguageModelToolResultContent { - Text(Arc), - Image(LanguageModelImage), -} - -impl<'de> Deserialize<'de> for LanguageModelToolResultContent { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::Error; - - let value = serde_json::Value::deserialize(deserializer)?; - - // Models can provide these responses in several styles. Try each in order. - - // 1. Try as plain string - if let Ok(text) = serde_json::from_value::(value.clone()) { - return Ok(Self::Text(Arc::from(text))); - } - - // 2. Try as object - if let Some(obj) = value.as_object() { - // get a JSON field case-insensitively - fn get_field<'a>( - obj: &'a serde_json::Map, - field: &str, - ) -> Option<&'a serde_json::Value> { - obj.iter() - .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) - .map(|(_, v)| v) - } - - // Accept wrapped text format: { "type": "text", "text": "..." } - if let (Some(type_value), Some(text_value)) = - (get_field(obj, "type"), get_field(obj, "text")) - && let Some(type_str) = type_value.as_str() - && type_str.to_lowercase() == "text" - && let Some(text) = text_value.as_str() - { - return Ok(Self::Text(Arc::from(text))); - } - - // Check for wrapped Text variant: { "text": "..." } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") - && obj.len() == 1 - { - // Only one field, and it's "text" (case-insensitive) - if let Some(text) = value.as_str() { - return Ok(Self::Text(Arc::from(text))); - } - } - - // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") - && obj.len() == 1 - { - // Only one field, and it's "image" (case-insensitive) - // Try to parse the nested image object - if let Some(image_obj) = value.as_object() - && let Some(image) = LanguageModelImage::from_json(image_obj) - { - return Ok(Self::Image(image)); - } - } - - // Try as direct Image (object with "source" and "size" fields) - if let Some(image) = LanguageModelImage::from_json(obj) { - return Ok(Self::Image(image)); - } - } - - // If none of the variants match, return an error with the problematic JSON - Err(D::Error::custom(format!( - "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ - an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", - serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) - ))) - } -} - -impl LanguageModelToolResultContent { - pub fn to_str(&self) -> Option<&str> { - match self { - Self::Text(text) => Some(text), - Self::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - Self::Text(text) => text.chars().all(|c| c.is_whitespace()), - Self::Image(_) => false, - } - } -} - -impl From<&str> for LanguageModelToolResultContent { - fn from(value: &str) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(value: String) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(image: LanguageModelImage) -> Self { - Self::Image(image) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub enum MessageContent { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), - Image(LanguageModelImage), - ToolUse(LanguageModelToolUse), - ToolResult(LanguageModelToolResult), -} - -impl MessageContent { - pub fn to_str(&self) -> Option<&str> { - match self { - MessageContent::Text(text) => Some(text.as_str()), - MessageContent::Thinking { text, .. } => Some(text.as_str()), - MessageContent::RedactedThinking(_) => None, - MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), - MessageContent::ToolUse(_) | MessageContent::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), - MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), - MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), - MessageContent::RedactedThinking(_) - | MessageContent::ToolUse(_) - | MessageContent::Image(_) => false, - } - } -} - -impl From for MessageContent { - fn from(value: String) -> Self { - MessageContent::Text(value) - } -} - -impl From<&str> for MessageContent { - fn from(value: &str) -> Self { - MessageContent::Text(value.to_string()) - } -} - -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: Vec, - pub cache: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reasoning_details: Option, -} - -impl LanguageModelRequestMessage { - pub fn string_contents(&self) -> String { - let mut buffer = String::new(); - for string in self.content.iter().filter_map(|content| content.to_str()) { - buffer.push_str(string); - } - - buffer - } - - pub fn contents_empty(&self) -> bool { - self.content.iter().all(|content| content.is_empty()) - } -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelRequestTool { - pub name: String, - pub description: String, - pub input_schema: serde_json::Value, - pub use_input_streaming: bool, -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub enum LanguageModelToolChoice { - Auto, - Any, - None, -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionIntent { - UserPrompt, - Subagent, - ToolResults, - ThreadSummarization, - ThreadContextSummarization, - CreateFile, - EditFile, - InlineAssist, - TerminalInlineAssist, - GenerateGitCommitMessage, -} - -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct LanguageModelRequest { - pub thread_id: Option, - pub prompt_id: Option, - pub intent: Option, - pub messages: Vec, - pub tools: Vec, - pub tool_choice: Option, - pub stop: Vec, - pub temperature: Option, - pub thinking_allowed: bool, - pub thinking_effort: Option, - pub speed: Option, -} - -#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum Speed { - #[default] - Standard, - Fast, -} - -impl Speed { - pub fn toggle(self) -> Self { - match self { - Speed::Standard => Speed::Fast, - Speed::Fast => Speed::Standard, - } +/// Convert a core `ImageSize` to a gpui `Size`. +pub fn image_size_to_gpui(size: ImageSize) -> Size { + Size { + width: DevicePixels(size.width), + height: DevicePixels(size.height), } } -impl From for anthropic::Speed { - fn from(speed: Speed) -> Self { - match speed { - Speed::Standard => anthropic::Speed::Standard, - Speed::Fast => anthropic::Speed::Fast, - } +/// Convert a gpui `Size` to a core `ImageSize`. +pub fn gpui_size_to_image_size(size: Size) -> ImageSize { + ImageSize { + width: size.width.0, + height: size.height.0, } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[cfg(test)] mod tests { use super::*; use base64::Engine as _; use gpui::TestAppContext; - use image::ImageDecoder as _; - fn base64_to_png_bytes(base64_png: &str) -> Vec { + fn base64_to_png_bytes(base64: &str) -> Vec { base64::engine::general_purpose::STANDARD - .decode(base64_png.as_bytes()) - .expect("base64 should decode") + .decode(base64) + .expect("valid base64") } fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) { - let decoder = - image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode"); - decoder.dimensions() + let img = image::load_from_memory(png_bytes).expect("valid png"); + (img.width(), img.height()) } fn make_noisy_png_bytes(width: u32, height: u32) -> Vec { - // Create an RGBA image with per-pixel variance to avoid PNG compressing too well. - let mut img = image::RgbaImage::new(width, height); - for y in 0..height { - for x in 0..width { - let r = ((x ^ y) & 0xFF) as u8; - let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8; - let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8; - img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF])); - } - } + use image::{ImageBuffer, Rgba}; + use std::hash::{Hash, Hasher}; + + let img = ImageBuffer::from_fn(width, height, |x, y| { + let mut hasher = std::hash::DefaultHasher::new(); + (x, y, width, height).hash(&mut hasher); + let h = hasher.finish(); + Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255]) + }); - let mut out = Vec::new(); - image::DynamicImage::ImageRgba8(img) - .write_with_encoder(PngEncoder::new(&mut out)) - .expect("png encoding should succeed"); - out + let mut buf = Cursor::new(Vec::new()); + img.write_with_encoder(PngEncoder::new(&mut buf)) + .expect("encode"); + buf.into_inner() } #[gpui::test] async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) { - // Pick a size that reliably produces a PNG > 5MB when filled with noise. - // If this fails (image is too small), bump dimensions. - let original_png = make_noisy_png_bytes(4096, 4096); + let raw_png = make_noisy_png_bytes(4096, 4096); assert!( - original_png.len() > DEFAULT_IMAGE_MAX_BYTES, - "precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES" + raw_png.len() > DEFAULT_IMAGE_MAX_BYTES, + "Test image should exceed the 5 MB limit (actual: {} bytes)", + raw_png.len() ); - let image = gpui::Image::from_bytes(ImageFormat::Png, original_png); + let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png)); let lm_image = cx - .update(|cx| LanguageModelImage::from_image(Arc::new(image), cx)) + .update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx)) .await - .expect("image conversion should succeed"); + .expect("from_image should succeed"); - let encoded_png = base64_to_png_bytes(lm_image.source.as_ref()); + let decoded_png = base64_to_png_bytes(lm_image.source.as_ref()); assert!( - encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, - "expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes", - encoded_png.len() + decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, + "Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes", + DEFAULT_IMAGE_MAX_BYTES, + decoded_png.len() ); - // Ensure we actually downscaled in pixels (not just re-encoded). - let (w, h) = png_dimensions(&encoded_png); + let (w, h) = png_dimensions(&decoded_png); assert!( - w < 4096 || h < 4096, - "expected image to be downscaled in at least one dimension; got {w}x{h}" - ); - } - - #[test] - fn test_language_model_tool_result_content_deserialization() { - let json = r#""This is plain text""#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is plain text".into()) - ); - - let json = r#"{"type": "text", "text": "This is wrapped text"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is wrapped text".into()) - ); - - let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Case insensitive".into()) - ); - - let json = r#"{"Text": "Wrapped variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Wrapped variant".into()) - ); - - let json = r#"{"text": "Lowercase wrapped"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Lowercase wrapped".into()) + w < 4096 && h < 4096, + "Dimensions should have shrunk: got {}×{}", + w, + h ); - - // Test image deserialization - let json = r#"{ - "source": "base64encodedimagedata", - "size": { - "width": 100, - "height": 200 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "base64encodedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 100); - assert_eq!(size.height.0, 200); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant - let json = r#"{ - "Image": { - "source": "wrappedimagedata", - "size": { - "width": 50, - "height": 75 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "wrappedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 50); - assert_eq!(size.height.0, 75); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant with case insensitive - let json = r#"{ - "image": { - "Source": "caseinsensitive", - "SIZE": { - "width": 30, - "height": 40 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "caseinsensitive"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 30); - assert_eq!(size.height.0, 40); - } - _ => panic!("Expected Image variant"), - } - - // Test that wrapped text with wrong type fails - let json = r#"{"type": "blahblah", "text": "This should fail"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test that malformed JSON fails - let json = r#"{"invalid": "structure"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test edge cases - let json = r#""""#; // Empty string - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("".into())); - - // Test with extra fields in wrapped text (should be ignored) - let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); - - // Test direct image with case-insensitive fields - let json = r#"{ - "SOURCE": "directimage", - "Size": { - "width": 200, - "height": 300 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "directimage"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 200); - assert_eq!(size.height.0, 300); - } - _ => panic!("Expected Image variant"), - } - - // Test that multiple fields prevent wrapped variant interpretation - let json = r#"{"Text": "not wrapped", "extra": "field"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test wrapped text with uppercase TEXT variant - let json = r#"{"TEXT": "Uppercase variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Uppercase variant".into()) - ); - - // Test that numbers and other JSON values fail gracefully - let json = r#"123"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"null"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"[1, 2, 3]"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); } } diff --git a/crates/language_model_core/Cargo.toml b/crates/language_model_core/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7a6de00f3e4a774537d93e2f77ea9107845a7c50 --- /dev/null +++ b/crates/language_model_core/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "language_model_core" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model_core.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +gpui_shared_string.workspace = true +http_client.workspace = true +partial-json-fixer.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +strum.workspace = true +thiserror.workspace = true diff --git a/crates/language_model_core/LICENSE-GPL b/crates/language_model_core/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_model_core/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model_core/src/language_model_core.rs b/crates/language_model_core/src/language_model_core.rs new file mode 100644 index 0000000000000000000000000000000000000000..5f932690869a2c17ec1c89cbe9401bcdef6e1e73 --- /dev/null +++ b/crates/language_model_core/src/language_model_core.rs @@ -0,0 +1,658 @@ +mod provider; +mod rate_limiter; +mod request; +mod role; +pub mod tool_schema; +pub mod util; + +use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionRequestStatus; +use http_client::{StatusCode, http}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::ops::{Add, Sub}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use std::{fmt, io}; +use thiserror::Error; +fn is_default(value: &T) -> bool { + *value == T::default() +} + +pub use crate::provider::*; +pub use crate::rate_limiter::*; +pub use crate::request::*; +pub use crate::role::*; +pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments}; +pub use gpui_shared_string::SharedString; + +#[derive(Clone, Debug)] +pub struct LanguageModelCacheConfiguration { + pub max_cache_anchors: usize, + pub should_speculate: bool, + pub min_total_token: u64, +} + +/// A completion event from a language model. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub enum LanguageModelCompletionEvent { + Queued { + position: usize, + }, + Started, + Stop(StopReason), + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking { + data: String, + }, + ToolUse(LanguageModelToolUse), + ToolUseJsonParseError { + id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + }, + StartMessage { + message_id: String, + }, + ReasoningDetails(serde_json::Value), + UsageUpdate(TokenUsage), +} + +impl LanguageModelCompletionEvent { + pub fn from_completion_request_status( + status: CompletionRequestStatus, + upstream_provider: LanguageModelProviderName, + ) -> Result, LanguageModelCompletionError> { + match status { + CompletionRequestStatus::Queued { position } => { + Ok(Some(LanguageModelCompletionEvent::Queued { position })) + } + CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), + CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), + CompletionRequestStatus::Failed { + code, + message, + request_id: _, + retry_after, + } => Err(LanguageModelCompletionError::from_cloud_failure( + upstream_provider, + code, + message, + retry_after.map(Duration::from_secs_f64), + )), + } + } +} + +#[derive(Error, Debug)] +pub enum LanguageModelCompletionError { + #[error("prompt too large for context window")] + PromptTooLarge { tokens: Option }, + #[error("missing {provider} API key")] + NoApiKey { provider: LanguageModelProviderName }, + #[error("{provider}'s API rate limit exceeded")] + RateLimitExceeded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API servers are overloaded right now")] + ServerOverloaded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API server reported an internal server error: {message}")] + ApiInternalServerError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("{message}")] + UpstreamProviderError { + message: String, + status: StatusCode, + retry_after: Option, + }, + #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] + HttpResponseError { + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + }, + #[error("invalid request format to {provider}'s API: {message}")] + BadRequestFormat { + provider: LanguageModelProviderName, + message: String, + }, + #[error("authentication error with {provider}'s API: {message}")] + AuthenticationError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("Permission error with {provider}'s API: {message}")] + PermissionError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("language model provider API endpoint not found")] + ApiEndpointNotFound { provider: LanguageModelProviderName }, + #[error("I/O error reading response from {provider}'s API")] + ApiReadResponseError { + provider: LanguageModelProviderName, + #[source] + error: io::Error, + }, + #[error("error serializing request to {provider} API")] + SerializeRequest { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("error building request body to {provider} API")] + BuildRequestBody { + provider: LanguageModelProviderName, + #[source] + error: http::Error, + }, + #[error("error sending HTTP request to {provider} API")] + HttpSend { + provider: LanguageModelProviderName, + #[source] + error: anyhow::Error, + }, + #[error("error deserializing {provider} API response")] + DeserializeResponse { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("stream from {provider} ended unexpectedly")] + StreamEndedUnexpectedly { provider: LanguageModelProviderName }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl LanguageModelCompletionError { + fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { + let error_json = serde_json::from_str::(message).ok()?; + let upstream_status = error_json + .get("upstream_status") + .and_then(|v| v.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .and_then(|status| StatusCode::from_u16(status).ok())?; + let inner_message = error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(message) + .to_string(); + Some((upstream_status, inner_message)) + } + + pub fn from_cloud_failure( + upstream_provider: LanguageModelProviderName, + code: String, + message: String, + retry_after: Option, + ) -> Self { + if let Some(tokens) = parse_prompt_too_long(&message) { + Self::PromptTooLarge { + tokens: Some(tokens), + } + } else if code == "upstream_http_error" { + if let Some((upstream_status, inner_message)) = + Self::parse_upstream_error_json(&message) + { + return Self::from_http_status( + upstream_provider, + upstream_status, + inner_message, + retry_after, + ); + } + anyhow!("completion request failed, code: {code}, message: {message}").into() + } else if let Some(status_code) = code + .strip_prefix("upstream_http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(upstream_provider, status_code, message, retry_after) + } else if let Some(status_code) = code + .strip_prefix("http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) + } else { + anyhow!("completion request failed, code: {code}, message: {message}").into() + } + } + + pub fn from_http_status( + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + retry_after: Option, + ) -> Self { + match status_code { + StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, + StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, + StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, + StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, + StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { + tokens: parse_prompt_too_long(&message), + }, + StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { + provider, + retry_after, + }, + StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, + StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { + provider, + retry_after, + }, + _ if status_code.as_u16() == 529 => Self::ServerOverloaded { + provider, + retry_after, + }, + _ => Self::HttpResponseError { + provider, + status_code, + message, + }, + } + } +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + MaxTokens, + ToolUse, + Refusal, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] +pub struct TokenUsage { + #[serde(default, skip_serializing_if = "is_default")] + pub input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub output_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_creation_input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_read_input_tokens: u64, +} + +impl TokenUsage { + pub fn total_tokens(&self) -> u64 { + self.input_tokens + + self.output_tokens + + self.cache_read_input_tokens + + self.cache_creation_input_tokens + } +} + +impl Add for TokenUsage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens + other.input_tokens, + output_tokens: self.output_tokens + other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + + other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, + } + } +} + +impl Sub for TokenUsage { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens - other.input_tokens, + output_tokens: self.output_tokens - other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + - other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUseId(Arc); + +impl fmt::Display for LanguageModelToolUseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelToolUseId +where + T: Into>, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUse { + pub id: LanguageModelToolUseId, + pub name: Arc, + pub raw_input: String, + pub input: serde_json::Value, + pub is_input_complete: bool, + /// Thought signature the model sent us. Some models require that this + /// signature be preserved and sent back in conversation history for validation. + pub thought_signature: Option, +} + +#[derive(Debug, Clone)] +pub struct LanguageModelEffortLevel { + pub name: SharedString, + pub value: SharedString, + pub is_default: bool, +} + +/// An error that occurred when trying to authenticate the language model provider. +#[derive(Debug, Error)] +pub enum AuthenticateError { + #[error("connection refused")] + ConnectionRefused, + #[error("credentials not found")] + CredentialsNotFound, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] +pub struct LanguageModelId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelName(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderName(pub SharedString); + +impl LanguageModelProviderId { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl LanguageModelProviderName { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl fmt::Display for LanguageModelProviderId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for LanguageModelProviderName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderId { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderName { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +/// Settings-layer–free model mode enum. +/// +/// Mirrors the shape of `settings_content::ModelMode` but lives here so that +/// crates below the settings layer can reference it. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + +/// Settings-layer–free reasoning-effort enum. +/// +/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives +/// here so that crates below the settings layer can reference it. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum ReasoningEffort { + Minimal, + Low, + Medium, + High, + XHigh, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_cloud_failure_with_upstream_http_error() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!(message, "Internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_from_cloud_failure_with_standard_format() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_503".to_string(), + "Service unavailable".to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!("Expected ServerOverloaded error for upstream_http_503"), + } + } + + #[test] + fn test_upstream_http_error_connection_timeout() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" + ); + } + _ => panic!( + "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_language_model_tool_use_serializes_with_signature() { + use serde_json::json; + + let tool_use = LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("test_signature".to_string()), + }; + + let serialized = serde_json::to_value(&tool_use).unwrap(); + + assert_eq!(serialized["id"], "test_id"); + assert_eq!(serialized["name"], "test_tool"); + assert_eq!(serialized["thought_signature"], "test_signature"); + } + + #[test] + fn test_language_model_tool_use_deserializes_with_missing_signature() { + use serde_json::json; + + let json = json!({ + "id": "test_id", + "name": "test_tool", + "raw_input": "{\"arg\":\"value\"}", + "input": {"arg": "value"}, + "is_input_complete": true + }); + + let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); + + assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); + assert_eq!(tool_use.name.as_ref(), "test_tool"); + assert_eq!(tool_use.thought_signature, None); + } + + #[test] + fn test_language_model_tool_use_round_trip_with_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("round_trip_id"), + name: "round_trip_tool".into(), + raw_input: json!({"key": "value"}).to_string(), + input: json!({"key": "value"}), + is_input_complete: true, + thought_signature: Some("round_trip_sig".to_string()), + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, original.thought_signature); + } + + #[test] + fn test_language_model_tool_use_round_trip_without_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("no_sig_id"), + name: "no_sig_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: None, + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, None); + } +} diff --git a/crates/language_model_core/src/provider.rs b/crates/language_model_core/src/provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..da8b208147ad1d5b58a35888dfd07c821965097c --- /dev/null +++ b/crates/language_model_core/src/provider.rs @@ -0,0 +1,21 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = + LanguageModelProviderId::new("anthropic"); +pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Anthropic"); + +pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); +pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("OpenAI"); + +pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); +pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Google AI"); + +pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); + +pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); +pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model_core/src/rate_limiter.rs similarity index 100% rename from crates/language_model/src/rate_limiter.rs rename to crates/language_model_core/src/rate_limiter.rs diff --git a/crates/language_model_core/src/request.rs b/crates/language_model_core/src/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..48f7f00522bc3dd5c06747d662761efb003886c0 --- /dev/null +++ b/crates/language_model_core/src/request.rs @@ -0,0 +1,463 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString}; + +/// Dimensions of a `LanguageModelImage` +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ImageSize { + pub width: i32, + pub height: i32, +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct LanguageModelImage { + /// A base64-encoded PNG image. + pub source: SharedString, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub size: Option, +} + +impl LanguageModelImage { + pub fn len(&self) -> usize { + self.source.len() + } + + pub fn is_empty(&self) -> bool { + self.source.is_empty() + } + + pub fn empty() -> Self { + Self { + source: "".into(), + size: None, + } + } + + /// Parse Self from a JSON object with case-insensitive field names + pub fn from_json(obj: &serde_json::Map) -> Option { + let mut source = None; + let mut size_obj = None; + + for (k, v) in obj.iter() { + match k.to_lowercase().as_str() { + "source" => source = v.as_str(), + "size" => size_obj = v.as_object(), + _ => {} + } + } + + let source = source?; + let size_obj = size_obj?; + + let mut width = None; + let mut height = None; + + for (k, v) in size_obj.iter() { + match k.to_lowercase().as_str() { + "width" => width = v.as_i64().map(|w| w as i32), + "height" => height = v.as_i64().map(|h| h as i32), + _ => {} + } + } + + Some(Self { + size: Some(ImageSize { + width: width?, + height: height?, + }), + source: SharedString::from(source.to_string()), + }) + } + + pub fn estimate_tokens(&self) -> usize { + let Some(size) = self.size.as_ref() else { + return 0; + }; + let width = size.width.unsigned_abs() as usize; + let height = size.height.unsigned_abs() as usize; + + // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs + (width * height) / 750 + } + + pub fn to_base64_url(&self) -> String { + format!("data:image/png;base64,{}", self.source) + } +} + +impl std::fmt::Debug for LanguageModelImage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanguageModelImage") + .field("source", &format!("<{} bytes>", self.source.len())) + .field("size", &self.size) + .finish() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub struct LanguageModelToolResult { + pub tool_use_id: LanguageModelToolUseId, + pub tool_name: Arc, + pub is_error: bool, + /// The tool output formatted for presenting to the model + pub content: LanguageModelToolResultContent, + /// The raw tool output, if available, often for debugging or extra state for replay + pub output: Option, +} + +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] +pub enum LanguageModelToolResultContent { + Text(Arc), + Image(LanguageModelImage), +} + +impl<'de> Deserialize<'de> for LanguageModelToolResultContent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let value = serde_json::Value::deserialize(deserializer)?; + + // 1. Try as plain string + if let Ok(text) = serde_json::from_value::(value.clone()) { + return Ok(Self::Text(Arc::from(text))); + } + + // 2. Try as object + if let Some(obj) = value.as_object() { + fn get_field<'a>( + obj: &'a serde_json::Map, + field: &str, + ) -> Option<&'a serde_json::Value> { + obj.iter() + .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) + .map(|(_, v)| v) + } + + // Accept wrapped text format: { "type": "text", "text": "..." } + if let (Some(type_value), Some(text_value)) = + (get_field(obj, "type"), get_field(obj, "text")) + && let Some(type_str) = type_value.as_str() + && type_str.to_lowercase() == "text" + && let Some(text) = text_value.as_str() + { + return Ok(Self::Text(Arc::from(text))); + } + + // Check for wrapped Text variant: { "text": "..." } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") + && obj.len() == 1 + { + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + + // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") + && obj.len() == 1 + { + if let Some(image_obj) = value.as_object() + && let Some(image) = LanguageModelImage::from_json(image_obj) + { + return Ok(Self::Image(image)); + } + } + + // Try as direct Image + if let Some(image) = LanguageModelImage::from_json(obj) { + return Ok(Self::Image(image)); + } + } + + Err(D::Error::custom(format!( + "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ + an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", + serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) + ))) + } +} + +impl LanguageModelToolResultContent { + pub fn to_str(&self) -> Option<&str> { + match self { + Self::Text(text) => Some(text), + Self::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), + Self::Image(_) => false, + } + } +} + +impl From<&str> for LanguageModelToolResultContent { + fn from(value: &str) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(value: String) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum MessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), + ToolResult(LanguageModelToolResult), +} + +impl MessageContent { + pub fn to_str(&self) -> Option<&str> { + match self { + MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Thinking { text, .. } => Some(text.as_str()), + MessageContent::RedactedThinking(_) => None, + MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), + MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), + MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => false, + } + } +} + +impl From for MessageContent { + fn from(value: String) -> Self { + MessageContent::Text(value) + } +} + +impl From<&str> for MessageContent { + fn from(value: &str) -> Self { + MessageContent::Text(value.to_string()) + } +} + +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: Vec, + pub cache: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_details: Option, +} + +impl LanguageModelRequestMessage { + pub fn string_contents(&self) -> String { + let mut buffer = String::new(); + for string in self.content.iter().filter_map(|content| content.to_str()) { + buffer.push_str(string); + } + buffer + } + + pub fn contents_empty(&self) -> bool { + self.content.iter().all(|content| content.is_empty()) + } +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelRequestTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, + pub use_input_streaming: bool, +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub enum LanguageModelToolChoice { + Auto, + Any, + None, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + Subagent, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct LanguageModelRequest { + pub thread_id: Option, + pub prompt_id: Option, + pub intent: Option, + pub messages: Vec, + pub tools: Vec, + pub tool_choice: Option, + pub stop: Vec, + pub temperature: Option, + pub thinking_allowed: bool, + pub thinking_effort: Option, + pub speed: Option, +} + +#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Speed { + #[default] + Standard, + Fast, +} + +impl Speed { + pub fn toggle(self) -> Self { + match self { + Speed::Standard => Speed::Fast, + Speed::Fast => Speed::Standard, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_model_tool_result_content_deserialization() { + // Test plain string + let json = serde_json::json!("hello world"); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello world")) + ); + + // Test wrapped text format: { "type": "text", "text": "..." } + let json = serde_json::json!({"type": "text", "text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test single-field text object: { "text": "..." } + let json = serde_json::json!({"text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test case-insensitive type field + let json = serde_json::json!({"Type": "Text", "Text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test image object + let json = serde_json::json!({ + "source": "base64encodedimagedata", + "size": {"width": 100, "height": 200} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "base64encodedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 100); + assert_eq!(size.height, 200); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped image: { "image": { "source": "...", "size": ... } } + let json = serde_json::json!({ + "image": { + "source": "wrappedimagedata", + "size": {"width": 50, "height": 75} + } + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "wrappedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 50); + assert_eq!(size.height, 75); + } + _ => panic!("Expected Image variant"), + } + + // Test case insensitive + let json = serde_json::json!({ + "Source": "caseinsensitive", + "Size": {"Width": 30, "Height": 40} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "caseinsensitive"); + let size = image.size.expect("size"); + assert_eq!(size.width, 30); + assert_eq!(size.height, 40); + } + _ => panic!("Expected Image variant"), + } + + // Test direct image object + let json = serde_json::json!({ + "source": "directimage", + "size": {"width": 200, "height": 300} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "directimage"); + let size = image.size.expect("size"); + assert_eq!(size.width, 200); + assert_eq!(size.height, 300); + } + _ => panic!("Expected Image variant"), + } + } +} diff --git a/crates/language_model/src/role.rs b/crates/language_model_core/src/role.rs similarity index 100% rename from crates/language_model/src/role.rs rename to crates/language_model_core/src/role.rs diff --git a/crates/language_model/src/tool_schema.rs b/crates/language_model_core/src/tool_schema.rs similarity index 92% rename from crates/language_model/src/tool_schema.rs rename to crates/language_model_core/src/tool_schema.rs index 878870482a7527bf815797d16e03ad8edc79642e..0e82b2f41081469c6c04d16765e8336eb903fd94 100644 --- a/crates/language_model/src/tool_schema.rs +++ b/crates/language_model_core/src/tool_schema.rs @@ -77,8 +77,6 @@ pub fn adapt_schema_to_format( } fn preprocess_json_schema(json: &mut Value) -> Result<()> { - // `additionalProperties` defaults to `false` unless explicitly specified. - // This prevents models from hallucinating tool parameters. if let Value::Object(obj) = json && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") { @@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { obj.insert("additionalProperties".to_string(), Value::Bool(false)); } - // OpenAI API requires non-missing `properties` if !obj.contains_key("properties") { obj.insert("properties".to_string(), Value::Object(Default::default())); } @@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { Ok(()) } -/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { if let Value::Object(obj) = json { const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"]; @@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [ ("format", |value| value.is_string()), - // Gemini doesn't support `additionalProperties` in any form (boolean or schema object) ("additionalProperties", |_| true), - // Gemini doesn't support `propertyNames` ("propertyNames", |_| true), ("exclusiveMinimum", |value| value.is_number()), ("exclusiveMaximum", |value| value.is_number()), @@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { } } - // If a type is not specified for an input parameter, add a default type if matches!(obj.get("description"), Some(Value::String(_))) && !obj.contains_key("type") && !(obj.contains_key("anyOf") @@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("type".to_string(), Value::String("string".to_string())); } - // Handle oneOf -> anyOf conversion if let Some(subschemas) = obj.get_mut("oneOf") && subschemas.is_array() { @@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("anyOf".to_string(), subschemas_clone); } - // Recursively process all nested objects and arrays for (_, value) in obj.iter_mut() { if let Value::Object(_) | Value::Array(_) = value { adapt_to_json_schema_subset(value)?; @@ -178,7 +169,6 @@ mod tests { }) ); - // Ensure that we do not add a type if it is an object let mut json = json!({ "description": { "value": "abc", @@ -221,7 +211,6 @@ mod tests { }) ); - // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property) let mut json = json!({ "description": "A test field", "type": "integer", @@ -239,7 +228,6 @@ mod tests { }) ); - // additionalProperties as an object schema is also unsupported by Gemini let mut json = json!({ "type": "object", "properties": { diff --git a/crates/language_models/src/provider/util.rs b/crates/language_model_core/src/util.rs similarity index 88% rename from crates/language_models/src/provider/util.rs rename to crates/language_model_core/src/util.rs index 76a02b6de40a3e36c7c506f11a6f6d34d2aaca3e..3db2e0b76fd76070aa4d30e97c525fa8f3460c9d 100644 --- a/crates/language_models/src/provider/util.rs +++ b/crates/language_model_core/src/util.rs @@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str { } } +/// Parses a "prompt is too long: N tokens ..." message and extracts the token count. +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse() + .ok() +} + #[cfg(test)] mod tests { use super::*; #[test] fn test_fix_streamed_json_strips_incomplete_escape() { - // Trailing `\` inside a string — incomplete escape sequence let fixed = fix_streamed_json(r#"{"text": "hello\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello"); @@ -52,7 +61,6 @@ mod tests { #[test] fn test_fix_streamed_json_preserves_complete_escape() { - // `\\` is a complete escape (literal backslash) let fixed = fix_streamed_json(r#"{"text": "hello\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -60,7 +68,6 @@ mod tests { #[test] fn test_fix_streamed_json_strips_escape_after_complete_escape() { - // `\\\` = complete `\\` (literal backslash) + incomplete `\` let fixed = fix_streamed_json(r#"{"text": "hello\\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -75,12 +82,10 @@ mod tests { #[test] fn test_fix_streamed_json_newline_escape_boundary() { - // Simulates a stream boundary landing between `\` and `n` let fixed = fix_streamed_json(r#"{"text": "line1\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1"); - // Next chunk completes the escape let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1\nline2"); @@ -88,8 +93,6 @@ mod tests { #[test] fn test_fix_streamed_json_incremental_delta_correctness() { - // This is the actual scenario that causes the bug: - // chunk 1 ends mid-escape, chunk 2 completes it. let chunk1 = r#"{"replacement_text": "fn foo() {\"#; let fixed1 = fix_streamed_json(chunk1); let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json"); @@ -102,7 +105,6 @@ mod tests { let text2 = parsed2["replacement_text"].as_str().expect("string"); assert_eq!(text2, "fn foo() {\n return bar;\n}"); - // The delta should be the newline + rest, with no spurious backslash let delta = &text2[text1.len()..]; assert_eq!(delta, "\n return bar;\n}"); } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 4ebfce695e587265ea39077c67c84ce9b01e5352..60670114529b07dca78202cc438ff5e243acaeee 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -21,8 +21,8 @@ aws_http_client.workspace = true base64.workspace = true bedrock = { workspace = true, features = ["schemars"] } client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true convert_case.workspace = true @@ -41,6 +41,7 @@ gpui_tokio.workspace = true http_client.workspace = true language.workspace = true language_model.workspace = true +language_models_cloud.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true menu.workspace = true @@ -49,16 +50,13 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } opencode = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -partial-json-fixer.workspace = true release_channel.workspace = true schemars.workspace = true -semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true strum.workspace = true -thiserror.workspace = true tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true @@ -70,4 +68,3 @@ x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true - diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index d3c433974599399160e602b8f201b9fd0af874cb..35a1e90e4483ba03e1ded8ce8c7519fc0fa7a746 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -11,7 +11,7 @@ pub mod open_ai; pub mod open_ai_compatible; pub mod open_router; pub mod opencode; -mod util; + pub mod vercel; pub mod vercel_ai_gateway; pub mod x_ai; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c1b8bc1a3bb1b602b67ae5563d8acc3b05a94d47..58de77d573293345ec2120695866c824f10c6108 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,13 +1,10 @@ pub mod telemetry; -use anthropic::{ - ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event, - ResponseContent, ToolResultContent, ToolResultPart, Usage, -}; +use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode}; use anyhow::Result; -use collections::{BTreeMap, HashMap}; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; use http_client::HttpClient; use language_model::{ @@ -16,20 +13,19 @@ use language_model::{ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, env_var, + LanguageModelToolChoice, RateLimiter, env_var, }; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; - +pub use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, + into_anthropic_count_tokens_request, +}; pub use settings::AnthropicAvailableModel as AvailableModel; const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; @@ -249,228 +245,6 @@ pub struct AnthropicModel { request_limiter: RateLimiter, } -fn to_anthropic_content(content: MessageContent) -> Option { - match content { - MessageContent::Text(text) => { - let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { - text.trim_end().to_string() - } else { - text - }; - if !text.is_empty() { - Some(anthropic::RequestContent::Text { - text, - cache_control: None, - }) - } else { - None - } - } - MessageContent::Thinking { - text: thinking, - signature, - } => { - if let Some(signature) = signature - && !thinking.is_empty() - { - Some(anthropic::RequestContent::Thinking { - thinking, - signature, - cache_control: None, - }) - } else { - None - } - } - MessageContent::RedactedThinking(data) => { - if !data.is_empty() { - Some(anthropic::RequestContent::RedactedThinking { data }) - } else { - None - } - } - MessageContent::Image(image) => Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control: None, - }), - MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse { - id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - input: tool_use.input, - cache_control: None, - }), - MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id.to_string(), - is_error: tool_result.is_error, - content: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ToolResultContent::Plain(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ToolResultContent::Multipart(vec![ToolResultPart::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }]) - } - }, - cache_control: None, - }), - } -} - -/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. -pub fn into_anthropic_count_tokens_request( - request: LanguageModelRequest, - model: String, - mode: AnthropicModelMode, -) -> CountTokensRequest { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - CountTokensRequest { - model, - messages: new_messages, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - } -} - -/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, -/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. -pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) -} - impl AnthropicModel { fn stream_completion( &self, @@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel { ) }); + let background = cx.background_executor().clone(); async move { // If no API key, fall back to tiktoken estimation let Some(api_key) = api_key else { - return count_anthropic_tokens_with_tiktoken(request); + return background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await; }; let count_request = @@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel { log::error!( "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" ); - count_anthropic_tokens_with_tiktoken(request) + background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await } } } @@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel { } } -pub fn into_anthropic( - request: LanguageModelRequest, - model: String, - default_temperature: f32, - max_output_tokens: u64, - mode: AnthropicModelMode, -) -> anthropic::Request { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let mut anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - // Mark the last segment of the message as cached - if message.cache { - let cache_control_value = Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }); - for message_content in anthropic_message_content.iter_mut().rev() { - match message_content { - anthropic::RequestContent::RedactedThinking { .. } => { - // Caching is not possible, fallback to next message - } - anthropic::RequestContent::Text { cache_control, .. } - | anthropic::RequestContent::Thinking { cache_control, .. } - | anthropic::RequestContent::Image { cache_control, .. } - | anthropic::RequestContent::ToolUse { cache_control, .. } - | anthropic::RequestContent::ToolResult { cache_control, .. } => { - *cache_control = cache_control_value; - break; - } - } - } - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - anthropic::Request { - model, - messages: new_messages, - max_tokens: max_output_tokens, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - metadata: None, - output_config: if request.thinking_allowed - && matches!(mode, AnthropicModelMode::AdaptiveThinking) - { - request.thinking_effort.as_deref().and_then(|effort| { - let effort = match effort { - "low" => Some(anthropic::Effort::Low), - "medium" => Some(anthropic::Effort::Medium), - "high" => Some(anthropic::Effort::High), - "max" => Some(anthropic::Effort::Max), - _ => None, - }; - effort.map(|effort| anthropic::OutputConfig { - effort: Some(effort), - }) - }) - } else { - None - }, - stop_sequences: Vec::new(), - speed: request.speed.map(From::from), - temperature: request.temperature.or(Some(default_temperature)), - top_k: None, - top_p: None, - } -} - -pub struct AnthropicEventMapper { - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, -} - -impl AnthropicEventMapper { - pub fn new() -> Self { - Self { - tool_uses_by_index: HashMap::default(), - usage: Usage::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(error.into())], - }) - }) - } - - pub fn map_event( - &mut self, - event: Event, - ) -> Vec> { - match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ResponseContent::Thinking { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ResponseContent::RedactedThinking { data } => { - vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] - } - ResponseContent::ToolUse { id, name, .. } => { - self.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - Vec::new() - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ContentDelta::ThinkingDelta { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ContentDelta::SignatureDelta { signature } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })] - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = - serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) - { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - thought_signature: None, - }, - ))]; - } - } - vec![] - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let event_result = match parse_tool_arguments(input_json) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - thought_signature: None, - }, - )), - Err(json_parse_err) => { - Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; - - vec![event_result] - } else { - Vec::new() - } - } - Event::MessageStart { message } => { - update_usage(&mut self.usage, &message.usage); - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &self.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ] - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut self.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - self.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "refusal" => StopReason::Refusal, - _ => { - log::error!("Unexpected anthropic stop_reason: {stop_reason}"); - StopReason::EndTurn - } - }; - } - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))] - } - Event::MessageStop => { - vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] - } - Event::Error { error } => { - vec![Err(error.into())] - } - _ => Vec::new(), - } - } -} - -struct RawToolUse { - id: String, - name: String, - input_json: String, -} - -/// Updates usage data by preferring counts from `new`. -fn update_usage(usage: &mut Usage, new: &Usage) { - if let Some(input_tokens) = new.input_tokens { - usage.input_tokens = Some(input_tokens); - } - if let Some(output_tokens) = new.output_tokens { - usage.output_tokens = Some(output_tokens); - } - if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { - usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); - } - if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { - usage.cache_read_input_tokens = Some(cache_read_input_tokens); - } -} - -fn convert_usage(usage: &Usage) -> language_model::TokenUsage { - language_model::TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), - cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1157,192 +597,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use anthropic::AnthropicModelMode; - use language_model::{LanguageModelRequestMessage, MessageContent}; - - #[test] - fn test_cache_control_only_on_last_segment() { - let request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Some prompt".to_string()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - ], - cache: true, - reasoning_details: None, - }], - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - let anthropic_request = into_anthropic( - request, - "claude-3-5-sonnet".to_string(), - 0.7, - 4096, - AnthropicModelMode::Default, - ); - - assert_eq!(anthropic_request.messages.len(), 1); - - let message = &anthropic_request.messages[0]; - assert_eq!(message.content.len(), 5); - - assert!(matches!( - message.content[0], - anthropic::RequestContent::Text { - cache_control: None, - .. - } - )); - for i in 1..3 { - assert!(matches!( - message.content[i], - anthropic::RequestContent::Image { - cache_control: None, - .. - } - )); - } - - assert!(matches!( - message.content[4], - anthropic::RequestContent::Image { - cache_control: Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }), - .. - } - )); - } - - fn request_with_assistant_content( - assistant_content: Vec, - ) -> anthropic::Request { - let mut request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("Hello".to_string())], - cache: false, - reasoning_details: None, - }], - thinking_effort: None, - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - speed: None, - }; - request.messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: assistant_content, - cache: false, - reasoning_details: None, - }); - into_anthropic( - request, - "claude-sonnet-4-5".to_string(), - 1.0, - 16000, - AnthropicModelMode::Thinking { - budget_tokens: Some(10000), - }, - ) - } - - #[test] - fn test_unsigned_thinking_blocks_stripped() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Cancelled mid-think, no signature".to_string(), - signature: None, - }, - MessageContent::Text("Some response text".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should still exist"); - - assert_eq!( - assistant_message.content.len(), - 1, - "Only the text content should remain; unsigned thinking block should be stripped" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Text { text, .. } if text == "Some response text" - )); - } - - #[test] - fn test_signed_thinking_blocks_preserved() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Completed thinking".to_string(), - signature: Some("valid-signature".to_string()), - }, - MessageContent::Text("Response".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should exist"); - - assert_eq!( - assistant_message.content.len(), - 2, - "Both the signed thinking block and text should be preserved" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Thinking { thinking, signature, .. } - if thinking == "Completed thinking" && signature == "valid-signature" - )); - } - - #[test] - fn test_only_unsigned_thinking_block_omits_entire_message() { - let result = request_with_assistant_content(vec![MessageContent::Thinking { - text: "Cancelled before any text or signature".to_string(), - signature: None, - }]); - - let assistant_messages: Vec<_> = result - .messages - .iter() - .filter(|m| m.role == anthropic::Role::Assistant) - .collect(); - - assert_eq!( - assistant_messages.len(), - 0, - "An assistant message whose only content was an unsigned thinking block \ - should be omitted entirely" - ); - } -} diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 4320763e2c5c6de7f3fe9238d7a4991565c3bfcd..80c758769cd990c00f5942433143bf6fb2216b7c 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -48,7 +48,7 @@ use ui_input::InputField; use util::ResultExt; use crate::AllLanguageModelSettings; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; actions!(bedrock, [Tab, TabPrev]); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 29623cc998ad0fe933e9a29c45c651f7be010b07..294b44ecae9941481e26c2341018ce584d68b3ec 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,107 +1,93 @@ use ai_onboarding::YoungAccountBanner; -use anthropic::AnthropicModelMode; -use anyhow::{Context as _, Result, anyhow}; -use client::{ - Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls, -}; -use cloud_api_types::{OrganizationId, Plan}; -use cloud_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, - CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, - CountTokensBody, CountTokensResponse, ListModelsResponse, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; -use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, - future::BoxFuture, - stream::{self, BoxStream}, -}; -use google_ai::GoogleModelMode; -use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; -use http_client::http::{HeaderMap, HeaderValue}; -use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; +use anyhow::Result; +use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls}; +use cloud_api_client::LlmApiToken; +use cloud_api_types::OrganizationId; +use cloud_api_types::Plan; +use futures::StreamExt; +use futures::future::BoxFuture; +use gpui::AsyncApp; +use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task}; use language_model::{ - ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID, - GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID, - OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, - ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, + AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, + ZED_CLOUD_PROVIDER_NAME, }; +use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider}; use release_channel::AppVersion; -use schemars::JsonSchema; -use semver::Version; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; + use settings::SettingsStore; pub use settings::ZedDotDevAvailableModel as AvailableModel; pub use settings::ZedDotDevAvailableProvider as AvailableProvider; -use smol::io::{AsyncReadExt, BufReader}; -use std::collections::VecDeque; -use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; -use std::task::Poll; -use std::time::Duration; -use thiserror::Error; use ui::{TintColor, prelude::*}; -use crate::provider::anthropic::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, -}; -use crate::provider::google::{GoogleEventMapper, into_google}; -use crate::provider::open_ai::{ - OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, - into_open_ai_response, -}; -use crate::provider::x_ai::count_xai_tokens; - const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; -#[derive(Default, Clone, Debug, PartialEq)] -pub struct ZedDotDevSettings { - pub available_models: Vec, -} -#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +struct ClientTokenProvider { + client: Arc, + llm_api_token: LlmApiToken, + user_store: Entity, } -impl From for AnthropicModelMode { - fn from(value: ModelMode) -> Self { - match value { - ModelMode::Default => AnthropicModelMode::Default, - ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, - } +impl CloudLlmTokenProvider for ClientTokenProvider { + type AuthContext = Option; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext { + self.user_store.read_with(cx, |user_store, _| { + user_store + .current_organization() + .map(|organization| organization.id.clone()) + }) } + + fn acquire_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .acquire_llm_token(&llm_api_token, organization_id) + .await + }) + } + + fn refresh_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .refresh_llm_token(&llm_api_token, organization_id) + .await + }) + } +} + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct ZedDotDevSettings { + pub available_models: Vec, } pub struct CloudLanguageModelProvider { - client: Arc, state: Entity, _maintain_client_status: Task<()>, } pub struct State { client: Arc, - llm_api_token: LlmApiToken, user_store: Entity, status: client::Status, - models: Vec>, - default_model: Option>, - default_fast_model: Option>, - recommended_models: Vec>, + provider: Entity>, _user_store_subscription: Subscription, _settings_subscription: Subscription, _llm_token_subscription: Subscription, + _provider_subscription: Subscription, } impl State { @@ -112,16 +98,26 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let llm_api_token = global_llm_token(cx); + let token_provider = Arc::new(ClientTokenProvider { + client: client.clone(), + llm_api_token: global_llm_token(cx), + user_store: user_store.clone(), + }); + + let provider = cx.new(|cx| { + CloudModelProvider::new( + token_provider.clone(), + client.http_client(), + Some(AppVersion::global(cx)), + ) + }); + Self { client: client.clone(), - llm_api_token, user_store: user_store.clone(), status, - models: Vec::new(), - default_model: None, - default_fast_model: None, - recommended_models: Vec::new(), + _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()), + provider, _user_store_subscription: cx.subscribe( &user_store, move |this, _user_store, event, cx| match event { @@ -131,19 +127,7 @@ impl State { return; } - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| this.update_models(response, cx)) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); } _ => {} }, @@ -154,21 +138,7 @@ impl State { _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, move |this, _listener, _event, cx| { - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); }, ), } @@ -186,74 +156,10 @@ impl State { }) } - fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { - let mut models = Vec::new(); - - for model in response.models { - models.push(Arc::new(model.clone())); - } - - self.default_model = models - .iter() - .find(|model| { - response - .default_model - .as_ref() - .is_some_and(|default_model_id| &model.id == default_model_id) - }) - .cloned(); - self.default_fast_model = models - .iter() - .find(|model| { - response - .default_fast_model - .as_ref() - .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) - }) - .cloned(); - self.recommended_models = response - .recommended_models - .iter() - .filter_map(|id| models.iter().find(|model| &model.id == id)) - .cloned() - .collect(); - self.models = models; - cx.notify(); - } - - async fn fetch_models( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - ) -> Result { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request = http_client::Request::builder() - .method(Method::GET) - .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") - .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) - .header("Authorization", format!("Bearer {token}")) - .body(AsyncBody::empty())?; - let mut response = http_client - .send(request) - .await - .context("failed to send list models request")?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - Ok(serde_json::from_str(&body)?) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "error listing models.\nStatus: {:?}\nBody: {body}", - response.status(), - ); - } + fn refresh_models(&mut self, cx: &mut Context) { + self.provider.update(cx, |provider, cx| { + provider.refresh_models(cx).detach_and_log_err(cx); + }); } } @@ -281,27 +187,10 @@ impl CloudLanguageModelProvider { }); Self { - client, state, _maintain_client_status: maintain_client_status, } } - - fn create_language_model( - &self, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - ) -> Arc { - Arc::new(CloudLanguageModel { - id: LanguageModelId(SharedString::from(model.id.0.clone())), - model, - llm_api_token, - user_store, - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - }) - } } impl LanguageModelProviderState for CloudLanguageModelProvider { @@ -327,45 +216,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_model = state.default_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_model()?; + Some(provider.create_model(model)) } fn default_fast_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_fast_model = state.default_fast_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_fast_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_fast_model()?; + Some(provider.create_model(model)) } fn recommended_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .recommended_models + let provider = state.provider.read(cx); + provider + .recommended_models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } fn provided_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .models + let provider = state.provider.read(cx); + provider + .models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } @@ -393,700 +272,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } -pub struct CloudLanguageModel { - id: LanguageModelId, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - client: Arc, - request_limiter: RateLimiter, -} - -struct PerformLlmCompletionResponse { - response: Response, - includes_status_messages: bool, -} - -impl CloudLanguageModel { - async fn perform_llm_completion( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - app_version: Option, - body: CompletionBody, - ) -> Result { - let http_client = &client.http_client(); - - let mut token = client - .acquire_llm_token(&llm_api_token, organization_id.clone()) - .await?; - let mut refreshed_token = false; - - loop { - let request = http_client::Request::builder() - .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) - .when_some(app_version.as_ref(), |builder, app_version| { - builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - }) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") - .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") - .body(serde_json::to_string(&body)?.into())?; - - let mut response = http_client.send(request).await?; - let status = response.status(); - if status.is_success() { - let includes_status_messages = response - .headers() - .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) - .is_some(); - - return Ok(PerformLlmCompletionResponse { - response, - includes_status_messages, - }); - } - - if !refreshed_token && response.needs_llm_token_refresh() { - token = client - .refresh_llm_token(&llm_api_token, organization_id.clone()) - .await?; - refreshed_token = true; - continue; - } - - if status == StatusCode::PAYMENT_REQUIRED { - return Err(anyhow!(PaymentRequiredError)); - } - - let mut body = String::new(); - let headers = response.headers().clone(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!(ApiError { - status, - body, - headers - })); - } - } -} - -#[derive(Debug, Error)] -#[error("cloud language model request failed with status {status}: {body}")] -struct ApiError { - status: StatusCode, - body: String, - headers: HeaderMap, -} - -/// Represents error responses from Zed's cloud API. -/// -/// Example JSON for an upstream HTTP error: -/// ```json -/// { -/// "code": "upstream_http_error", -/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", -/// "upstream_status": 503 -/// } -/// ``` -#[derive(Debug, serde::Deserialize)] -struct CloudApiError { - code: String, - message: String, - #[serde(default)] - #[serde(deserialize_with = "deserialize_optional_status_code")] - upstream_status: Option, - #[serde(default)] - retry_after: Option, -} - -fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let opt: Option = Option::deserialize(deserializer)?; - Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) -} - -impl From for LanguageModelCompletionError { - fn from(error: ApiError) -> Self { - if let Ok(cloud_error) = serde_json::from_str::(&error.body) { - if cloud_error.code.starts_with("upstream_http_") { - let status = if let Some(status) = cloud_error.upstream_status { - status - } else if cloud_error.code.ends_with("_error") { - error.status - } else { - // If there's a status code in the code string (e.g. "upstream_http_429") - // then use that; otherwise, see if the JSON contains a status code. - cloud_error - .code - .strip_prefix("upstream_http_") - .and_then(|code_str| code_str.parse::().ok()) - .and_then(|code| StatusCode::from_u16(code).ok()) - .unwrap_or(error.status) - }; - - return LanguageModelCompletionError::UpstreamProviderError { - message: cloud_error.message, - status, - retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), - }; - } - - return LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - cloud_error.message, - None, - ); - } - - let retry_after = None; - LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - error.body, - retry_after, - ) - } -} - -impl LanguageModel for CloudLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name.clone()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn upstream_provider_id(&self) -> LanguageModelProviderId { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_ID, - OpenAi => OPEN_AI_PROVIDER_ID, - Google => GOOGLE_PROVIDER_ID, - XAi => X_AI_PROVIDER_ID, - } - } - - fn upstream_provider_name(&self) -> LanguageModelProviderName { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_NAME, - OpenAi => OPEN_AI_PROVIDER_NAME, - Google => GOOGLE_PROVIDER_NAME, - XAi => X_AI_PROVIDER_NAME, - } - } - - fn is_latest(&self) -> bool { - self.model.is_latest - } - - fn supports_tools(&self) -> bool { - self.model.supports_tools - } - - fn supports_images(&self) -> bool { - self.model.supports_images - } - - fn supports_thinking(&self) -> bool { - self.model.supports_thinking - } - - fn supports_fast_mode(&self) -> bool { - self.model.supports_fast_mode - } - - fn supported_effort_levels(&self) -> Vec { - self.model - .supported_effort_levels - .iter() - .map(|effort_level| LanguageModelEffortLevel { - name: effort_level.name.clone().into(), - value: effort_level.value.clone().into(), - is_default: effort_level.is_default.unwrap_or(false), - }) - .collect() - } - - fn supports_streaming_tools(&self) -> bool { - self.model.supports_streaming_tools - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => true, - } - } - - fn supports_split_token_display(&self) -> bool { - use cloud_llm_client::LanguageModelProvider::*; - matches!(self.model.provider, OpenAi | XAi) - } - - fn telemetry_id(&self) -> String { - format!("zed.dev/{}", self.model.id) - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic - | cloud_llm_client::LanguageModelProvider::OpenAi => { - LanguageModelToolSchemaFormat::JsonSchema - } - cloud_llm_client::LanguageModelProvider::Google - | cloud_llm_client::LanguageModelProvider::XAi => { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } - } - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count as u64 - } - - fn max_output_tokens(&self) -> Option { - Some(self.model.max_output_tokens as u64) - } - - fn cache_configuration(&self) -> Option { - match &self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - Some(LanguageModelCacheConfiguration { - min_total_token: 2_048, - should_speculate: true, - max_cache_anchors: 4, - }) - } - cloud_llm_client::LanguageModelProvider::OpenAi - | cloud_llm_client::LanguageModelProvider::XAi - | cloud_llm_client::LanguageModelProvider::Google => None, - } - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => cx - .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .boxed(), - cloud_llm_client::LanguageModelProvider::OpenAi => { - let model = match open_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_open_ai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::XAi => { - let model = match x_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_xai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = self - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - let model_id = self.model.id.to_string(); - let generate_content_request = - into_google(request, model_id.clone(), GoogleModelMode::Default); - async move { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, - model: model_id, - provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - generate_content_request, - })?, - }; - let request = http_client::Request::builder() - .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/count_tokens", &[])? - .as_ref(), - ) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body(serde_json::to_string(&request_body)?.into())?; - let mut response = http_client.send(request).await?; - let status = response.status(); - let headers = response.headers().clone(); - let mut response_body = String::new(); - response - .body_mut() - .read_to_string(&mut response_body) - .await?; - - if status.is_success() { - let response_body: CountTokensResponse = - serde_json::from_str(&response_body)?; - - Ok(response_body.tokens as u64) - } else { - Err(anyhow!(ApiError { - status, - body: response_body, - headers - })) - } - } - .boxed() - } - } - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let thread_id = request.thread_id.clone(); - let prompt_id = request.prompt_id.clone(); - let app_version = Some(cx.update(|cx| AppVersion::global(cx))); - let user_store = self.user_store.clone(); - let organization_id = cx.update(|cx| { - user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()) - }); - let thinking_allowed = request.thinking_allowed; - let enable_thinking = thinking_allowed && self.model.supports_thinking; - let provider_name = provider_name(&self.model.provider); - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| anthropic::Effort::from_str(effort).ok()); - - let mut request = into_anthropic( - request, - self.model.id.to_string(), - 1.0, - self.model.max_output_tokens as u64, - if enable_thinking { - AnthropicModelMode::Thinking { - budget_tokens: Some(4_096), - } - } else { - AnthropicModelMode::Default - }, - ); - - if enable_thinking && effort.is_some() { - request.thinking = Some(anthropic::Thinking::Adaptive); - request.output_config = Some(anthropic::OutputConfig { effort }); - } - - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await - .map_err(|err| match err.downcast::() { - Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), - Err(err) => anyhow!(err), - })?; - - let mut mapper = AnthropicEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::OpenAi => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); - - let mut request = into_open_ai_response( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); - - if enable_thinking && let Some(effort) = effort { - request.reasoning = Some(open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }); - } - - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiResponseEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::XAi => { - let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - false, - None, - None, - ); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::XAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let request = - into_google(request, self.model.id.to_string(), GoogleModelMode::Default); - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Google, - model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = GoogleEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - } - } -} - -fn map_cloud_completion_events( - stream: Pin>> + Send>>, - provider: &LanguageModelProviderName, - mut map_callback: F, -) -> BoxStream<'static, Result> -where - T: DeserializeOwned + 'static, - F: FnMut(T) -> Vec> - + Send - + 'static, -{ - let provider = provider.clone(); - let mut stream = stream.fuse(); - - let mut saw_stream_ended = false; - - let mut done = false; - let mut pending = VecDeque::new(); - - stream::poll_fn(move |cx| { - loop { - if let Some(item) = pending.pop_front() { - return Poll::Ready(Some(item)); - } - - if done { - return Poll::Ready(None); - } - - match stream.poll_next_unpin(cx) { - Poll::Ready(Some(event)) => { - let items = match event { - Err(error) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { - saw_stream_ended = true; - vec![] - } - Ok(CompletionEvent::Status(status)) => { - LanguageModelCompletionEvent::from_completion_request_status( - status, - provider.clone(), - ) - .transpose() - .map(|event| vec![event]) - .unwrap_or_default() - } - Ok(CompletionEvent::Event(event)) => map_callback(event), - }; - pending.extend(items); - } - Poll::Ready(None) => { - done = true; - - if !saw_stream_ended { - return Poll::Ready(Some(Err( - LanguageModelCompletionError::StreamEndedUnexpectedly { - provider: provider.clone(), - }, - ))); - } - } - Poll::Pending => return Poll::Pending, - } - } - }) - .boxed() -} - -fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName { - match provider { - cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, - } -} - -fn response_lines( - response: Response, - includes_status_messages: bool, -) -> impl Stream>> { - futures::stream::try_unfold( - (String::new(), BufReader::new(response.into_body())), - move |(mut line, mut body)| async move { - match body.read_line(&mut line).await { - Ok(0) => Ok(None), - Ok(_) => { - let event = if includes_status_messages { - serde_json::from_str::>(&line)? - } else { - CompletionEvent::Event(serde_json::from_str::(&line)?) - }; - - line.clear(); - Ok(Some((event, (line, body)))) - } - Err(e) => Err(e.into()), - } - }, - ) -} - #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, @@ -1281,155 +466,3 @@ impl Component for ZedAiConfiguration { ) } } - -#[cfg(test)] -mod tests { - use super::*; - use http_client::http::{HeaderMap, StatusCode}; - use language_model::LanguageModelCompletionError; - - #[test] - fn test_api_error_conversion_with_upstream_http_error() { - // upstream_http_error with 503 status should become ServerOverloaded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 503, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 500 status should become ApiInternalServerError - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the OpenAI API: internal server error" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 500, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 429 status should become RateLimitExceeded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Google API: rate limit exceeded" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 429, got: {:?}", - completion_error - ), - } - - // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed - let error_body = "Regular internal server error"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider, PROVIDER_NAME); - assert_eq!(message, "Regular internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for regular 500, got: {:?}", - completion_error - ), - } - - // upstream_http_429 format should be converted to UpstreamProviderError - let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { - message, - status, - retry_after, - } => { - assert_eq!(message, "Upstream Anthropic rate limit exceeded."); - assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); - assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); - } - _ => panic!( - "Expected UpstreamProviderError for upstream_http_429, got: {:?}", - completion_error - ), - } - - // Invalid JSON in error body should fall back to regular error handling - let error_body = "Not JSON at all"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { - assert_eq!(provider, PROVIDER_NAME); - } - _ => panic!( - "Expected ApiInternalServerError for invalid JSON, got: {:?}", - completion_error - ), - } - } -} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index a2d39e1945e2791d9d5c998cc717a07498ebc157..a77e3f880be18d8f9f0e97ec8717c32bc780e267 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -32,7 +32,7 @@ use ui::prelude::*; use util::debug_panic; use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic}; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = @@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel { levels .iter() .map(|level| { - let name: SharedString = match level.as_str() { + let name = match level.as_str() { "low" => "Low".into(), "medium" => "Medium".into(), "high" => "High".into(), - _ => SharedString::from(level.clone()), + _ => language_model::SharedString::from(level.clone()), }; LanguageModelEffortLevel { name, - value: SharedString::from(level.clone()), + value: language_model::SharedString::from(level.clone()), is_default: level == "high", } }) diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 0cfb1af425c7cb0279d98fa124a589437f1bb1a1..f3dccd5cc1a2e1a5ddfe2bc6b43901f2b549e532 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 244f7835a85ff67f0c4826321910ea13516371cb..92278839c6ff5119849f8881409928686f055331 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,32 +1,25 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; -use google_ai::{ - FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, - ThinkingConfig, UsageMetadata, -}; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google}; +use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub use settings::GoogleAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::sync::{ - Arc, LazyLock, - atomic::{self, AtomicU64}, -}; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; @@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel { } } -pub fn into_google( - mut request: LanguageModelRequest, - model_id: String, - mode: GoogleModelMode, -) -> google_ai::GenerateContentRequest { - fn map_content(content: Vec) -> Vec { - content - .into_iter() - .flat_map(|content| match content { - language_model::MessageContent::Text(text) => { - if !text.is_empty() { - vec![Part::TextPart(google_ai::TextPart { text })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { - text: _, - signature: Some(signature), - } => { - if !signature.is_empty() { - vec![Part::ThoughtPart(google_ai::ThoughtPart { - thought: true, - thought_signature: signature, - })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { .. } => { - vec![] - } - language_model::MessageContent::RedactedThinking(_) => vec![], - language_model::MessageContent::Image(image) => { - vec![Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - })] - } - language_model::MessageContent::ToolUse(tool_use) => { - // Normalize empty string signatures to None - let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); - - vec![Part::FunctionCallPart(google_ai::FunctionCallPart { - function_call: google_ai::FunctionCall { - name: tool_use.name.to_string(), - args: tool_use.input, - }, - thought_signature, - })] - } - language_model::MessageContent::ToolResult(tool_result) => { - match tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => { - vec![Part::FunctionResponsePart( - google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": text - }), - }, - }, - )] - } - language_model::LanguageModelToolResultContent::Image(image) => { - vec![ - Part::FunctionResponsePart(google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": "Tool responded with an image" - }), - }, - }), - Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }), - ] - } - } - } - }) - .collect() - } - - let system_instructions = if request - .messages - .first() - .is_some_and(|msg| matches!(msg.role, Role::System)) - { - let message = request.messages.remove(0); - Some(SystemInstruction { - parts: map_content(message.content), - }) - } else { - None - }; - - google_ai::GenerateContentRequest { - model: google_ai::ModelName { model_id }, - system_instruction: system_instructions, - contents: request - .messages - .into_iter() - .filter_map(|message| { - let parts = map_content(message.content); - if parts.is_empty() { - None - } else { - Some(google_ai::Content { - parts, - role: match message.role { - Role::User => google_ai::Role::User, - Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, // Google AI doesn't have a system role - }, - }) - } - }) - .collect(), - generation_config: Some(google_ai::GenerationConfig { - candidate_count: Some(1), - stop_sequences: Some(request.stop), - max_output_tokens: None, - temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), - thinking_config: match (request.thinking_allowed, mode) { - (true, GoogleModelMode::Thinking { budget_tokens }) => { - budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) - } - _ => None, - }, - top_p: None, - top_k: None, - }), - safety_settings: None, - tools: (!request.tools.is_empty()).then(|| { - vec![google_ai::Tool { - function_declarations: request - .tools - .into_iter() - .map(|tool| FunctionDeclaration { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }) - .collect(), - }] - }), - tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig { - function_calling_config: google_ai::FunctionCallingConfig { - mode: match choice { - LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto, - LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any, - LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None, - }, - allowed_function_names: None, - }, - }), - } -} - -pub struct GoogleEventMapper { - usage: UsageMetadata, - stop_reason: StopReason, -} - -impl GoogleEventMapper { - pub fn new() -> Self { - Self { - usage: UsageMetadata::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events - .map(Some) - .chain(futures::stream::once(async { None })) - .flat_map(move |event| { - futures::stream::iter(match event { - Some(Ok(event)) => self.map_event(event), - Some(Err(error)) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], - }) - }) - } - - pub fn map_event( - &mut self, - event: GenerateContentResponse, - ) -> Vec> { - static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - - let mut events: Vec<_> = Vec::new(); - let mut wants_to_use_tool = false; - if let Some(usage_metadata) = event.usage_metadata { - update_usage(&mut self.usage, &usage_metadata); - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))) - } - - if let Some(prompt_feedback) = event.prompt_feedback - && let Some(block_reason) = prompt_feedback.block_reason.as_deref() - { - self.stop_reason = match block_reason { - "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { - StopReason::Refusal - } - _ => { - log::error!("Unexpected Google block_reason: {block_reason}"); - StopReason::Refusal - } - }; - events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); - - return events; - } - - if let Some(candidates) = event.candidates { - for candidate in candidates { - if let Some(finish_reason) = candidate.finish_reason.as_deref() { - self.stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - _ => { - log::error!("Unexpected google finish_reason: {finish_reason}"); - StopReason::EndTurn - } - }; - } - candidate - .content - .parts - .into_iter() - .for_each(|part| match part { - Part::TextPart(text_part) => { - events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) - } - Part::InlineDataPart(_) => {} - Part::FunctionCallPart(function_call_part) => { - wants_to_use_tool = true; - let name: Arc = function_call_part.function_call.name.into(); - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); - let id: LanguageModelToolUseId = - format!("{}-{}", name, next_tool_id).into(); - - // Normalize empty string signatures to None - let thought_signature = function_call_part - .thought_signature - .filter(|s| !s.is_empty()); - - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id, - name, - is_input_complete: true, - raw_input: function_call_part.function_call.args.to_string(), - input: function_call_part.function_call.args, - thought_signature, - }, - ))); - } - Part::FunctionResponsePart(_) => {} - Part::ThoughtPart(part) => { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? - signature: Some(part.thought_signature), - })); - } - }); - } - } - - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` - if wants_to_use_tool { - self.stop_reason = StopReason::ToolUse; - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - events - } -} - -pub fn count_google_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. - // So we have to use tokenizer from tiktoken_rs to count tokens. - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - -fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { - if let Some(prompt_token_count) = new.prompt_token_count { - usage.prompt_token_count = Some(prompt_token_count); - } - if let Some(cached_content_token_count) = new.cached_content_token_count { - usage.cached_content_token_count = Some(cached_content_token_count); - } - if let Some(candidates_token_count) = new.candidates_token_count { - usage.candidates_token_count = Some(candidates_token_count); - } - if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { - usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); - } - if let Some(thoughts_token_count) = new.thoughts_token_count { - usage.thoughts_token_count = Some(thoughts_token_count); - } - if let Some(total_token_count) = new.total_token_count { - usage.total_token_count = Some(total_token_count); - } -} - -fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { - let prompt_tokens = usage.prompt_token_count.unwrap_or(0); - let cached_tokens = usage.cached_content_token_count.unwrap_or(0); - let input_tokens = prompt_tokens - cached_tokens; - let output_tokens = usage.candidates_token_count.unwrap_or(0); - - language_model::TokenUsage { - input_tokens, - output_tokens, - cache_read_input_tokens: cached_tokens, - cache_creation_input_tokens: 0, - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -895,428 +525,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use google_ai::{ - Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, - Part, Role as GoogleRole, TextPart, - }; - use language_model::{LanguageModelToolUseId, MessageContent, Role}; - use serde_json::json; - - #[test] - fn test_function_call_with_signature_creates_tool_use_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("test_signature_123".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 2); // ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "test_function"); - assert_eq!( - tool_use.thought_signature.as_deref(), - Some("test_signature_123") - ); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_function_call_without_signature_has_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: None, - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_empty_string_signature_normalized_to_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_parallel_function_calls_preserve_signatures() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_1".to_string(), - args: json!({"arg": "value1"}), - }, - thought_signature: Some("signature_1".to_string()), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_2".to_string(), - args: json!({"arg": "value2"}), - }, - thought_signature: None, - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "function_1"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); - } else { - panic!("Expected ToolUse event for function_1"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "function_2"); - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event for function_2"); - } - } - - #[test] - fn test_tool_use_with_signature_converts_to_function_call_part() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature_456".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.function_call.name, "test_function"); - assert_eq!( - fc_part.thought_signature.as_deref(), - Some("test_signature_456") - ); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_tool_use_without_signature_omits_field() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_empty_signature_in_tool_use_normalized_to_none() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_round_trip_preserves_signature() { - let mut mapper = GoogleEventMapper::new(); - - // Simulate receiving a response from Google with a signature - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("round_trip_sig".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - tool_use.clone() - } else { - panic!("Expected ToolUse event"); - }; - - // Convert back to Google format - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - // Verify signature is preserved - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_mixed_text_and_function_call_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::TextPart(TextPart { - text: "I'll help with that.".to_string(), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "helper_function".to_string(), - args: json!({"query": "help"}), - }, - thought_signature: Some("mixed_sig".to_string()), - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { - assert_eq!(text, "I'll help with that."); - } else { - panic!("Expected Text event"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "helper_function"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_special_characters_in_signature_preserved() { - let mut mapper = GoogleEventMapper::new(); - - let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some(signature_with_special_chars.clone()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!( - tool_use.thought_signature.as_deref(), - Some(signature_with_special_chars.as_str()) - ); - } else { - panic!("Expected ToolUse event"); - } - } -} diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 0d60fef16791087e35bac7d846b2ec99821d5470..a541da8cd8092d5d0fa43af1217c31833f10cdeb 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -28,7 +28,7 @@ use ui::{ use ui_input::InputField; use crate::AllLanguageModelSettings; -use crate::provider::util::parse_tool_arguments; +use language_model::util::parse_tool_arguments; const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download"; const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4cd1375fe50cd792a3a7bc8c85ba7b5b5af9520a..5fef40b2b1badbc77133ebe67fbe0f1fe5521259 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6a2313487f4a1922cdc2aa20d23ede01c4b7d158..358a0ec5a6d517064be93d973f08eceb894ab665 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,41 +1,33 @@ -use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashMap}; +use anyhow::Result; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, - RateLimiter, Role, StopReason, TokenUsage, env_var, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, + RateLimiter, env_var, }; use menu; -use open_ai::responses::{ - ResponseFunctionCallItem, ResponseFunctionCallOutputContent, ResponseFunctionCallOutputItem, - ResponseInputContent, ResponseInputItem, ResponseMessageItem, -}; use open_ai::{ - ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, - responses::{ - Request as ResponseRequest, ResponseOutputItem, ResponseSummary as ResponsesSummary, - ResponseUsage as ResponsesUsage, StreamEvent as ResponsesStreamEvent, stream_response, - }, + OPEN_AI_API_URL, ResponseStreamEvent, + responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response}, stream_completion, }; use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; -use std::pin::Pin; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +pub use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens, + into_open_ai, into_open_ai_response, +}; const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME; @@ -189,7 +181,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, - reasoning_effort: model.reasoning_effort.clone(), + reasoning_effort: model.reasoning_effort, supports_chat_completions: model.capabilities.chat_completions, }, ); @@ -382,7 +374,9 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -433,853 +427,6 @@ impl LanguageModel for OpenAiLanguageModel { } } -pub fn into_open_ai( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> open_ai::Request { - let stream = !model_id.starts_with("o1-"); - - let mut messages = Vec::new(); - for message in request.messages { - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - let should_add = if message.role == Role::User { - // Including whitespace-only user messages can cause error with OpenAI compatible APIs - // See https://github.com/zed-industries/zed/issues/40097 - !text.trim().is_empty() - } else { - !text.is_empty() - }; - if should_add { - add_message_content_part( - open_ai::MessagePart::Text { text }, - message.role, - &mut messages, - ); - } - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - add_message_content_part( - open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }, - message.role, - &mut messages, - ); - } - MessageContent::ToolUse(tool_use) => { - let tool_call = open_ai::ToolCall { - id: tool_use.id.to_string(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - }, - }, - }; - - if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) = - messages.last_mut() - { - tool_calls.push(tool_call); - } else { - messages.push(open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - }); - } - } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_ai::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }] - } - }; - - messages.push(open_ai::RequestMessage::Tool { - content: content.into(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - } - - open_ai::Request { - model: model_id.into(), - messages, - stream, - stream_options: if stream { - Some(open_ai::StreamOptions::default()) - } else { - None - }, - stop: request.stop, - temperature: request.temperature.or(Some(1.0)), - max_completion_tokens: max_output_tokens, - parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { - Some(supports_parallel_tool_calls) - } else { - None - }, - prompt_cache_key: if supports_prompt_cache_key { - request.thread_id - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - reasoning_effort, - } -} - -pub fn into_open_ai_response( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> ResponseRequest { - let stream = !model_id.starts_with("o1-"); - - let LanguageModelRequest { - thread_id, - prompt_id: _, - intent: _, - messages, - tools, - tool_choice, - stop: _, - temperature, - thinking_allowed: _, - thinking_effort: _, - speed: _, - } = request; - - let mut input_items = Vec::new(); - for (index, message) in messages.into_iter().enumerate() { - append_message_to_response_items(message, index, &mut input_items); - } - - let tools: Vec<_> = tools - .into_iter() - .map(|tool| open_ai::responses::ToolDefinition::Function { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - strict: None, - }) - .collect(); - - ResponseRequest { - model: model_id.into(), - input: input_items, - stream, - temperature, - top_p: None, - max_output_tokens, - parallel_tool_calls: if tools.is_empty() { - None - } else { - Some(supports_parallel_tool_calls) - }, - tool_choice: tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - tools, - prompt_cache_key: if supports_prompt_cache_key { - thread_id - } else { - None - }, - reasoning: reasoning_effort.map(|effort| open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }), - } -} - -fn append_message_to_response_items( - message: LanguageModelRequestMessage, - index: usize, - input_items: &mut Vec, -) { - let mut content_parts: Vec = Vec::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::Thinking { text, .. } => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - push_response_image_part(&message.role, image, &mut content_parts); - } - MessageContent::ToolUse(tool_use) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - let call_id = tool_use.id.to_string(); - input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { - call_id, - name: tool_use.name.to_string(), - arguments: tool_use.raw_input, - })); - } - MessageContent::ToolResult(tool_result) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - input_items.push(ResponseInputItem::FunctionCallOutput( - ResponseFunctionCallOutputItem { - call_id: tool_result.tool_use_id.to_string(), - output: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ResponseFunctionCallOutputContent::Text(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ResponseFunctionCallOutputContent::List(vec![ - ResponseInputContent::Image { - image_url: image.to_base64_url(), - }, - ]) - } - }, - }, - )); - } - } - } - - flush_response_parts(&message.role, index, &mut content_parts, input_items); -} - -fn push_response_text_part( - role: &Role, - text: impl Into, - parts: &mut Vec, -) { - let text = text.into(); - if text.trim().is_empty() { - return; - } - - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text, - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Text { text }), - } -} - -fn push_response_image_part( - role: &Role, - image: LanguageModelImage, - parts: &mut Vec, -) { - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text: "[image omitted]".to_string(), - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Image { - image_url: image.to_base64_url(), - }), - } -} - -fn flush_response_parts( - role: &Role, - _index: usize, - parts: &mut Vec, - input_items: &mut Vec, -) { - if parts.is_empty() { - return; - } - - let item = ResponseInputItem::Message(ResponseMessageItem { - role: match role { - Role::User => open_ai::Role::User, - Role::Assistant => open_ai::Role::Assistant, - Role::System => open_ai::Role::System, - }, - content: parts.clone(), - }); - - input_items.push(item); - parts.clear(); -} - -fn add_message_content_part( - new_part: open_ai::MessagePart, - role: Role, - messages: &mut Vec, -) { - match (role, messages.last_mut()) { - (Role::User, Some(open_ai::RequestMessage::User { content })) - | ( - Role::Assistant, - Some(open_ai::RequestMessage::Assistant { - content: Some(content), - .. - }), - ) - | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { - content.push_part(new_part); - } - _ => { - messages.push(match role { - Role::User => open_ai::RequestMessage::User { - content: open_ai::MessageContent::from(vec![new_part]), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::from(vec![new_part])), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: open_ai::MessageContent::from(vec![new_part]), - }, - }); - } - } -} - -pub struct OpenAiEventMapper { - tool_calls_by_index: HashMap, -} - -impl OpenAiEventMapper { - pub fn new() -> Self { - Self { - tool_calls_by_index: HashMap::default(), - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponseStreamEvent, - ) -> Vec> { - let mut events = Vec::new(); - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - - let Some(choice) = event.choices.first() else { - return events; - }; - - if let Some(delta) = choice.delta.as_ref() { - if let Some(reasoning_content) = delta.reasoning_content.clone() { - if !reasoning_content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: reasoning_content, - signature: None, - })); - } - } - if let Some(content) = delta.content.clone() { - if !content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - } - - if let Some(tool_calls) = delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - - if !entry.id.is_empty() && !entry.name.is_empty() { - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: entry.id.clone().into(), - name: entry.name.as_str().into(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))); - } - } - } - } - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - Some("tool_calls") => { - events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { - match parse_tool_arguments(&tool_call.arguments) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - thought_signature: None, - }, - )), - Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_call.id.into(), - tool_name: tool_call.name.into(), - raw_input: tool_call.arguments.clone().into(), - json_parse_error: error.to_string(), - }), - } - })); - - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - None => {} - } - - events - } -} - -#[derive(Default)] -struct RawToolCall { - id: String, - name: String, - arguments: String, -} - -pub struct OpenAiResponseEventMapper { - function_calls_by_item: HashMap, - pending_stop_reason: Option, -} - -#[derive(Default)] -struct PendingResponseFunctionCall { - call_id: String, - name: Arc, - arguments: String, -} - -impl OpenAiResponseEventMapper { - pub fn new() -> Self { - Self { - function_calls_by_item: HashMap::default(), - pending_stop_reason: None, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponsesStreamEvent, - ) -> Vec> { - match event { - ResponsesStreamEvent::OutputItemAdded { item, .. } => { - let mut events = Vec::new(); - - match &item { - ResponseOutputItem::Message(message) => { - if let Some(id) = &message.id { - events.push(Ok(LanguageModelCompletionEvent::StartMessage { - message_id: id.clone(), - })); - } - } - ResponseOutputItem::FunctionCall(function_call) => { - if let Some(item_id) = function_call.id.clone() { - let call_id = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - .unwrap_or_else(|| item_id.clone()); - let entry = PendingResponseFunctionCall { - call_id, - name: Arc::::from( - function_call.name.clone().unwrap_or_default(), - ), - arguments: function_call.arguments.clone(), - }; - self.function_calls_by_item.insert(item_id, entry); - } - } - ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} - } - events - } - ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: delta, - signature: None, - })] - } - } - ResponsesStreamEvent::OutputTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Text(delta))] - } - } - ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { - if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { - entry.arguments.push_str(&delta); - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))]; - } - } - Vec::new() - } - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id, arguments, .. - } => { - if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { - if !arguments.is_empty() { - entry.arguments = arguments; - } - let raw_input = entry.arguments.clone(); - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(&entry.arguments) { - Ok(input) => { - vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: true, - input, - raw_input, - thought_signature: None, - }, - ))] - } - Err(error) => { - vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - tool_name: entry.name.clone(), - raw_input: Arc::::from(raw_input), - json_parse_error: error.to_string(), - })] - } - } - } else { - Vec::new() - } - } - ResponsesStreamEvent::Completed { response } => { - self.handle_completion(response, StopReason::EndTurn) - } - ResponsesStreamEvent::Incomplete { response } => { - let reason = response - .status_details - .as_ref() - .and_then(|details| details.reason.as_deref()); - let stop_reason = match reason { - Some("max_output_tokens") => StopReason::MaxTokens, - Some("content_filter") => { - self.pending_stop_reason = Some(StopReason::Refusal); - StopReason::Refusal - } - _ => self - .pending_stop_reason - .take() - .unwrap_or(StopReason::EndTurn), - }; - - let mut events = Vec::new(); - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - ResponsesStreamEvent::Failed { response } => { - let message = response - .status_details - .and_then(|details| details.error) - .map(|error| error.to_string()) - .unwrap_or_else(|| "response failed".to_string()); - vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] - } - ResponsesStreamEvent::Error { error } - | ResponsesStreamEvent::GenericError { error } => { - vec![Err(LanguageModelCompletionError::Other(anyhow!( - error.message - )))] - } - ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { - if summary_index > 0 { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "\n\n".to_string(), - signature: None, - })] - } else { - Vec::new() - } - } - ResponsesStreamEvent::OutputTextDone { .. } - | ResponsesStreamEvent::OutputItemDone { .. } - | ResponsesStreamEvent::ContentPartAdded { .. } - | ResponsesStreamEvent::ContentPartDone { .. } - | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } - | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } - | ResponsesStreamEvent::Created { .. } - | ResponsesStreamEvent::InProgress { .. } - | ResponsesStreamEvent::Unknown => Vec::new(), - } - } - - fn handle_completion( - &mut self, - response: ResponsesSummary, - default_reason: StopReason, - ) -> Vec> { - let mut events = Vec::new(); - - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - - let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - - fn emit_tool_calls_from_output( - &mut self, - output: &[ResponseOutputItem], - ) -> Vec> { - let mut events = Vec::new(); - for item in output { - if let ResponseOutputItem::FunctionCall(function_call) = item { - let Some(call_id) = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - else { - log::error!( - "Function call item missing both call_id and id: {:?}", - function_call - ); - continue; - }; - let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); - let arguments = &function_call.arguments; - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(arguments) { - Ok(input) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(call_id.clone()), - name: name.clone(), - is_input_complete: true, - input, - raw_input: arguments.clone(), - thought_signature: None, - }, - ))); - } - Err(error) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(call_id.clone()), - tool_name: name.clone(), - raw_input: Arc::::from(arguments.clone()), - json_parse_error: error.to_string(), - })); - } - } - } - } - events - } -} - -fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { - TokenUsage { - input_tokens: usage.input_tokens.unwrap_or_default(), - output_tokens: usage.output_tokens.unwrap_or_default(), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } -} - -pub(crate) fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -pub fn count_open_ai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::FourOmniMini - | Model::FourPointOneNano - | Model::O1 - | Model::O3 - | Model::O3Mini - | Model::Five - | Model::FiveCodex - | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), - // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer - Model::FivePointOne - | Model::FivePointTwo - | Model::FivePointTwoCodex - | Model::FivePointThreeCodex - | Model::FivePointFour - | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), - } - .map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1459,874 +606,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use futures::{StreamExt, executor::block_on}; - use gpui::TestAppContext; - use language_model::{ - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - }; - use open_ai::responses::{ - ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, - ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, - StreamEvent as ResponsesStreamEvent, - }; - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::*; - - fn map_response_events(events: Vec) -> Vec { - block_on(async { - OpenAiResponseEventMapper::new() - .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) - .collect::>() - .await - .into_iter() - .map(Result::unwrap) - .collect() - }) - } - - fn response_item_message(id: &str) -> ResponseOutputItem { - ResponseOutputItem::Message(ResponseOutputMessage { - id: Some(id.to_string()), - role: Some("assistant".to_string()), - status: Some("in_progress".to_string()), - content: vec![], - }) - } - - fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { - ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some(id.to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_123".to_string()), - arguments: args.map(|s| s.to_string()).unwrap_or_default(), - }) - } - - #[gpui::test] - fn tiktoken_rs_support(cx: &TestAppContext) { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - reasoning_details: None, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in Model::iter() { - let count = cx - .foreground_executor() - .block_on(count_open_ai_tokens( - request.clone(), - model, - &cx.app.borrow(), - )) - .unwrap(); - assert!(count > 0); - } - } - - #[test] - fn responses_stream_maps_text_and_usage() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Hello".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary { - usage: Some(ResponseUsage { - input_tokens: Some(5), - output_tokens: Some(3), - total_tokens: Some(8), - }), - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Hello" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 5, - output_tokens: 3, - .. - }) - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::EndTurn) - )); - } - - #[test] - fn into_open_ai_response_builds_complete_payload() { - let tool_call_id = LanguageModelToolUseId::from("call-42"); - let tool_input = json!({ "city": "Boston" }); - let tool_arguments = serde_json::to_string(&tool_input).unwrap(); - let tool_use = LanguageModelToolUse { - id: tool_call_id.clone(), - name: Arc::from("get_weather"), - raw_input: tool_arguments.clone(), - input: tool_input, - is_input_complete: true, - thought_signature: None, - }; - let tool_result = LanguageModelToolResult { - tool_use_id: tool_call_id, - tool_name: Arc::from("get_weather"), - is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), - output: Some(json!({ "forecast": "Sunny" })), - }; - let user_image = LanguageModelImage { - source: SharedString::from("aGVsbG8="), - size: None, - }; - let expected_image_url = user_image.to_base64_url(); - - let request = LanguageModelRequest { - thread_id: Some("thread-123".into()), - prompt_id: None, - intent: None, - messages: vec![ - LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text("System context".into())], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Please check the weather.".into()), - MessageContent::Image(user_image), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![ - MessageContent::Text("Looking that up.".into()), - MessageContent::ToolUse(tool_use), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolResult(tool_result)], - cache: false, - reasoning_details: None, - }, - ], - tools: vec![LanguageModelRequestTool { - name: "get_weather".into(), - description: "Fetches the weather".into(), - input_schema: json!({ "type": "object" }), - use_input_streaming: false, - }], - tool_choice: Some(LanguageModelToolChoice::Any), - stop: vec!["".into()], - temperature: None, - thinking_allowed: false, - thinking_effort: None, - speed: None, - }; - - let response = into_open_ai_response( - request, - "custom-model", - true, - true, - Some(2048), - Some(ReasoningEffort::Low), - ); - - let serialized = serde_json::to_value(&response).unwrap(); - let expected = json!({ - "model": "custom-model", - "input": [ - { - "type": "message", - "role": "system", - "content": [ - { "type": "input_text", "text": "System context" } - ] - }, - { - "type": "message", - "role": "user", - "content": [ - { "type": "input_text", "text": "Please check the weather." }, - { "type": "input_image", "image_url": expected_image_url } - ] - }, - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Looking that up.", "annotations": [] } - ] - }, - { - "type": "function_call", - "call_id": "call-42", - "name": "get_weather", - "arguments": tool_arguments - }, - { - "type": "function_call_output", - "call_id": "call-42", - "output": "Sunny" - } - ], - "stream": true, - "max_output_tokens": 2048, - "parallel_tool_calls": true, - "tool_choice": "required", - "tools": [ - { - "type": "function", - "name": "get_weather", - "description": "Fetches the weather", - "parameters": { "type": "object" } - } - ], - "prompt_cache_key": "thread-123", - "reasoning": { "effort": "low", "summary": "auto" } - }); - - assert_eq!(serialized, expected); - } - - #[test] - fn responses_stream_maps_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from FunctionCallArgumentsDone) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref id, - ref name, - ref raw_input, - is_input_complete: true, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_uses_max_tokens_stop_reason() { - let events = vec![ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - usage: Some(ResponseUsage { - input_tokens: Some(10), - output_tokens: Some(20), - total_tokens: Some(30), - }), - ..Default::default() - }, - }]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 10, - output_tokens: 20, - .. - }) - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_multiple_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn1".into(), - output_index: 0, - arguments: "{\"city\":\"NYC\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn2".into(), - output_index: 1, - arguments: "{\"city\":\"LA\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"NYC\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"LA\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_mixed_text_and_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Let me check that".into(), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 1, - arguments: "{\"query\":\"test\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { .. } - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"query\":\"test\"}" - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_json_parse_error() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{invalid json")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{invalid json".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUseJsonParseError { - ref raw_input, - .. - } if raw_input.as_ref() == "{invalid json" - )); - } - - #[test] - fn responses_stream_handles_incomplete_function_call() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "\"Boston\"".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from the Incomplete response output) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref raw_input, - is_input_complete: true, - .. - }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_incomplete_does_not_duplicate_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_empty_tool_arguments() { - // Test that tools with no arguments (empty string) are handled correctly - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - - // Should produce a ToolUse event with an empty object - assert!(matches!( - &mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - id, - name, - raw_input, - input, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "" - && input.is_object() - && input.as_object().unwrap().is_empty() - )); - - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_emits_partial_tool_use_events() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some("item_fn".to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_abc".to_string()), - arguments: String::new(), - }), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "{\"city\":\"Bos".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - // Two partial events + one complete event + Stop - assert!(mapped.len() >= 3); - - // The last complete ToolUse event should have is_input_complete: true - let complete_tool_use = mapped.iter().find(|e| { - matches!( - e, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - ) - }); - assert!( - complete_tool_use.is_some(), - "should have a complete tool use event" - ); - - // All ToolUse events before the final one should have is_input_complete: false - let tool_uses: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) - .collect(); - assert!( - tool_uses.len() >= 2, - "should have at least one partial and one complete event" - ); - - let last = tool_uses.last().unwrap(); - assert!(matches!( - last, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - )); - } - - #[test] - fn responses_stream_maps_reasoning_summary_deltas() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Thinking about".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: " the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Thinking about the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![ - ReasoningSummaryPart::SummaryText { - text: "Thinking about the answer".into(), - }, - ReasoningSummaryPart::SummaryText { - text: "Second part".into(), - }, - ], - }), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_message("msg_456"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_456".into(), - output_index: 1, - content_index: Some(0), - delta: "The answer is 42".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - let thinking_events: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) - .collect(); - assert_eq!( - thinking_events.len(), - 4, - "expected 4 thinking events (2 deltas + separator + second delta), got {:?}", - thinking_events, - ); - - assert!(matches!( - &thinking_events[0], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about" - )); - assert!(matches!( - &thinking_events[1], - LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer" - )); - assert!( - matches!( - &thinking_events[2], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n" - ), - "expected separator between summary parts" - ); - assert!(matches!( - &thinking_events[3], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part" - )); - - assert!(mapped.iter().any(|e| matches!( - e, - LanguageModelCompletionEvent::Text(t) if t == "The answer is 42" - ))); - } - - #[test] - fn responses_stream_maps_reasoning_from_done_only() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![ReasoningSummaryPart::SummaryText { - text: "Summary without deltas".into(), - }], - }), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - assert!( - !mapped - .iter() - .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), - "OutputItemDone reasoning should not produce Thinking events (no delta/done text events)" - ); - } -} diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 1c3268749c3340826cd2f50d29e80eecfa1826d4..7a3126f8f33beb7851ea914cfe063b76f8b4443f 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -402,7 +402,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_completion(request, cx); async move { @@ -417,7 +417,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_response(request, cx); async move { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 09c8eb768d12c61ed1dc86a1251ad52114be6162..fba3a6938aecf1db80680e014e408e4d59c42ff7 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 88189864c7b4b650a24afb2b872c1d6105cf9782..e95bc1ba72fabcf9632b2ed2efd94254fb1313cd 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -9,7 +9,7 @@ use language_model::{ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, - Role, env_var, + env_var, }; use open_ai::ResponseStreamEvent; pub use settings::XaiAvailableModel as AvailableModel; @@ -19,7 +19,8 @@ use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use x_ai::{Model, XAI_API_URL}; +use x_ai::XAI_API_URL; +pub use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); @@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_xai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel { } } -pub fn count_xai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - let model_name = if model.max_token_count() >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models_cloud/Cargo.toml b/crates/language_models_cloud/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..b08acc5ecd5c2a718e936378c2dbfbc3d1c32df0 --- /dev/null +++ b/crates/language_models_cloud/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "language_models_cloud" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_models_cloud.rs" + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } +gpui.workspace = true +http_client.workspace = true +language_model.workspace = true +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +thiserror.workspace = true +x_ai = { workspace = true, features = ["schemars"] } + +[dev-dependencies] +language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models_cloud/LICENSE-GPL b/crates/language_models_cloud/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_models_cloud/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_models_cloud/src/language_models_cloud.rs b/crates/language_models_cloud/src/language_models_cloud.rs new file mode 100644 index 0000000000000000000000000000000000000000..24c8ec87d5c672dbc18b20164f2fe28c9b46b2e1 --- /dev/null +++ b/crates/language_models_cloud/src/language_models_cloud.rs @@ -0,0 +1,1059 @@ +use anthropic::AnthropicModelMode; +use anyhow::{Context as _, Result, anyhow}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, + CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, + CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, + OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, + ZED_VERSION_HEADER_NAME, +}; +use futures::{ + AsyncBufReadExt, FutureExt, Stream, StreamExt, + future::BoxFuture, + stream::{self, BoxStream}, +}; +use google_ai::GoogleModelMode; +use gpui::{App, AppContext, AsyncApp, Context, Task}; +use http_client::http::{HeaderMap, HeaderValue}; +use http_client::{ + AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode, +}; +use language_model::{ + ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, + LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID, + OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, +}; + +use schemars::JsonSchema; +use semver::Version; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use smol::io::{AsyncReadExt, BufReader}; +use std::collections::VecDeque; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use thiserror::Error; + +use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, +}; +use google_ai::completion::{GoogleEventMapper, into_google}; +use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, + into_open_ai_response, +}; +use x_ai::completion::count_xai_tokens; + +const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; + +/// Trait for acquiring and refreshing LLM authentication tokens. +pub trait CloudLlmTokenProvider: Send + Sync { + type AuthContext: Clone + Send + 'static; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext; + fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; + fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; +} + +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for AnthropicModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => AnthropicModelMode::Default, + ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, + } + } +} + +pub struct CloudLanguageModel { + pub id: LanguageModelId, + pub model: Arc, + pub token_provider: Arc, + pub http_client: Arc, + pub app_version: Option, + pub request_limiter: RateLimiter, +} + +pub struct PerformLlmCompletionResponse { + pub response: Response, + pub includes_status_messages: bool, +} + +impl CloudLanguageModel { + pub async fn perform_llm_completion( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + app_version: Option, + body: CompletionBody, + ) -> Result { + let mut token = token_provider.acquire_token(auth_context.clone()).await?; + let mut refreshed_token = false; + + loop { + let request = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) + .when_some(app_version.as_ref(), |builder, app_version| { + builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + }) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") + .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") + .body(serde_json::to_string(&body)?.into())?; + + let mut response = http_client.send(request).await?; + let status = response.status(); + if status.is_success() { + let includes_status_messages = response + .headers() + .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) + .is_some(); + + return Ok(PerformLlmCompletionResponse { + response, + includes_status_messages, + }); + } + + if !refreshed_token && needs_llm_token_refresh(&response) { + token = token_provider.refresh_token(auth_context.clone()).await?; + refreshed_token = true; + continue; + } + + if status == StatusCode::PAYMENT_REQUIRED { + return Err(anyhow!(PaymentRequiredError)); + } + + let mut body = String::new(); + let headers = response.headers().clone(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!(ApiError { + status, + body, + headers + })); + } + } +} + +fn needs_llm_token_refresh(response: &Response) -> bool { + response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + || response + .headers() + .get(OUTDATED_LLM_TOKEN_HEADER_NAME) + .is_some() +} + +#[derive(Debug, Error)] +#[error("cloud language model request failed with status {status}: {body}")] +struct ApiError { + status: StatusCode, + body: String, + headers: HeaderMap, +} + +/// Represents error responses from Zed's cloud API. +/// +/// Example JSON for an upstream HTTP error: +/// ```json +/// { +/// "code": "upstream_http_error", +/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", +/// "upstream_status": 503 +/// } +/// ``` +#[derive(Debug, serde::Deserialize)] +struct CloudApiError { + code: String, + message: String, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_status_code")] + upstream_status: Option, + #[serde(default)] + retry_after: Option, +} + +fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) +} + +impl From for LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + if let Ok(cloud_error) = serde_json::from_str::(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; + + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + + return LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + cloud_error.message, + None, + ); + } + + let retry_after = None; + LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + error.body, + retry_after, + ) + } +} + +impl LanguageModel for CloudLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name.clone()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn upstream_provider_id(&self) -> LanguageModelProviderId { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_ID, + OpenAi => OPEN_AI_PROVIDER_ID, + Google => GOOGLE_PROVIDER_ID, + XAi => X_AI_PROVIDER_ID, + } + } + + fn upstream_provider_name(&self) -> LanguageModelProviderName { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_NAME, + OpenAi => OPEN_AI_PROVIDER_NAME, + Google => GOOGLE_PROVIDER_NAME, + XAi => X_AI_PROVIDER_NAME, + } + } + + fn is_latest(&self) -> bool { + self.model.is_latest + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools + } + + fn supports_images(&self) -> bool { + self.model.supports_images + } + + fn supports_thinking(&self) -> bool { + self.model.supports_thinking + } + + fn supports_fast_mode(&self) -> bool { + self.model.supports_fast_mode + } + + fn supported_effort_levels(&self) -> Vec { + self.model + .supported_effort_levels + .iter() + .map(|effort_level| LanguageModelEffortLevel { + name: effort_level.name.clone().into(), + value: effort_level.value.clone().into(), + is_default: effort_level.is_default.unwrap_or(false), + }) + .collect() + } + + fn supports_streaming_tools(&self) -> bool { + self.model.supports_streaming_tools + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn supports_split_token_display(&self) -> bool { + use cloud_llm_client::LanguageModelProvider::*; + matches!(self.model.provider, OpenAi | XAi) + } + + fn telemetry_id(&self) -> String { + format!("zed.dev/{}", self.model.id) + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { + LanguageModelToolSchemaFormat::JsonSchema + } + cloud_llm_client::LanguageModelProvider::Google + | cloud_llm_client::LanguageModelProvider::XAi => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } + } + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count as u64 + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens as u64) + } + + fn cache_configuration(&self) -> Option { + match &self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + Some(LanguageModelCacheConfiguration { + min_total_token: 2_048, + should_speculate: true, + max_cache_anchors: 4, + }) + } + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::XAi + | cloud_llm_client::LanguageModelProvider::Google => None, + } + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => cx + .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .boxed(), + cloud_llm_client::LanguageModelProvider::OpenAi => { + let model = match open_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let model = match x_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let model_id = self.model.id.to_string(); + let generate_content_request = + into_google(request, model_id.clone(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(&cx.to_async()); + async move { + let token = token_provider.acquire_token(auth_context).await?; + + let request_body = CountTokensBody { + provider: cloud_llm_client::LanguageModelProvider::Google, + model: model_id, + provider_request: serde_json::to_value(&google_ai::CountTokensRequest { + generate_content_request, + })?, + }; + let request = http_client::Request::builder() + .method(Method::POST) + .uri( + http_client + .build_zed_llm_url("/count_tokens", &[])? + .as_ref(), + ) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&request_body)?.into())?; + let mut response = http_client.send(request).await?; + let status = response.status(); + let headers = response.headers().clone(); + let mut response_body = String::new(); + response + .body_mut() + .read_to_string(&mut response_body) + .await?; + + if status.is_success() { + let response_body: CountTokensResponse = + serde_json::from_str(&response_body)?; + + Ok(response_body.tokens as u64) + } else { + Err(anyhow!(ApiError { + status, + body: response_body, + headers + })) + } + } + .boxed() + } + } + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let thread_id = request.thread_id.clone(); + let prompt_id = request.prompt_id.clone(); + let app_version = self.app_version.clone(); + let thinking_allowed = request.thinking_allowed; + let enable_thinking = thinking_allowed && self.model.supports_thinking; + let provider_name = provider_name(&self.model.provider); + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| anthropic::Effort::from_str(effort).ok()); + + let mut request = into_anthropic( + request, + self.model.id.to_string(), + 1.0, + self.model.max_output_tokens as u64, + if enable_thinking { + AnthropicModelMode::Thinking { + budget_tokens: Some(4_096), + } + } else { + AnthropicModelMode::Default + }, + ); + + if enable_thinking && effort.is_some() { + request.thinking = Some(anthropic::Thinking::Adaptive); + request.output_config = Some(anthropic::OutputConfig { effort }); + } + + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await + .map_err(|err| match err.downcast::() { + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), + Err(err) => anyhow!(err), + })?; + + let mut mapper = AnthropicEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::OpenAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); + + let mut request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + + if enable_thinking && let Some(effort) = effort { + request.reasoning = Some(open_ai::responses::ReasoningConfig { + effort, + summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), + }); + } + + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + false, + None, + None, + ); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::XAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = + into_google(request, self.model.id.to_string(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Google, + model: request.model.model_id.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = GoogleEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + } + } +} + +pub struct CloudModelProvider { + token_provider: Arc, + http_client: Arc, + app_version: Option, + models: Vec>, + default_model: Option>, + default_fast_model: Option>, + recommended_models: Vec>, +} + +impl CloudModelProvider { + pub fn new( + token_provider: Arc, + http_client: Arc, + app_version: Option, + ) -> Self { + Self { + token_provider, + http_client, + app_version, + models: Vec::new(), + default_model: None, + default_fast_model: None, + recommended_models: Vec::new(), + } + } + + pub fn refresh_models(&self, cx: &mut Context) -> Task> { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + cx.spawn(async move |this, cx| { + let auth_context = token_provider.auth_context(cx); + let response = + Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?; + this.update(cx, |this, cx| { + this.update_models(response); + cx.notify(); + }) + }) + } + + async fn fetch_models_request( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + ) -> Result { + let token = token_provider.acquire_token(auth_context).await?; + + let request = http_client::Request::builder() + .method(Method::GET) + .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") + .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) + .header("Authorization", format!("Bearer {token}")) + .body(AsyncBody::empty())?; + let mut response = http_client + .send(request) + .await + .context("failed to send list models request")?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Ok(serde_json::from_str(&body)?) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "error listing models.\nStatus: {:?}\nBody: {body}", + response.status(), + ); + } + } + + pub fn update_models(&mut self, response: ListModelsResponse) { + let models: Vec<_> = response.models.into_iter().map(Arc::new).collect(); + + self.default_model = models + .iter() + .find(|model| { + response + .default_model + .as_ref() + .is_some_and(|default_model_id| &model.id == default_model_id) + }) + .cloned(); + self.default_fast_model = models + .iter() + .find(|model| { + response + .default_fast_model + .as_ref() + .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) + }) + .cloned(); + self.recommended_models = response + .recommended_models + .iter() + .filter_map(|id| models.iter().find(|model| &model.id == id)) + .cloned() + .collect(); + self.models = models; + } + + pub fn create_model( + &self, + model: &Arc, + ) -> Arc { + Arc::new(CloudLanguageModel:: { + id: LanguageModelId::from(model.id.0.to_string()), + model: model.clone(), + token_provider: self.token_provider.clone(), + http_client: self.http_client.clone(), + app_version: self.app_version.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + pub fn models(&self) -> &[Arc] { + &self.models + } + + pub fn default_model(&self) -> Option<&Arc> { + self.default_model.as_ref() + } + + pub fn default_fast_model(&self) -> Option<&Arc> { + self.default_fast_model.as_ref() + } + + pub fn recommended_models(&self) -> &[Arc] { + &self.recommended_models + } +} + +pub fn map_cloud_completion_events( + stream: Pin>> + Send>>, + provider: &LanguageModelProviderName, + mut map_callback: F, +) -> BoxStream<'static, Result> +where + T: DeserializeOwned + 'static, + F: FnMut(T) -> Vec> + + Send + + 'static, +{ + let provider = provider.clone(); + let mut stream = stream.fuse(); + + let mut saw_stream_ended = false; + + let mut done = false; + let mut pending = VecDeque::new(); + + stream::poll_fn(move |cx| { + loop { + if let Some(item) = pending.pop_front() { + return Poll::Ready(Some(item)); + } + + if done { + return Poll::Ready(None); + } + + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => { + let items = match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { + saw_stream_ended = true; + vec![] + } + Ok(CompletionEvent::Status(status)) => { + LanguageModelCompletionEvent::from_completion_request_status( + status, + provider.clone(), + ) + .transpose() + .map(|event| vec![event]) + .unwrap_or_default() + } + Ok(CompletionEvent::Event(event)) => map_callback(event), + }; + pending.extend(items); + } + Poll::Ready(None) => { + done = true; + + if !saw_stream_ended { + return Poll::Ready(Some(Err( + LanguageModelCompletionError::StreamEndedUnexpectedly { + provider: provider.clone(), + }, + ))); + } + } + Poll::Pending => return Poll::Pending, + } + } + }) + .boxed() +} + +pub fn provider_name( + provider: &cloud_llm_client::LanguageModelProvider, +) -> LanguageModelProviderName { + match provider { + cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, + } +} + +pub fn response_lines( + response: Response, + includes_status_messages: bool, +) -> impl Stream>> { + futures::stream::try_unfold( + (String::new(), BufReader::new(response.into_body())), + move |(mut line, mut body)| async move { + match body.read_line(&mut line).await { + Ok(0) => Ok(None), + Ok(_) => { + let event = if includes_status_messages { + serde_json::from_str::>(&line)? + } else { + CompletionEvent::Event(serde_json::from_str::(&line)?) + }; + + line.clear(); + Ok(Some((event, (line, body)))) + } + Err(e) => Err(e.into()), + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use http_client::http::{HeaderMap, StatusCode}; + use language_model::LanguageModelCompletionError; + + #[test] + fn test_api_error_conversion_with_upstream_http_error() { + // upstream_http_error with 503 status should become ServerOverloaded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 503, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 500 status should become ApiInternalServerError + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the OpenAI API: internal server error" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 500, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 429 status should become RateLimitExceeded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Google API: rate limit exceeded" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 429, got: {:?}", + completion_error + ), + } + + // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed + let error_body = "Regular internal server error"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider, PROVIDER_NAME); + assert_eq!(message, "Regular internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for regular 500, got: {:?}", + completion_error + ), + } + + // upstream_http_429 format should be converted to UpstreamProviderError + let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { + message, + status, + retry_after, + } => { + assert_eq!(message, "Upstream Anthropic rate limit exceeded."); + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); + } + _ => panic!( + "Expected UpstreamProviderError for upstream_http_429, got: {:?}", + completion_error + ), + } + + // Invalid JSON in error body should fall back to regular error handling + let error_body = "Not JSON at all"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { + assert_eq!(provider, PROVIDER_NAME); + } + _ => panic!( + "Expected ApiInternalServerError for invalid JSON, got: {:?}", + completion_error + ), + } + } +} diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 3de3a4dc3fcb8c9519f4c67be7cead75401f6281..9a73e73196fa225691fa68e2ca839a19783bc3ca 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -17,13 +17,18 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true rand.workspace = true schemars = { workspace = true, optional = true } log.workspace = true serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..81fa79d35ee134ef4fee7618aec17d34e9382cec --- /dev/null +++ b/crates/open_ai/src/completion.rs @@ -0,0 +1,1693 @@ +use anyhow::{Result, anyhow}; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::sync::Arc; + +use crate::responses::{ + Request as ResponseRequest, ResponseFunctionCallItem, ResponseFunctionCallOutputContent, + ResponseFunctionCallOutputItem, ResponseInputContent, ResponseInputItem, ResponseMessageItem, + ResponseOutputItem, ResponseSummary as ResponsesSummary, ResponseUsage as ResponsesUsage, + StreamEvent as ResponsesStreamEvent, +}; +use crate::{ + FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort, + ResponseStreamEvent, ToolCall, ToolCallContent, +}; + +pub fn into_open_ai( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> crate::Request { + let stream = !model_id.starts_with("o1-"); + + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + let should_add = if message.role == Role::User { + // Including whitespace-only user messages can cause error with OpenAI compatible APIs + // See https://github.com/zed-industries/zed/issues/40097 + !text.trim().is_empty() + } else { + !text.is_empty() + }; + if should_add { + add_message_content_part( + MessagePart::Text { text }, + message.role, + &mut messages, + ); + } + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }, + message.role, + &mut messages, + ); + } + MessageContent::ToolUse(tool_use) => { + let tool_call = ToolCall { + id: tool_use.id.to_string(), + content: ToolCallContent::Function { + function: FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(crate::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(crate::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + vec![MessagePart::Text { + text: text.to_string(), + }] + } + LanguageModelToolResultContent::Image(image) => { + vec![MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }] + } + }; + + messages.push(crate::RequestMessage::Tool { + content: content.into(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + crate::Request { + model: model_id.into(), + messages, + stream, + stream_options: if stream { + Some(crate::StreamOptions::default()) + } else { + None + }, + stop: request.stop, + temperature: request.temperature.or(Some(1.0)), + max_completion_tokens: max_output_tokens, + parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { + Some(supports_parallel_tool_calls) + } else { + None + }, + prompt_cache_key: if supports_prompt_cache_key { + request.thread_id + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| crate::ToolDefinition::Function { + function: FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + reasoning_effort, + } +} + +pub fn into_open_ai_response( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> ResponseRequest { + let stream = !model_id.starts_with("o1-"); + + let LanguageModelRequest { + thread_id, + prompt_id: _, + intent: _, + messages, + tools, + tool_choice, + stop: _, + temperature, + thinking_allowed: _, + thinking_effort: _, + speed: _, + } = request; + + let mut input_items = Vec::new(); + for (index, message) in messages.into_iter().enumerate() { + append_message_to_response_items(message, index, &mut input_items); + } + + let tools: Vec<_> = tools + .into_iter() + .map(|tool| crate::responses::ToolDefinition::Function { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + strict: None, + }) + .collect(); + + ResponseRequest { + model: model_id.into(), + input: input_items, + stream, + temperature, + top_p: None, + max_output_tokens, + parallel_tool_calls: if tools.is_empty() { + None + } else { + Some(supports_parallel_tool_calls) + }, + tool_choice: tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + tools, + prompt_cache_key: if supports_prompt_cache_key { + thread_id + } else { + None + }, + reasoning: reasoning_effort.map(|effort| crate::responses::ReasoningConfig { + effort, + summary: Some(crate::responses::ReasoningSummaryMode::Auto), + }), + } +} + +fn append_message_to_response_items( + message: LanguageModelRequestMessage, + index: usize, + input_items: &mut Vec, +) { + let mut content_parts: Vec = Vec::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::Thinking { text, .. } => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + push_response_image_part(&message.role, image, &mut content_parts); + } + MessageContent::ToolUse(tool_use) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + let call_id = tool_use.id.to_string(); + input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { + call_id, + name: tool_use.name.to_string(), + arguments: tool_use.raw_input, + })); + } + MessageContent::ToolResult(tool_result) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + input_items.push(ResponseInputItem::FunctionCallOutput( + ResponseFunctionCallOutputItem { + call_id: tool_result.tool_use_id.to_string(), + output: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ResponseFunctionCallOutputContent::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ResponseFunctionCallOutputContent::List(vec![ + ResponseInputContent::Image { + image_url: image.to_base64_url(), + }, + ]) + } + }, + }, + )); + } + } + } + + flush_response_parts(&message.role, index, &mut content_parts, input_items); +} + +fn push_response_text_part( + role: &Role, + text: impl Into, + parts: &mut Vec, +) { + let text = text.into(); + if text.trim().is_empty() { + return; + } + + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text, + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Text { text }), + } +} + +fn push_response_image_part( + role: &Role, + image: LanguageModelImage, + parts: &mut Vec, +) { + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Image { + image_url: image.to_base64_url(), + }), + } +} + +fn flush_response_parts( + role: &Role, + _index: usize, + parts: &mut Vec, + input_items: &mut Vec, +) { + if parts.is_empty() { + return; + } + + let item = ResponseInputItem::Message(ResponseMessageItem { + role: match role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => crate::Role::System, + }, + content: parts.clone(), + }); + + input_items.push(item); + parts.clear(); +} + +fn add_message_content_part( + new_part: MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(crate::RequestMessage::User { content })) + | ( + Role::Assistant, + Some(crate::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) + | (Role::System, Some(crate::RequestMessage::System { content, .. })) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => crate::RequestMessage::User { + content: crate::MessageContent::from(vec![new_part]), + }, + Role::Assistant => crate::RequestMessage::Assistant { + content: Some(crate::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + }, + Role::System => crate::RequestMessage::System { + content: crate::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + +pub struct OpenAiEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenAiEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let mut events = Vec::new(); + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + + let Some(choice) = event.choices.first() else { + return events; + }; + + if let Some(delta) = choice.delta.as_ref() { + if let Some(reasoning_content) = delta.reasoning_content.clone() { + if !reasoning_content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: reasoning_content, + signature: None, + })); + } + } + if let Some(content) = delta.content.clone() { + if !content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + } + + if let Some(tool_calls) = delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match parse_tool_arguments(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + thought_signature: None, + }, + )), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.clone().into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + +pub struct OpenAiResponseEventMapper { + function_calls_by_item: HashMap, + pending_stop_reason: Option, +} + +#[derive(Default)] +struct PendingResponseFunctionCall { + call_id: String, + name: Arc, + arguments: String, +} + +impl OpenAiResponseEventMapper { + pub fn new() -> Self { + Self { + function_calls_by_item: HashMap::default(), + pending_stop_reason: None, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponsesStreamEvent, + ) -> Vec> { + match event { + ResponsesStreamEvent::OutputItemAdded { item, .. } => { + let mut events = Vec::new(); + + match &item { + ResponseOutputItem::Message(message) => { + if let Some(id) = &message.id { + events.push(Ok(LanguageModelCompletionEvent::StartMessage { + message_id: id.clone(), + })); + } + } + ResponseOutputItem::FunctionCall(function_call) => { + if let Some(item_id) = function_call.id.clone() { + let call_id = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + .unwrap_or_else(|| item_id.clone()); + let entry = PendingResponseFunctionCall { + call_id, + name: Arc::::from( + function_call.name.clone().unwrap_or_default(), + ), + arguments: function_call.arguments.clone(), + }; + self.function_calls_by_item.insert(item_id, entry); + } + } + ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} + } + events + } + ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: delta, + signature: None, + })] + } + } + ResponsesStreamEvent::OutputTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Text(delta))] + } + } + ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { + if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { + entry.arguments.push_str(&delta); + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))]; + } + } + Vec::new() + } + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id, arguments, .. + } => { + if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { + if !arguments.is_empty() { + entry.arguments = arguments; + } + let raw_input = entry.arguments.clone(); + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(&entry.arguments) { + Ok(input) => { + vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: true, + input, + raw_input, + thought_signature: None, + }, + ))] + } + Err(error) => { + vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + tool_name: entry.name.clone(), + raw_input: Arc::::from(raw_input), + json_parse_error: error.to_string(), + })] + } + } + } else { + Vec::new() + } + } + ResponsesStreamEvent::Completed { response } => { + self.handle_completion(response, StopReason::EndTurn) + } + ResponsesStreamEvent::Incomplete { response } => { + let reason = response + .status_details + .as_ref() + .and_then(|details| details.reason.as_deref()); + let stop_reason = match reason { + Some("max_output_tokens") => StopReason::MaxTokens, + Some("content_filter") => { + self.pending_stop_reason = Some(StopReason::Refusal); + StopReason::Refusal + } + _ => self + .pending_stop_reason + .take() + .unwrap_or(StopReason::EndTurn), + }; + + let mut events = Vec::new(); + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + ResponsesStreamEvent::Failed { response } => { + let message = response + .status_details + .and_then(|details| details.error) + .map(|error| error.to_string()) + .unwrap_or_else(|| "response failed".to_string()); + vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] + } + ResponsesStreamEvent::Error { error } + | ResponsesStreamEvent::GenericError { error } => { + vec![Err(LanguageModelCompletionError::Other(anyhow!( + error.message + )))] + } + ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { + if summary_index > 0 { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "\n\n".to_string(), + signature: None, + })] + } else { + Vec::new() + } + } + ResponsesStreamEvent::OutputTextDone { .. } + | ResponsesStreamEvent::OutputItemDone { .. } + | ResponsesStreamEvent::ContentPartAdded { .. } + | ResponsesStreamEvent::ContentPartDone { .. } + | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } + | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } + | ResponsesStreamEvent::Created { .. } + | ResponsesStreamEvent::InProgress { .. } + | ResponsesStreamEvent::Unknown => Vec::new(), + } + } + + fn handle_completion( + &mut self, + response: ResponsesSummary, + default_reason: StopReason, + ) -> Vec> { + let mut events = Vec::new(); + + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + + let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + + fn emit_tool_calls_from_output( + &mut self, + output: &[ResponseOutputItem], + ) -> Vec> { + let mut events = Vec::new(); + for item in output { + if let ResponseOutputItem::FunctionCall(function_call) = item { + let Some(call_id) = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + else { + log::error!( + "Function call item missing both call_id and id: {:?}", + function_call + ); + continue; + }; + let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); + let arguments = &function_call.arguments; + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(arguments) { + Ok(input) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(call_id.clone()), + name: name.clone(), + is_input_complete: true, + input, + raw_input: arguments.clone(), + thought_signature: None, + }, + ))); + } + Err(error) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(call_id.clone()), + tool_name: name.clone(), + raw_input: Arc::::from(arguments.clone()), + json_parse_error: error.to_string(), + })); + } + } + } + } + events + } +} + +fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or_default(), + output_tokens: usage.output_tokens.unwrap_or_default(), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } +} + +pub fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + +/// Count tokens for an OpenAI model. This is synchronous; callers should spawn +/// it on a background thread if needed. +pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = collect_tiktoken_messages(request); + match model { + Model::Custom { max_tokens, .. } => { + let model = if max_tokens >= 100_000 { + // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer + "gpt-4o" + } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages) + } + // Currently supported by tiktoken_rs + // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch + // arm with an override. We enumerate all supported models here so that we can check if new + // models are supported yet or not. + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::FourOmniMini + | Model::FourPointOneNano + | Model::O1 + | Model::O3 + | Model::O3Mini + | Model::Five + | Model::FiveCodex + | Model::FiveMini + | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer + Model::FivePointOne + | Model::FivePointTwo + | Model::FivePointTwoCodex + | Model::FivePointThreeCodex + | Model::FivePointFour + | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), + } + .map(|tokens| tokens as u64) +} + +#[cfg(test)] +mod tests { + use crate::responses::{ + ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, + ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, + StreamEvent as ResponsesStreamEvent, + }; + use futures::{StreamExt, executor::block_on}; + use language_model_core::{ + LanguageModelImage, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, + LanguageModelToolUseId, SharedString, + }; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + fn map_response_events(events: Vec) -> Vec { + block_on(async { + OpenAiResponseEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + .into_iter() + .map(Result::unwrap) + .collect() + }) + } + + fn response_item_message(id: &str) -> ResponseOutputItem { + ResponseOutputItem::Message(ResponseOutputMessage { + id: Some(id.to_string()), + role: Some("assistant".to_string()), + status: Some("in_progress".to_string()), + content: vec![], + }) + } + + fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { + ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { + id: Some(id.to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_123".to_string()), + arguments: args.map(|s| s.to_string()).unwrap_or_default(), + }) + } + + #[test] + fn tiktoken_rs_support() { + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("message".into())], + cache: false, + reasoning_details: None, + }], + tools: vec![], + tool_choice: None, + stop: vec![], + temperature: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + // Validate that all models are supported by tiktoken-rs + for model in ::iter() { + let count = count_open_ai_tokens(request.clone(), model).unwrap(); + assert!(count > 0); + } + } + + #[test] + fn responses_stream_maps_text_and_usage() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Hello".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary { + usage: Some(ResponseUsage { + input_tokens: Some(5), + output_tokens: Some(3), + total_tokens: Some(8), + }), + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Text(ref text) if text == "Hello" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 5, + output_tokens: 3, + .. + }) + )); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) + )); + } + + #[test] + fn into_open_ai_response_builds_complete_payload() { + let tool_call_id = LanguageModelToolUseId::from("call-42"); + let tool_input = json!({ "city": "Boston" }); + let tool_arguments = serde_json::to_string(&tool_input).unwrap(); + let tool_use = LanguageModelToolUse { + id: tool_call_id.clone(), + name: Arc::from("get_weather"), + raw_input: tool_arguments.clone(), + input: tool_input, + is_input_complete: true, + thought_signature: None, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: tool_call_id, + tool_name: Arc::from("get_weather"), + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), + output: Some(json!({ "forecast": "Sunny" })), + }; + let user_image = LanguageModelImage { + source: SharedString::from("aGVsbG8="), + size: None, + }; + let expected_image_url = user_image.to_base64_url(); + + let request = LanguageModelRequest { + thread_id: Some("thread-123".into()), + prompt_id: None, + intent: None, + messages: vec![ + LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text("System context".into())], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Please check the weather.".into()), + MessageContent::Image(user_image), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![ + MessageContent::Text("Looking that up.".into()), + MessageContent::ToolUse(tool_use), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false, + reasoning_details: None, + }, + ], + tools: vec![LanguageModelRequestTool { + name: "get_weather".into(), + description: "Fetches the weather".into(), + input_schema: json!({ "type": "object" }), + use_input_streaming: false, + }], + tool_choice: Some(LanguageModelToolChoice::Any), + stop: vec!["".into()], + temperature: None, + thinking_allowed: false, + thinking_effort: None, + speed: None, + }; + + let response = into_open_ai_response( + request, + "custom-model", + true, + true, + Some(2048), + Some(ReasoningEffort::Low), + ); + + let serialized = serde_json::to_value(&response).unwrap(); + let expected = json!({ + "model": "custom-model", + "input": [ + { + "type": "message", + "role": "system", + "content": [ + { "type": "input_text", "text": "System context" } + ] + }, + { + "type": "message", + "role": "user", + "content": [ + { "type": "input_text", "text": "Please check the weather." }, + { "type": "input_image", "image_url": expected_image_url } + ] + }, + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "Looking that up.", "annotations": [] } + ] + }, + { + "type": "function_call", + "call_id": "call-42", + "name": "get_weather", + "arguments": tool_arguments + }, + { + "type": "function_call_output", + "call_id": "call-42", + "output": "Sunny" + } + ], + "stream": true, + "max_output_tokens": 2048, + "parallel_tool_calls": true, + "tool_choice": "required", + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Fetches the weather", + "parameters": { "type": "object" } + } + ], + "prompt_cache_key": "thread-123", + "reasoning": { "effort": "low", "summary": "auto" } + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn responses_stream_maps_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + ref id, + ref name, + ref raw_input, + is_input_complete: true, + .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "{\"city\":\"Boston\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_uses_max_tokens_stop_reason() { + let events = vec![ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + usage: Some(ResponseUsage { + input_tokens: Some(10), + output_tokens: Some(20), + total_tokens: Some(30), + }), + ..Default::default() + }, + }]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 10, + output_tokens: 20, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_multiple_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn1".into(), + output_index: 0, + arguments: "{\"city\":\"NYC\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn2".into(), + output_index: 1, + arguments: "{\"city\":\"LA\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"NYC\"}" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"LA\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_mixed_text_and_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Let me check that".into(), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 1, + arguments: "{\"query\":\"test\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { .. } + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that") + ); + assert!( + matches!(mapped[2], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"query\":\"test\"}") + ); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_json_parse_error() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{invalid json")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{invalid json".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUseJsonParseError { ref raw_input, .. } + if raw_input.as_ref() == "{invalid json" + )); + } + + #[test] + fn responses_stream_handles_incomplete_function_call() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "\"Boston\"".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, is_input_complete: true, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_incomplete_does_not_duplicate_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!( + matches!(mapped[0], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_empty_tool_arguments() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!(matches!( + &mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id, name, raw_input, input, .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "" + && input.is_object() + && input.as_object().unwrap().is_empty() + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_emits_partial_tool_use_events() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::FunctionCall( + crate::responses::ResponseFunctionToolCall { + id: Some("item_fn".to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_abc".to_string()), + arguments: String::new(), + }, + ), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "{\"city\":\"Bos".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(mapped.len() >= 3); + + let complete_tool_use = mapped.iter().find(|e| { + matches!( + e, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + ) + }); + assert!( + complete_tool_use.is_some(), + "should have a complete tool use event" + ); + + let tool_uses: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) + .collect(); + assert!( + tool_uses.len() >= 2, + "should have at least one partial and one complete event" + ); + assert!(matches!( + tool_uses.last().unwrap(), + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + )); + } + + #[test] + fn responses_stream_maps_reasoning_summary_deltas() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Thinking about".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: " the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Thinking about the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![ + ReasoningSummaryPart::SummaryText { + text: "Thinking about the answer".into(), + }, + ReasoningSummaryPart::SummaryText { + text: "Second part".into(), + }, + ], + }), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_message("msg_456"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_456".into(), + output_index: 1, + content_index: Some(0), + delta: "The answer is 42".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + + let thinking_events: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) + .collect(); + assert_eq!( + thinking_events.len(), + 4, + "expected 4 thinking events, got {:?}", + thinking_events + ); + assert!( + matches!(&thinking_events[0], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about") + ); + assert!( + matches!(&thinking_events[1], LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer") + ); + assert!( + matches!(&thinking_events[2], LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n"), + "expected separator between summary parts" + ); + assert!( + matches!(&thinking_events[3], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part") + ); + + assert!(mapped.iter().any( + |e| matches!(e, LanguageModelCompletionEvent::Text(t) if t == "The answer is 42") + )); + } + + #[test] + fn responses_stream_maps_reasoning_from_done_only() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![ReasoningSummaryPart::SummaryText { + text: "Summary without deltas".into(), + }], + }), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!( + !mapped + .iter() + .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), + "OutputItemDone reasoning should not produce Thinking events" + ); + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index c4a3e078d76eb028b90e5b80fe95b1281b795f34..5423d9c5dcaa13589a8a7d658548b42fd467f67f 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,4 +1,5 @@ pub mod batches; +pub mod completion; pub mod responses; use anyhow::{Context as _, Result, anyhow}; @@ -7,9 +8,9 @@ use http_client::{ AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode, http::{HeaderMap, HeaderValue}, }; +pub use language_model_core::ReasoningEffort; use serde::{Deserialize, Serialize}; use serde_json::Value; -pub use settings::OpenAiReasoningEffort as ReasoningEffort; use std::{convert::TryFrom, future::Future}; use strum::EnumIter; use thiserror::Error; @@ -717,3 +718,26 @@ pub fn embed<'a>( Ok(response) } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: RequestError) -> Self { + match error { + RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http_client::http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(std::time::Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + RequestError::Other(e) => Self::Other(e), + } + } +} diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml index cccb92c33b05b8fff0e5e78277c9f7fa29844ace..2cc5d3d00e2eb5d755cef971be51a315bcdf254f 100644 --- a/crates/open_router/Cargo.toml +++ b/crates/open_router/Cargo.toml @@ -19,6 +19,7 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index 9841c7b1ae19a57878fd8e84625bc4058b809613..b94631f9a0e6764ab5cfe487e7851a820fa80b1d 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -744,3 +744,71 @@ impl ApiErrorCode { } } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: OpenRouterError) -> Self { + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error { + OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, + OpenRouterError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + OpenRouterError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error.code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PaymentRequiredError => Self::AuthenticationError { + provider, + message: format!("Payment required: {}", error.message), + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + RequestTimedOut => Self::HttpResponseError { + provider, + status_code: http_client::StatusCode::REQUEST_TIMEOUT, + message: error.message, + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + } + } +} diff --git a/crates/project/src/prettier_store.rs b/crates/project/src/prettier_store.rs index b66f2d5e0c041e104cf109a48b6bad249b492b88..faa2cca79866f31682a497eebab819b75e778ffb 100644 --- a/crates/project/src/prettier_store.rs +++ b/crates/project/src/prettier_store.rs @@ -412,7 +412,7 @@ impl PrettierStore { prettier_store .update(cx, |prettier_store, cx| { let name = if is_default { - LanguageServerName("prettier (default)".to_string().into()) + LanguageServerName("prettier (default)".into()) } else { let worktree_path = worktree_id .and_then(|id| { diff --git a/crates/settings_content/Cargo.toml b/crates/settings_content/Cargo.toml index b3599e9eef3b7ac5680f441369a7cbdc98a5d043..59cccb4167ed64a2ece8ae5a73ac570ca7dabd97 100644 --- a/crates/settings_content/Cargo.toml +++ b/crates/settings_content/Cargo.toml @@ -19,6 +19,7 @@ anyhow.workspace = true collections.workspace = true derive_more.workspace = true gpui.workspace = true +language_model_core.workspace = true log.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/settings_content/src/language_model.rs b/crates/settings_content/src/language_model.rs index 4b72c2ad3f47d834dfa38555d80a8646e3940f51..00ecf42537459496102495c51628b54405968214 100644 --- a/crates/settings_content/src/language_model.rs +++ b/crates/settings_content/src/language_model.rs @@ -1,8 +1,8 @@ +use crate::merge_from::MergeFrom; use collections::HashMap; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings_macros::{MergeFrom, with_fallible_options}; -use strum::EnumString; use std::sync::Arc; @@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel { pub capabilities: OpenAiModelCapabilities, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)] -#[serde(rename_all = "lowercase")] -#[strum(serialize_all = "lowercase")] -pub enum OpenAiReasoningEffort { - Minimal, - Low, - Medium, - High, - XHigh, +pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort; + +impl MergeFrom for OpenAiReasoningEffort { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } #[with_fallible_options] @@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: u64, } -#[derive( - Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom, -)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +pub use language_model_core::ModelMode; + +impl MergeFrom for ModelMode { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index ff264edcb150063237c633de746b2f6b9f6f250c..e2bbc1aeb2dd5718596b905788b4a88826357401 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true futures.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 11227d8fb5c7152dc5b7e03b95fadea6cb714717..16707003c49921bce6244b69d0e7387f935ed8e1 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Task}; use http_client::{HttpClient, Method}; -use language_model::LlmApiToken; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml index 8ff020df8c1ccaf284157d8b46ddaa0e678b3cd7..2d1c9d0ecebeb8a1e0965b0ac914603b41383f00 100644 --- a/crates/x_ai/Cargo.toml +++ b/crates/x_ai/Cargo.toml @@ -17,6 +17,8 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/x_ai/src/completion.rs b/crates/x_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..aad03d227eb82768c972283f7e1617ea7486f22f --- /dev/null +++ b/crates/x_ai/src/completion.rs @@ -0,0 +1,30 @@ +use anyhow::Result; +use language_model_core::{LanguageModelRequest, Role}; + +use crate::Model; + +/// Count tokens for an xAI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) +} diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index 1abb2b53771fa1e29e2979560e9f394744b26158..fd141a1723a28d235311d5d875bf4cc0388cab61 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -1,3 +1,5 @@ +pub mod completion; + use anyhow::Result; use serde::{Deserialize, Serialize}; use strum::EnumIter; From d3a9d5fb9d889b4d9b6382051e5fb6268fe25cf2 Mon Sep 17 00:00:00 2001 From: Cameron Mcloughlin Date: Tue, 7 Apr 2026 16:34:53 +0100 Subject: [PATCH 16/22] sidebar: Drop test cases (#53315) --- crates/sidebar/src/sidebar_tests.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index 60881acfe9461f7897d6013831970444b7a65544..09fd44af35679a69908e1d86d203ea8c3aa5c545 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -5064,6 +5064,7 @@ async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &m mod property_test { use super::*; + use gpui::proptest::prelude::*; struct UnopenedWorktree { path: String, @@ -5658,7 +5659,10 @@ mod property_test { Ok(()) } - #[gpui::property_test] + #[gpui::property_test(config = ProptestConfig { + cases: 10, + ..Default::default() + })] async fn test_sidebar_invariants( #[strategy = gpui::proptest::collection::vec(0u32..DISTRIBUTION_SLOTS * 10, 1..5)] raw_operations: Vec, From 833a015dc65f5b1658af1ae8cbb58ebe313cdf66 Mon Sep 17 00:00:00 2001 From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:44:19 -0300 Subject: [PATCH 17/22] recent_projects: Make the currently active project visible in the picker (#53302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves the recent projects picker in the context of multi-workspace: - The currently active project now appears in the "This Window" section with a checkmark indicator. Clicking it simply dismisses the picker, since there's nothing to switch to. This feels like a better UX because it gives you visual confirmation of where you are. - The remove button is hidden for the current project entry, both in the row and the footer, to prevent accidentally removing the workspace you're actively using. - The "Add to Workspace" button now uses a more descriptive icon (`FolderOpenAdd`) and shows a meta tooltip clarifying that it adds the project as a multi-root folder project. The primary click/enter behavior remains unchanged—it opens the selected project in the current window's multi-workspace. The "Open in New Window" action continues to be available via the icon button or shift+enter. Release Notes: - Improved the recent projects picker to show the currently active project in the "This Window" section with a checkmark indicator. --- assets/icons/folder_open_add.svg | 5 + assets/icons/folder_plus.svg | 5 - assets/icons/open_new_window.svg | 7 + crates/agent_ui/src/threads_archive_view.rs | 1 + crates/icons/src/icons.rs | 3 +- .../src/highlighted_match_with_paths.rs | 23 +++- crates/recent_projects/src/recent_projects.rs | 127 ++++++++++++------ .../src/sidebar_recent_projects.rs | 10 +- 8 files changed, 127 insertions(+), 54 deletions(-) create mode 100644 assets/icons/folder_open_add.svg delete mode 100644 assets/icons/folder_plus.svg create mode 100644 assets/icons/open_new_window.svg diff --git a/assets/icons/folder_open_add.svg b/assets/icons/folder_open_add.svg new file mode 100644 index 0000000000000000000000000000000000000000..d5ebbdaa8b080037a2faee0ee0fc3606eec9c6ca --- /dev/null +++ b/assets/icons/folder_open_add.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/folder_plus.svg b/assets/icons/folder_plus.svg deleted file mode 100644 index a543448ed6197043291369bee640e23b6ad729b9..0000000000000000000000000000000000000000 --- a/assets/icons/folder_plus.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/assets/icons/open_new_window.svg b/assets/icons/open_new_window.svg new file mode 100644 index 0000000000000000000000000000000000000000..c81d49f9ff9edfbc965055568efc72e0214efb41 --- /dev/null +++ b/assets/icons/open_new_window.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/crates/agent_ui/src/threads_archive_view.rs b/crates/agent_ui/src/threads_archive_view.rs index 13b2aa1a37cd506c338d13db78bce751882e426a..7cb8410e5017438b0e8adde673887c13397d9abf 100644 --- a/crates/agent_ui/src/threads_archive_view.rs +++ b/crates/agent_ui/src/threads_archive_view.rs @@ -1236,6 +1236,7 @@ impl PickerDelegate for ProjectPickerDelegate { }, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths: Vec::new(), + active: false, }; Some( diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index e29b7d3593025556771d62dc0124786672c540de..bdc3890432414e0a78f69a226bb9174510453331 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -134,7 +134,7 @@ pub enum IconName { Flame, Folder, FolderOpen, - FolderPlus, + FolderOpenAdd, FolderSearch, Font, FontSize, @@ -184,6 +184,7 @@ pub enum IconName { NewThread, Notepad, OpenFolder, + OpenNewWindow, Option, PageDown, PageUp, diff --git a/crates/picker/src/highlighted_match_with_paths.rs b/crates/picker/src/highlighted_match_with_paths.rs index 74271047621b26be573dc2eebfffe9e9e0f1a138..7c88213437feea17e6b431dff9c97b0b8557872a 100644 --- a/crates/picker/src/highlighted_match_with_paths.rs +++ b/crates/picker/src/highlighted_match_with_paths.rs @@ -5,6 +5,7 @@ pub struct HighlightedMatchWithPaths { pub prefix: Option, pub match_label: HighlightedMatch, pub paths: Vec, + pub active: bool, } #[derive(Debug, Clone, IntoElement)] @@ -63,18 +64,30 @@ impl HighlightedMatchWithPaths { .color(Color::Muted) })) } + + pub fn is_active(mut self, active: bool) -> Self { + self.active = active; + self + } } impl RenderOnce for HighlightedMatchWithPaths { fn render(mut self, _window: &mut Window, _: &mut App) -> impl IntoElement { v_flex() .child( - h_flex().gap_1().child(self.match_label.clone()).when_some( - self.prefix.as_ref(), - |this, prefix| { + h_flex() + .gap_1() + .child(self.match_label.clone()) + .when_some(self.prefix.as_ref(), |this, prefix| { this.child(Label::new(format!("({})", prefix)).color(Color::Muted)) - }, - ), + }) + .when(self.active, |this| { + this.child( + Icon::new(IconName::Check) + .size(IconSize::Small) + .color(Color::Accent), + ) + }), ) .when(!self.paths.is_empty(), |this| { self.render_paths_children(this) diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index e3bfc0dc08c95c0ce57b818e50965433a6c6bc98..57754dadec20146cb1f21039266de88a0bd5da9f 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -720,6 +720,9 @@ impl RecentProjects { picker.delegate.workspaces.get(hit.candidate_id) { let workspace_id = *workspace_id; + if picker.delegate.is_current_workspace(workspace_id, cx) { + return; + } picker .delegate .remove_sibling_workspace(workspace_id, window, cx); @@ -939,7 +942,7 @@ impl PickerDelegate for RecentProjectsDelegate { .workspaces .iter() .enumerate() - .filter(|(_, (id, _, _, _))| self.is_sibling_workspace(*id, cx)) + .filter(|(_, (id, _, _, _))| self.sibling_workspace_ids.contains(id)) .map(|(id, (_, _, paths, _))| { let combined_string = paths .ordered_paths() @@ -1028,7 +1031,7 @@ impl PickerDelegate for RecentProjectsDelegate { if is_empty_query { for (id, (workspace_id, _, _, _)) in self.workspaces.iter().enumerate() { - if self.is_sibling_workspace(*workspace_id, cx) { + if self.sibling_workspace_ids.contains(workspace_id) { entries.push(ProjectPickerEntry::OpenProject(StringMatch { candidate_id: id, score: 0.0, @@ -1106,6 +1109,11 @@ impl PickerDelegate for RecentProjectsDelegate { }; let workspace_id = *workspace_id; + if self.is_current_workspace(workspace_id, cx) { + cx.emit(DismissEvent); + return; + } + if let Some(handle) = window.window_handle().downcast::() { cx.defer(move |cx| { handle @@ -1349,6 +1357,7 @@ impl PickerDelegate for RecentProjectsDelegate { ProjectPickerEntry::OpenProject(hit) => { let (workspace_id, location, paths, _) = self.workspaces.get(hit.candidate_id)?; let workspace_id = *workspace_id; + let is_current = self.is_current_workspace(workspace_id, cx); let ordered_paths: Vec<_> = paths .ordered_paths() .map(|p| p.compact().to_string_lossy().to_string()) @@ -1388,6 +1397,7 @@ impl PickerDelegate for RecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths, + active: is_current, }; let icon = icon_for_remote_connection(match location { @@ -1397,20 +1407,24 @@ impl PickerDelegate for RecentProjectsDelegate { let secondary_actions = h_flex() .gap_1() - .child( - IconButton::new("remove_open_project", IconName::Close) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Remove Project from Window")) - .on_click(cx.listener(move |picker, _, window, cx| { - cx.stop_propagation(); - window.prevent_default(); - picker - .delegate - .remove_sibling_workspace(workspace_id, window, cx); - let query = picker.query(cx); - picker.update_matches(query, window, cx); - })), - ) + .when(!is_current, |this| { + this.child( + IconButton::new("remove_open_project", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Remove Project from Window")) + .on_click(cx.listener(move |picker, _, window, cx| { + cx.stop_propagation(); + window.prevent_default(); + picker.delegate.remove_sibling_workspace( + workspace_id, + window, + cx, + ); + let query = picker.query(cx); + picker.update_matches(query, window, cx); + })), + ) + }) .into_any_element(); Some( @@ -1483,6 +1497,7 @@ impl PickerDelegate for RecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths, + active: false, }; let focus_handle = self.focus_handle.clone(); @@ -1491,9 +1506,16 @@ impl PickerDelegate for RecentProjectsDelegate { .gap_px() .when(is_local, |this| { this.child( - IconButton::new("add_to_workspace", IconName::FolderPlus) + IconButton::new("add_to_workspace", IconName::FolderOpenAdd) .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Add Project to this Workspace")) + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Add Project to this Workspace", + None, + "As a multi-root folder project", + cx, + ) + }) .on_click({ let paths_to_add = paths_to_add.clone(); cx.listener(move |picker, _event, window, cx| { @@ -1509,8 +1531,8 @@ impl PickerDelegate for RecentProjectsDelegate { ) }) .child( - IconButton::new("open_new_window", IconName::ArrowUpRight) - .icon_size(IconSize::XSmall) + IconButton::new("open_new_window", IconName::OpenNewWindow) + .icon_size(IconSize::Small) .tooltip({ move |_, cx| { Tooltip::for_action_in( @@ -1565,7 +1587,14 @@ impl PickerDelegate for RecentProjectsDelegate { } highlighted.render(window, cx) }) - .tooltip(Tooltip::text(tooltip_path)), + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Open Project in This Window", + None, + tooltip_path.clone(), + cx, + ) + }), ) .end_slot(secondary_actions) .show_end_slot_on_hover() @@ -1625,27 +1654,41 @@ impl PickerDelegate for RecentProjectsDelegate { let selected_entry = self.filtered_entries.get(self.selected_index); + let is_current_workspace_entry = + if let Some(ProjectPickerEntry::OpenProject(hit)) = selected_entry { + self.workspaces + .get(hit.candidate_id) + .map(|(id, ..)| self.is_current_workspace(*id, cx)) + .unwrap_or(false) + } else { + false + }; + let secondary_footer_actions: Option = match selected_entry { - Some(ProjectPickerEntry::OpenFolder { .. } | ProjectPickerEntry::OpenProject(_)) => { - let label = if matches!(selected_entry, Some(ProjectPickerEntry::OpenFolder { .. })) - { - "Remove Folder" - } else { - "Remove from Window" - }; - Some( - Button::new("remove_selected", label) - .key_binding(KeyBinding::for_action_in( - &RemoveSelected, - &focus_handle, - cx, - )) - .on_click(|_, window, cx| { - window.dispatch_action(RemoveSelected.boxed_clone(), cx) - }) - .into_any_element(), - ) - } + Some(ProjectPickerEntry::OpenFolder { .. }) => Some( + Button::new("remove_selected", "Remove Folder") + .key_binding(KeyBinding::for_action_in( + &RemoveSelected, + &focus_handle, + cx, + )) + .on_click(|_, window, cx| { + window.dispatch_action(RemoveSelected.boxed_clone(), cx) + }) + .into_any_element(), + ), + Some(ProjectPickerEntry::OpenProject(_)) if !is_current_workspace_entry => Some( + Button::new("remove_selected", "Remove from Window") + .key_binding(KeyBinding::for_action_in( + &RemoveSelected, + &focus_handle, + cx, + )) + .on_click(|_, window, cx| { + window.dispatch_action(RemoveSelected.boxed_clone(), cx) + }) + .into_any_element(), + ), Some(ProjectPickerEntry::RecentProject(_)) => Some( Button::new("delete_recent", "Delete") .key_binding(KeyBinding::for_action_in( @@ -1748,7 +1791,7 @@ impl PickerDelegate for RecentProjectsDelegate { menu.context(focus_handle) .when(show_add_to_workspace, |menu| { menu.action( - "Add to Workspace", + "Add to this Workspace", AddToWorkspace.boxed_clone(), ) .separator() diff --git a/crates/recent_projects/src/sidebar_recent_projects.rs b/crates/recent_projects/src/sidebar_recent_projects.rs index 1fe0d2ae86aefdad45136c496f8049689d77e048..dec269c07eada3a1d6172482cb886f9ed44d784c 100644 --- a/crates/recent_projects/src/sidebar_recent_projects.rs +++ b/crates/recent_projects/src/sidebar_recent_projects.rs @@ -374,6 +374,7 @@ impl PickerDelegate for SidebarRecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths: Vec::new(), + active: false, }; let icon = icon_for_remote_connection(match location { @@ -395,7 +396,14 @@ impl PickerDelegate for SidebarRecentProjectsDelegate { }) .child(highlighted_match.render(window, cx)), ) - .tooltip(Tooltip::text(tooltip_path)) + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Open Project in This Window", + None, + tooltip_path.clone(), + cx, + ) + }) .into_any_element(), ) } From 43867668f44549dd8da9954c62b1229c2fb6bec7 Mon Sep 17 00:00:00 2001 From: Jozsef Lazar Date: Tue, 7 Apr 2026 18:11:39 +0200 Subject: [PATCH 18/22] Add query and search options to pane::DeploySearch action (#47331) Extend the DeploySearch action to accept additional parameters for configuring the project search from keymaps: - query: prefilled search query string - regex: enable regex search mode - case_sensitive: match case exactly - whole_word: match whole words only - include_ignored: search in gitignored files With this change, the following keymap becomes possible: ```json ["pane::DeploySearch", { "query": "TODO|FIXME|NOTE|BUG|HACK|XXX|WARN", "regex": true }], ``` Release Notes: - Added options to `pane::DeploySearch` for keymap-driven search initiation --- crates/search/src/project_search.rs | 313 +++++++++++++++++++++++++++- crates/workspace/src/pane.rs | 30 +-- crates/zed/src/zed/app_menus.rs | 2 +- 3 files changed, 317 insertions(+), 28 deletions(-) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 1bccf1ae52fb2c52a8d01e53aabb1b3ff5c7c16f..7c9d3f176ed3f17ec5e21faa7c1b483252657614 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -769,6 +769,17 @@ impl ProjectSearchView { } } + fn set_search_option_enabled( + &mut self, + option: SearchOptions, + enabled: bool, + cx: &mut Context, + ) { + if self.search_options.contains(option) != enabled { + self.toggle_search_option(option, cx); + } + } + fn toggle_search_option(&mut self, option: SearchOptions, cx: &mut Context) { self.search_options.toggle(option); ActiveSettings::update_global(cx, |settings, cx| { @@ -1153,7 +1164,7 @@ impl ProjectSearchView { window: &mut Window, cx: &mut Context, ) { - Self::existing_or_new_search(workspace, None, &DeploySearch::find(), window, cx) + Self::existing_or_new_search(workspace, None, &DeploySearch::default(), window, cx) } fn existing_or_new_search( @@ -1203,8 +1214,29 @@ impl ProjectSearchView { search.update(cx, |search, cx| { search.replace_enabled |= action.replace_enabled; + if let Some(regex) = action.regex { + search.set_search_option_enabled(SearchOptions::REGEX, regex, cx); + } + if let Some(case_sensitive) = action.case_sensitive { + search.set_search_option_enabled(SearchOptions::CASE_SENSITIVE, case_sensitive, cx); + } + if let Some(whole_word) = action.whole_word { + search.set_search_option_enabled(SearchOptions::WHOLE_WORD, whole_word, cx); + } + if let Some(include_ignored) = action.include_ignored { + search.set_search_option_enabled( + SearchOptions::INCLUDE_IGNORED, + include_ignored, + cx, + ); + } + let query = action + .query + .as_deref() + .filter(|q| !q.is_empty()) + .or(query.as_deref()); if let Some(query) = query { - search.set_query(&query, window, cx); + search.set_query(query, window, cx); } if let Some(included_files) = action.included_files.as_deref() { search @@ -3101,7 +3133,7 @@ pub mod tests { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -3252,7 +3284,7 @@ pub mod tests { workspace.update_in(cx, |workspace, window, cx| { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -3325,7 +3357,7 @@ pub mod tests { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -4560,7 +4592,7 @@ pub mod tests { }); // Deploy a new search - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); // Both panes should now have a project search in them workspace.update_in(cx, |workspace, window, cx| { @@ -4585,7 +4617,7 @@ pub mod tests { .unwrap(); // Deploy a new search - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); // The project search view should now be focused in the second pane // And the number of items should be unchanged. @@ -4823,7 +4855,7 @@ pub mod tests { assert!(workspace.has_active_modal(window, cx)); }); - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); workspace.update_in(cx, |workspace, window, cx| { assert!(!workspace.has_active_modal(window, cx)); @@ -5136,6 +5168,271 @@ pub mod tests { .unwrap(); } + #[gpui::test] + async fn test_deploy_search_applies_and_resets_options(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/dir"), + json!({ + "one.rs": "const ONE: usize = 1;", + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let search_bar = window.build_entity(cx, |_, _| ProjectSearchBar::new()); + + workspace.update_in(cx, |workspace, window, cx| { + workspace.panes()[0].update(cx, |pane, cx| { + pane.toolbar() + .update(cx, |toolbar, cx| toolbar.add_item(search_bar, window, cx)) + }); + + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + case_sensitive: Some(true), + whole_word: Some(true), + include_ignored: Some(true), + query: Some("Test_Query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + let search_view = cx + .read(|cx| { + workspace + .read(cx) + .active_pane() + .read(cx) + .active_item() + .and_then(|item| item.downcast::()) + }) + .expect("Search view should be active after deploy"); + + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + search_view.search_options.contains(SearchOptions::REGEX), + "Regex option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Case sensitive option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::WHOLE_WORD), + "Whole word option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::INCLUDE_IGNORED), + "Include ignored option should be enabled" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!( + query_text, "Test_Query", + "Query should be set from the action" + ); + }); + + // Redeploy with only regex - unspecified options should be preserved. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, _cx| { + assert!( + search_view.search_options.contains(SearchOptions::REGEX), + "Regex should still be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Case sensitive should be preserved from previous deploy" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::WHOLE_WORD), + "Whole word should be preserved from previous deploy" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::INCLUDE_IGNORED), + "Include ignored should be preserved from previous deploy" + ); + }); + + // Redeploy explicitly turning off options. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + case_sensitive: Some(false), + whole_word: Some(false), + include_ignored: Some(false), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, _cx| { + assert_eq!( + search_view.search_options, + SearchOptions::REGEX, + "Explicit Some(false) should turn off options" + ); + }); + + // Redeploy with an empty query - should not overwrite the existing query. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + query: Some("".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, cx| { + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!( + query_text, "Test_Query", + "Empty query string should not overwrite the existing query" + ); + }); + } + + #[gpui::test] + async fn test_smartcase_overrides_explicit_case_sensitive(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_global::(|store, cx| { + store.update_default_settings(cx, |settings| { + settings.editor.use_smartcase_search = Some(true); + }); + }); + }); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/dir"), + json!({ + "one.rs": "const ONE: usize = 1;", + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let search_bar = window.build_entity(cx, |_, _| ProjectSearchBar::new()); + + workspace.update_in(cx, |workspace, window, cx| { + workspace.panes()[0].update(cx, |pane, cx| { + pane.toolbar() + .update(cx, |toolbar, cx| toolbar.add_item(search_bar, window, cx)) + }); + + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + case_sensitive: Some(true), + query: Some("lowercase_query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + let search_view = cx + .read(|cx| { + workspace + .read(cx) + .active_pane() + .read(cx) + .active_item() + .and_then(|item| item.downcast::()) + }) + .expect("Search view should be active after deploy"); + + // Smartcase should override the explicit case_sensitive flag + // because the query is all lowercase. + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + !search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Smartcase should disable case sensitivity for a lowercase query, \ + even when case_sensitive was explicitly set in the action" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!(query_text, "lowercase_query"); + }); + + // Now deploy with an uppercase query - smartcase should enable case sensitivity. + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + query: Some("Uppercase_Query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Smartcase should enable case sensitivity for a query containing uppercase" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!(query_text, "Uppercase_Query"); + }); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings = SettingsStore::test(cx); diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index a09ba73add7e94fbe6910eb400b1364bd21cd313..cbcd60b734644cb61473bef85e27f2403e3c7d3c 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -198,6 +198,16 @@ pub struct DeploySearch { pub included_files: Option, #[serde(default)] pub excluded_files: Option, + #[serde(default)] + pub query: Option, + #[serde(default)] + pub regex: Option, + #[serde(default)] + pub case_sensitive: Option, + #[serde(default)] + pub whole_word: Option, + #[serde(default)] + pub include_ignored: Option, } #[derive(Clone, Copy, PartialEq, Debug, Deserialize, JsonSchema, Default)] @@ -309,16 +319,6 @@ actions!( ] ); -impl DeploySearch { - pub fn find() -> Self { - Self { - replace_enabled: false, - included_files: None, - excluded_files: None, - } - } -} - const MAX_NAVIGATION_HISTORY_LEN: usize = 1024; pub enum Event { @@ -4188,15 +4188,7 @@ fn default_render_tab_bar_buttons( menu.action("New File", NewFile.boxed_clone()) .action("Open File", ToggleFileFinder::default().boxed_clone()) .separator() - .action( - "Search Project", - DeploySearch { - replace_enabled: false, - included_files: None, - excluded_files: None, - } - .boxed_clone(), - ) + .action("Search Project", DeploySearch::default().boxed_clone()) .action("Search Symbols", ToggleProjectSymbols.boxed_clone()) .separator() .action("New Terminal", NewTerminal::default().boxed_clone()) diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 3edbcad2d81d63b56e777218a3db5e57a42de7bc..f3913a6556626e2919024ca02bcba0f1f41819eb 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -165,7 +165,7 @@ pub fn app_menus(cx: &mut App) -> Vec { MenuItem::os_action("Paste", editor::actions::Paste, OsAction::Paste), MenuItem::separator(), MenuItem::action("Find", search::buffer_search::Deploy::find()), - MenuItem::action("Find in Project", workspace::DeploySearch::find()), + MenuItem::action("Find in Project", workspace::DeploySearch::default()), MenuItem::separator(), MenuItem::action( "Toggle Line Comment", From c5845ec04cb9deaf8712501a53d8d7f284011e4f Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:12:02 -0400 Subject: [PATCH 19/22] Remove notification panel (#50204) After chat functionality was removed, this panel became redundant. It only displayed three notification types: incoming contact requests, accepted contact requests, and channel invitations. This PR moves those notifications into the collab experience by adding toast popups and a badge count to the collab panel. It also removes the notification-panel-specific settings, documentation, and Vim command. Before you mark this PR as ready for review, make sure that you have: - [ ] Added a solid test coverage and/or screenshots from doing manual testing - [x] Done a self-review taking into account security and performance aspects - [x] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) Release Notes: - Removed the notification panel from Zed --- Cargo.lock | 3 - assets/settings/default.json | 10 - crates/agent_settings/src/agent_settings.rs | 15 - crates/agent_ui/src/agent_ui.rs | 2 - crates/collab_ui/Cargo.toml | 3 - crates/collab_ui/src/collab_panel.rs | 362 ++++++++- crates/collab_ui/src/collab_ui.rs | 4 +- crates/collab_ui/src/notification_panel.rs | 727 ------------------ crates/collab_ui/src/panel_settings.rs | 20 - crates/settings/src/vscode_import.rs | 2 +- .../settings_content/src/settings_content.rs | 25 - crates/settings_ui/src/page_data.rs | 91 --- .../components/collab/collab_notification.rs | 56 +- crates/vim/src/command.rs | 1 - crates/zed/src/zed.rs | 16 - docs/src/visual-customization.md | 14 +- 16 files changed, 395 insertions(+), 956 deletions(-) delete mode 100644 crates/collab_ui/src/notification_panel.rs diff --git a/Cargo.lock b/Cargo.lock index 3fccd850ae697925330d15ed6b72804c39f4795e..f426d0da3392240300d15ca174013a6bdbdbb31d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3208,7 +3208,6 @@ dependencies = [ "anyhow", "call", "channel", - "chrono", "client", "collections", "db", @@ -3217,7 +3216,6 @@ dependencies = [ "fuzzy", "gpui", "livekit_client", - "log", "menu", "notifications", "picker", @@ -3232,7 +3230,6 @@ dependencies = [ "theme", "theme_settings", "time", - "time_format", "title_bar", "ui", "util", diff --git a/assets/settings/default.json b/assets/settings/default.json index a32e1b27aee08bf2676922fea3790a99b7d7844b..97fbcd546e09beefa9ff7a67e33806f3faf561d1 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -936,16 +936,6 @@ // For example: typing `:wave:` gets replaced with `👋`. "auto_replace_emoji_shortcode": true, }, - "notification_panel": { - // Whether to show the notification panel button in the status bar. - "button": true, - // Where to dock the notification panel. Can be 'left' or 'right'. - "dock": "right", - // Default width of the notification panel. - "default_width": 380, - // Whether to show a badge on the notification panel icon with the count of unread notifications. - "show_count_badge": false, - }, "agent": { // Whether the inline assistant should use streaming tools, when available "inline_assistant_use_streaming_tools": true, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 5d6dca9322482daecf7525f79ead63b4471b7a53..a04de2ed3be69d3f5791419a32e427fa0c26791e 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -31,7 +31,6 @@ pub struct PanelLayout { pub(crate) outline_panel_dock: Option, pub(crate) collaboration_panel_dock: Option, pub(crate) git_panel_dock: Option, - pub(crate) notification_panel_button: Option, } impl PanelLayout { @@ -41,7 +40,6 @@ impl PanelLayout { outline_panel_dock: Some(DockSide::Right), collaboration_panel_dock: Some(DockPosition::Right), git_panel_dock: Some(DockPosition::Right), - notification_panel_button: Some(false), }; const EDITOR: Self = Self { @@ -50,7 +48,6 @@ impl PanelLayout { outline_panel_dock: Some(DockSide::Left), collaboration_panel_dock: Some(DockPosition::Left), git_panel_dock: Some(DockPosition::Left), - notification_panel_button: Some(true), }; pub fn is_agent_layout(&self) -> bool { @@ -68,7 +65,6 @@ impl PanelLayout { outline_panel_dock: content.outline_panel.as_ref().and_then(|p| p.dock), collaboration_panel_dock: content.collaboration_panel.as_ref().and_then(|p| p.dock), git_panel_dock: content.git_panel.as_ref().and_then(|p| p.dock), - notification_panel_button: content.notification_panel.as_ref().and_then(|p| p.button), } } @@ -78,7 +74,6 @@ impl PanelLayout { settings.outline_panel.get_or_insert_default().dock = self.outline_panel_dock; settings.collaboration_panel.get_or_insert_default().dock = self.collaboration_panel_dock; settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; - settings.notification_panel.get_or_insert_default().button = self.notification_panel_button; } fn write_diff_to(&self, current_merged: &PanelLayout, settings: &mut SettingsContent) { @@ -98,10 +93,6 @@ impl PanelLayout { if self.git_panel_dock != current_merged.git_panel_dock { settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; } - if self.notification_panel_button != current_merged.notification_panel_button { - settings.notification_panel.get_or_insert_default().button = - self.notification_panel_button; - } } fn backfill_to(&self, user_layout: &PanelLayout, settings: &mut SettingsContent) { @@ -121,10 +112,6 @@ impl PanelLayout { if user_layout.git_panel_dock.is_none() { settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; } - if user_layout.notification_panel_button.is_none() { - settings.notification_panel.get_or_insert_default().button = - self.notification_panel_button; - } } } @@ -1257,7 +1244,6 @@ mod tests { assert_eq!(user_layout.outline_panel_dock, None); assert_eq!(user_layout.collaboration_panel_dock, None); assert_eq!(user_layout.git_panel_dock, None); - assert_eq!(user_layout.notification_panel_button, None); // User sets a combination that doesn't match either preset: // agent on the left but project panel also on the left. @@ -1480,7 +1466,6 @@ mod tests { Some(DockPosition::Left) ); assert_eq!(user_layout.git_panel_dock, Some(DockPosition::Left)); - assert_eq!(user_layout.notification_panel_button, Some(true)); // Now switch defaults to agent V2. set_agent_v2_defaults(cx); diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 58b52d9ea2eb10a4f7f483402b98c4be4b08924f..429bc184f5d889990599c196910ae8d0feb28da1 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -524,7 +524,6 @@ pub fn init( defaults.collaboration_panel.get_or_insert_default().dock = Some(DockPosition::Right); defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Right); - defaults.notification_panel.get_or_insert_default().button = Some(false); } else { defaults.agent.get_or_insert_default().dock = Some(DockPosition::Right); defaults.project_panel.get_or_insert_default().dock = Some(DockSide::Left); @@ -532,7 +531,6 @@ pub fn init( defaults.collaboration_panel.get_or_insert_default().dock = Some(DockPosition::Left); defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Left); - defaults.notification_panel.get_or_insert_default().button = Some(true); } }); }); diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index efcba05456955e308e5a00e938bf3092d894efeb..920f620e0ea2d48f514c5e0af598add193f80d98 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -32,7 +32,6 @@ test-support = [ anyhow.workspace = true call.workspace = true channel.workspace = true -chrono.workspace = true client.workspace = true collections.workspace = true db.workspace = true @@ -41,7 +40,6 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true livekit_client.workspace = true -log.workspace = true menu.workspace = true notifications.workspace = true picker.workspace = true @@ -56,7 +54,6 @@ telemetry.workspace = true theme.workspace = true theme_settings.workspace = true time.workspace = true -time_format.workspace = true title_bar.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 327ef1cf6003eb959bd0926d67d2b0ed3b4ab0ba..1cff27ac6b2f3c61f7a90c4a9ca6749d4b1e48b7 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -6,7 +6,7 @@ use crate::{CollaborationPanelSettings, channel_view::ChannelView}; use anyhow::Context as _; use call::ActiveCall; use channel::{Channel, ChannelEvent, ChannelStore}; -use client::{ChannelId, Client, Contact, User, UserStore}; +use client::{ChannelId, Client, Contact, Notification, User, UserStore}; use collections::{HashMap, HashSet}; use contact_finder::ContactFinder; use db::kvp::KeyValueStore; @@ -21,6 +21,7 @@ use gpui::{ }; use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrevious}; +use notifications::{NotificationEntry, NotificationEvent, NotificationStore}; use project::{Fs, Project}; use rpc::{ ErrorCode, ErrorExt, @@ -29,19 +30,23 @@ use rpc::{ use serde::{Deserialize, Serialize}; use settings::Settings; use smallvec::SmallVec; -use std::{mem, sync::Arc}; +use std::{mem, sync::Arc, time::Duration}; use theme::ActiveTheme; use theme_settings::ThemeSettings; use ui::{ - Avatar, AvatarAvailabilityIndicator, ContextMenu, CopyButton, Facepile, HighlightedLabel, - IconButtonShape, Indicator, ListHeader, ListItem, Tab, Tooltip, prelude::*, tooltip_container, + Avatar, AvatarAvailabilityIndicator, CollabNotification, ContextMenu, CopyButton, Facepile, + HighlightedLabel, IconButtonShape, Indicator, ListHeader, ListItem, Tab, Tooltip, prelude::*, + tooltip_container, }; use util::{ResultExt, TryFutureExt, maybe}; use workspace::{ CopyRoomId, Deafen, LeaveCall, MultiWorkspace, Mute, OpenChannelNotes, OpenChannelNotesById, ScreenShare, ShareProject, Workspace, dock::{DockPosition, Panel, PanelEvent}, - notifications::{DetachAndPromptErr, NotifyResultExt}, + notifications::{ + DetachAndPromptErr, Notification as WorkspaceNotification, NotificationId, NotifyResultExt, + SuppressEvent, + }, }; const FILTER_OCCUPIED_CHANNELS_KEY: &str = "filter_occupied_channels"; @@ -87,6 +92,7 @@ struct ChannelMoveClipboard { } const COLLABORATION_PANEL_KEY: &str = "CollaborationPanel"; +const TOAST_DURATION: Duration = Duration::from_secs(5); pub fn init(cx: &mut App) { cx.observe_new(|workspace: &mut Workspace, _, _| { @@ -267,6 +273,9 @@ pub struct CollabPanel { collapsed_channels: Vec, filter_occupied_channels: bool, workspace: WeakEntity, + notification_store: Entity, + current_notification_toast: Option<(u64, Task<()>)>, + mark_as_read_tasks: HashMap>>, } #[derive(Serialize, Deserialize)] @@ -394,6 +403,9 @@ impl CollabPanel { channel_editing_state: None, selection: None, channel_store: ChannelStore::global(cx), + notification_store: NotificationStore::global(cx), + current_notification_toast: None, + mark_as_read_tasks: HashMap::default(), user_store: workspace.user_store().clone(), project: workspace.project().clone(), subscriptions: Vec::default(), @@ -437,6 +449,11 @@ impl CollabPanel { } }, )); + this.subscriptions.push(cx.subscribe_in( + &this.notification_store, + window, + Self::on_notification_event, + )); this }) @@ -2665,26 +2682,28 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) -> AnyElement { - let entry = &self.entries[ix]; + let entry = self.entries[ix].clone(); let is_selected = self.selection == Some(ix); match entry { ListEntry::Header(section) => { - let is_collapsed = self.collapsed_sections.contains(section); - self.render_header(*section, is_selected, is_collapsed, cx) + let is_collapsed = self.collapsed_sections.contains(§ion); + self.render_header(section, is_selected, is_collapsed, cx) + .into_any_element() + } + ListEntry::Contact { contact, calling } => { + self.mark_contact_request_accepted_notifications_read(contact.user.id, cx); + self.render_contact(&contact, calling, is_selected, cx) .into_any_element() } - ListEntry::Contact { contact, calling } => self - .render_contact(contact, *calling, is_selected, cx) - .into_any_element(), ListEntry::ContactPlaceholder => self .render_contact_placeholder(is_selected, cx) .into_any_element(), ListEntry::IncomingRequest(user) => self - .render_contact_request(user, true, is_selected, cx) + .render_contact_request(&user, true, is_selected, cx) .into_any_element(), ListEntry::OutgoingRequest(user) => self - .render_contact_request(user, false, is_selected, cx) + .render_contact_request(&user, false, is_selected, cx) .into_any_element(), ListEntry::Channel { channel, @@ -2694,9 +2713,9 @@ impl CollabPanel { .. } => self .render_channel( - channel, - *depth, - *has_children, + &channel, + depth, + has_children, is_selected, ix, string_match.as_ref(), @@ -2704,10 +2723,10 @@ impl CollabPanel { ) .into_any_element(), ListEntry::ChannelEditor { depth } => self - .render_channel_editor(*depth, window, cx) + .render_channel_editor(depth, window, cx) .into_any_element(), ListEntry::ChannelInvite(channel) => self - .render_channel_invite(channel, is_selected, cx) + .render_channel_invite(&channel, is_selected, cx) .into_any_element(), ListEntry::CallParticipant { user, @@ -2715,7 +2734,7 @@ impl CollabPanel { is_pending, role, } => self - .render_call_participant(user, *peer_id, *is_pending, *role, is_selected, cx) + .render_call_participant(&user, peer_id, is_pending, role, is_selected, cx) .into_any_element(), ListEntry::ParticipantProject { project_id, @@ -2724,20 +2743,20 @@ impl CollabPanel { is_last, } => self .render_participant_project( - *project_id, - worktree_root_names, - *host_user_id, - *is_last, + project_id, + &worktree_root_names, + host_user_id, + is_last, is_selected, window, cx, ) .into_any_element(), ListEntry::ParticipantScreen { peer_id, is_last } => self - .render_participant_screen(*peer_id, *is_last, is_selected, window, cx) + .render_participant_screen(peer_id, is_last, is_selected, window, cx) .into_any_element(), ListEntry::ChannelNotes { channel_id } => self - .render_channel_notes(*channel_id, is_selected, window, cx) + .render_channel_notes(channel_id, is_selected, window, cx) .into_any_element(), } } @@ -3397,6 +3416,178 @@ impl CollabPanel { item.child(self.channel_name_editor.clone()) } } + + fn on_notification_event( + &mut self, + _: &Entity, + event: &NotificationEvent, + _window: &mut Window, + cx: &mut Context, + ) { + match event { + NotificationEvent::NewNotification { entry } => { + self.add_toast(entry, cx); + cx.notify(); + } + NotificationEvent::NotificationRemoved { entry } + | NotificationEvent::NotificationRead { entry } => { + self.remove_toast(entry.id, cx); + cx.notify(); + } + NotificationEvent::NotificationsUpdated { .. } => { + cx.notify(); + } + } + } + + fn present_notification( + &self, + entry: &NotificationEntry, + cx: &App, + ) -> Option<(Option>, String)> { + let user_store = self.user_store.read(cx); + match &entry.notification { + Notification::ContactRequest { sender_id } => { + let requester = user_store.get_cached_user(*sender_id)?; + Some(( + Some(requester.clone()), + format!("{} wants to add you as a contact", requester.github_login), + )) + } + Notification::ContactRequestAccepted { responder_id } => { + let responder = user_store.get_cached_user(*responder_id)?; + Some(( + Some(responder.clone()), + format!("{} accepted your contact request", responder.github_login), + )) + } + Notification::ChannelInvitation { + channel_name, + inviter_id, + .. + } => { + let inviter = user_store.get_cached_user(*inviter_id)?; + Some(( + Some(inviter.clone()), + format!( + "{} invited you to join the #{channel_name} channel", + inviter.github_login + ), + )) + } + } + } + + fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut Context) { + let Some((actor, text)) = self.present_notification(entry, cx) else { + return; + }; + + let notification = entry.notification.clone(); + let needs_response = matches!( + notification, + Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } + ); + + let notification_id = entry.id; + + self.current_notification_toast = Some(( + notification_id, + cx.spawn(async move |this, cx| { + cx.background_executor().timer(TOAST_DURATION).await; + this.update(cx, |this, cx| this.remove_toast(notification_id, cx)) + .ok(); + }), + )); + + let collab_panel = cx.entity().downgrade(); + self.workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + + workspace.dismiss_notification(&id, cx); + workspace.show_notification(id, cx, |cx| { + let workspace = cx.entity().downgrade(); + cx.new(|cx| CollabNotificationToast { + actor, + text, + notification: needs_response.then(|| notification), + workspace, + collab_panel: collab_panel.clone(), + focus_handle: cx.focus_handle(), + }) + }) + }) + .ok(); + } + + fn mark_notification_read(&mut self, notification_id: u64, cx: &mut Context) { + let client = self.client.clone(); + self.mark_as_read_tasks + .entry(notification_id) + .or_insert_with(|| { + cx.spawn(async move |this, cx| { + let request_result = client + .request(proto::MarkNotificationRead { notification_id }) + .await; + + this.update(cx, |this, _| { + this.mark_as_read_tasks.remove(¬ification_id); + })?; + + request_result?; + Ok(()) + }) + }); + } + + fn mark_contact_request_accepted_notifications_read( + &mut self, + contact_user_id: u64, + cx: &mut Context, + ) { + let notification_ids = self.notification_store.read_with(cx, |store, _| { + (0..store.notification_count()) + .filter_map(|index| { + let entry = store.notification_at(index)?; + if entry.is_read { + return None; + } + + match &entry.notification { + Notification::ContactRequestAccepted { responder_id } + if *responder_id == contact_user_id => + { + Some(entry.id) + } + _ => None, + } + }) + .collect::>() + }); + + for notification_id in notification_ids { + self.mark_notification_read(notification_id, cx); + } + } + + fn remove_toast(&mut self, notification_id: u64, cx: &mut Context) { + if let Some((current_id, _)) = &self.current_notification_toast { + if *current_id == notification_id { + self.dismiss_toast(cx); + } + } + } + + fn dismiss_toast(&mut self, cx: &mut Context) { + self.current_notification_toast.take(); + self.workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + workspace.dismiss_notification(&id, cx) + }) + .ok(); + } } fn render_tree_branch( @@ -3516,12 +3707,38 @@ impl Panel for CollabPanel { CollaborationPanelSettings::get_global(cx).default_width } + fn set_active(&mut self, active: bool, _window: &mut Window, cx: &mut Context) { + if active && self.current_notification_toast.is_some() { + self.current_notification_toast.take(); + let workspace = self.workspace.clone(); + cx.defer(move |cx| { + workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + workspace.dismiss_notification(&id, cx) + }) + .ok(); + }); + } + } + fn icon(&self, _window: &Window, cx: &App) -> Option { CollaborationPanelSettings::get_global(cx) .button .then_some(ui::IconName::UserGroup) } + fn icon_label(&self, _window: &Window, cx: &App) -> Option { + let user_store = self.user_store.read(cx); + let count = user_store.incoming_contact_requests().len() + + self.channel_store.read(cx).channel_invitations().len(); + if count == 0 { + None + } else { + Some(count.to_string()) + } + } + fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { Some("Collab Panel") } @@ -3702,6 +3919,101 @@ impl Render for JoinChannelTooltip { } } +pub struct CollabNotificationToast { + actor: Option>, + text: String, + notification: Option, + workspace: WeakEntity, + collab_panel: WeakEntity, + focus_handle: FocusHandle, +} + +impl Focusable for CollabNotificationToast { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl WorkspaceNotification for CollabNotificationToast {} + +impl CollabNotificationToast { + fn focus_collab_panel(&self, window: &mut Window, cx: &mut Context) { + let workspace = self.workspace.clone(); + window.defer(cx, move |window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.focus_panel::(window, cx) + }) + .ok(); + }) + } + + fn respond(&mut self, accept: bool, window: &mut Window, cx: &mut Context) { + if let Some(notification) = self.notification.take() { + self.collab_panel + .update(cx, |collab_panel, cx| match notification { + Notification::ContactRequest { sender_id } => { + collab_panel.respond_to_contact_request(sender_id, accept, window, cx); + } + Notification::ChannelInvitation { channel_id, .. } => { + collab_panel.respond_to_channel_invite(ChannelId(channel_id), accept, cx); + } + Notification::ContactRequestAccepted { .. } => {} + }) + .ok(); + } + cx.emit(DismissEvent); + } +} + +impl Render for CollabNotificationToast { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let needs_response = self.notification.is_some(); + + let accept_button = if needs_response { + Button::new("accept", "Accept").on_click(cx.listener(|this, _, window, cx| { + this.respond(true, window, cx); + cx.stop_propagation(); + })) + } else { + Button::new("dismiss", "Dismiss").on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + }; + + let decline_button = if needs_response { + Button::new("decline", "Decline").on_click(cx.listener(|this, _, window, cx| { + this.respond(false, window, cx); + cx.stop_propagation(); + })) + } else { + Button::new("close", "Close").on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + }; + + let avatar_uri = self + .actor + .as_ref() + .map(|user| user.avatar_uri.clone()) + .unwrap_or_default(); + + div() + .id("collab_notification_toast") + .on_click(cx.listener(|this, _, window, cx| { + this.focus_collab_panel(window, cx); + cx.emit(DismissEvent); + })) + .child( + CollabNotification::new(avatar_uri, accept_button, decline_button) + .child(Label::new(self.text.clone())), + ) + } +} + +impl EventEmitter for CollabNotificationToast {} +impl EventEmitter for CollabNotificationToast {} + #[cfg(any(test, feature = "test-support"))] impl CollabPanel { pub fn entries_as_strings(&self) -> Vec { diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index 107b2ffa7f625d98dd9c54bb6bbf75df8b72d020..f9c463c0690343a3b4b1b9a048134265326a9f50 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -1,7 +1,6 @@ mod call_stats_modal; pub mod channel_view; pub mod collab_panel; -pub mod notification_panel; pub mod notifications; mod panel_settings; @@ -12,7 +11,7 @@ use gpui::{ App, Pixels, PlatformDisplay, Size, WindowBackgroundAppearance, WindowBounds, WindowDecorations, WindowKind, WindowOptions, point, }; -pub use panel_settings::{CollaborationPanelSettings, NotificationPanelSettings}; +pub use panel_settings::CollaborationPanelSettings; use release_channel::ReleaseChannel; use ui::px; use workspace::AppState; @@ -22,7 +21,6 @@ pub fn init(app_state: &Arc, cx: &mut App) { call_stats_modal::init(cx); channel_view::init(cx); collab_panel::init(cx); - notification_panel::init(cx); notifications::init(app_state, cx); title_bar::init(cx); } diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs deleted file mode 100644 index d7fef4873c687ab23a25b3144ba902cf4c42c137..0000000000000000000000000000000000000000 --- a/crates/collab_ui/src/notification_panel.rs +++ /dev/null @@ -1,727 +0,0 @@ -use crate::NotificationPanelSettings; -use anyhow::Result; -use channel::ChannelStore; -use client::{ChannelId, Client, Notification, User, UserStore}; -use collections::HashMap; -use futures::StreamExt; -use gpui::{ - AnyElement, App, AsyncWindowContext, ClickEvent, Context, DismissEvent, Element, Entity, - EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment, - ListScrollEvent, ListState, ParentElement, Render, StatefulInteractiveElement, Styled, Task, - WeakEntity, Window, actions, div, img, list, px, -}; -use notifications::{NotificationEntry, NotificationEvent, NotificationStore}; -use project::Fs; -use rpc::proto; - -use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; -use time::{OffsetDateTime, UtcOffset}; -use ui::{ - Avatar, Button, Icon, IconButton, IconName, Label, Tab, Tooltip, h_flex, prelude::*, v_flex, -}; -use util::ResultExt; -use workspace::notifications::{ - Notification as WorkspaceNotification, NotificationId, SuppressEvent, -}; -use workspace::{ - Workspace, - dock::{DockPosition, Panel, PanelEvent}, -}; - -const LOADING_THRESHOLD: usize = 30; -const MARK_AS_READ_DELAY: Duration = Duration::from_secs(1); -const TOAST_DURATION: Duration = Duration::from_secs(5); -const NOTIFICATION_PANEL_KEY: &str = "NotificationPanel"; - -pub struct NotificationPanel { - client: Arc, - user_store: Entity, - channel_store: Entity, - notification_store: Entity, - fs: Arc, - active: bool, - notification_list: ListState, - subscriptions: Vec, - workspace: WeakEntity, - current_notification_toast: Option<(u64, Task<()>)>, - local_timezone: UtcOffset, - focus_handle: FocusHandle, - mark_as_read_tasks: HashMap>>, - unseen_notifications: Vec, -} - -#[derive(Debug)] -pub enum Event { - DockPositionChanged, - Focus, - Dismissed, -} - -pub struct NotificationPresenter { - pub actor: Option>, - pub text: String, - pub icon: &'static str, - pub needs_response: bool, -} - -actions!( - notification_panel, - [ - /// Toggles the notification panel. - Toggle, - /// Toggles focus on the notification panel. - ToggleFocus - ] -); - -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _| { - workspace.register_action(|workspace, _: &ToggleFocus, window, cx| { - workspace.toggle_panel_focus::(window, cx); - }); - workspace.register_action(|workspace, _: &Toggle, window, cx| { - if !workspace.toggle_panel_focus::(window, cx) { - workspace.close_panel::(window, cx); - } - }); - }) - .detach(); -} - -impl NotificationPanel { - pub fn new( - workspace: &mut Workspace, - window: &mut Window, - cx: &mut Context, - ) -> Entity { - let fs = workspace.app_state().fs.clone(); - let client = workspace.app_state().client.clone(); - let user_store = workspace.app_state().user_store.clone(); - let workspace_handle = workspace.weak_handle(); - - cx.new(|cx| { - let mut status = client.status(); - cx.spawn_in(window, async move |this, cx| { - while (status.next().await).is_some() { - if this - .update(cx, |_: &mut Self, cx| { - cx.notify(); - }) - .is_err() - { - break; - } - } - }) - .detach(); - - let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); - notification_list.set_scroll_handler(cx.listener( - |this, event: &ListScrollEvent, _, cx| { - if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD - && let Some(task) = this - .notification_store - .update(cx, |store, cx| store.load_more_notifications(false, cx)) - { - task.detach(); - } - }, - )); - - let local_offset = chrono::Local::now().offset().local_minus_utc(); - let mut this = Self { - fs, - client, - user_store, - local_timezone: UtcOffset::from_whole_seconds(local_offset).unwrap(), - channel_store: ChannelStore::global(cx), - notification_store: NotificationStore::global(cx), - notification_list, - workspace: workspace_handle, - focus_handle: cx.focus_handle(), - subscriptions: Default::default(), - current_notification_toast: None, - active: false, - mark_as_read_tasks: Default::default(), - unseen_notifications: Default::default(), - }; - - let mut old_dock_position = this.position(window, cx); - this.subscriptions.extend([ - cx.observe(&this.notification_store, |_, _, cx| cx.notify()), - cx.subscribe_in( - &this.notification_store, - window, - Self::on_notification_event, - ), - cx.observe_global_in::( - window, - move |this: &mut Self, window, cx| { - let new_dock_position = this.position(window, cx); - if new_dock_position != old_dock_position { - old_dock_position = new_dock_position; - cx.emit(Event::DockPositionChanged); - } - cx.notify(); - }, - ), - ]); - this - }) - } - - pub fn load( - workspace: WeakEntity, - cx: AsyncWindowContext, - ) -> Task>> { - cx.spawn(async move |cx| { - workspace.update_in(cx, |workspace, window, cx| Self::new(workspace, window, cx)) - }) - } - - fn render_notification( - &mut self, - ix: usize, - window: &mut Window, - cx: &mut Context, - ) -> Option { - let entry = self.notification_store.read(cx).notification_at(ix)?; - let notification_id = entry.id; - let now = OffsetDateTime::now_utc(); - let timestamp = entry.timestamp; - let NotificationPresenter { - actor, - text, - needs_response, - .. - } = self.present_notification(entry, cx)?; - - let response = entry.response; - let notification = entry.notification.clone(); - - if self.active && !entry.is_read { - self.did_render_notification(notification_id, ¬ification, window, cx); - } - - let relative_timestamp = time_format::format_localized_timestamp( - timestamp, - now, - self.local_timezone, - time_format::TimestampFormat::Relative, - ); - - let absolute_timestamp = time_format::format_localized_timestamp( - timestamp, - now, - self.local_timezone, - time_format::TimestampFormat::Absolute, - ); - - Some( - div() - .id(ix) - .flex() - .flex_row() - .size_full() - .px_2() - .py_1() - .gap_2() - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .children(actor.map(|actor| { - img(actor.avatar_uri.clone()) - .flex_none() - .w_8() - .h_8() - .rounded_full() - })) - .child( - v_flex() - .gap_1() - .size_full() - .overflow_hidden() - .child(Label::new(text)) - .child( - h_flex() - .child( - div() - .id("notification_timestamp") - .hover(|style| { - style - .bg(cx.theme().colors().element_selected) - .rounded_sm() - }) - .child(Label::new(relative_timestamp).color(Color::Muted)) - .tooltip(move |_, cx| { - Tooltip::simple(absolute_timestamp.clone(), cx) - }), - ) - .children(if let Some(is_accepted) = response { - Some(div().flex().flex_grow().justify_end().child(Label::new( - if is_accepted { - "You accepted" - } else { - "You declined" - }, - ))) - } else if needs_response { - Some( - h_flex() - .flex_grow() - .justify_end() - .child(Button::new("decline", "Decline").on_click({ - let notification = notification.clone(); - let entity = cx.entity(); - move |_, _, cx| { - entity.update(cx, |this, cx| { - this.respond_to_notification( - notification.clone(), - false, - cx, - ) - }); - } - })) - .child(Button::new("accept", "Accept").on_click({ - let notification = notification.clone(); - let entity = cx.entity(); - move |_, _, cx| { - entity.update(cx, |this, cx| { - this.respond_to_notification( - notification.clone(), - true, - cx, - ) - }); - } - })), - ) - } else { - None - }), - ), - ) - .into_any(), - ) - } - - fn present_notification( - &self, - entry: &NotificationEntry, - cx: &App, - ) -> Option { - let user_store = self.user_store.read(cx); - let channel_store = self.channel_store.read(cx); - match entry.notification { - Notification::ContactRequest { sender_id } => { - let requester = user_store.get_cached_user(sender_id)?; - Some(NotificationPresenter { - icon: "icons/plus.svg", - text: format!("{} wants to add you as a contact", requester.github_login), - needs_response: user_store.has_incoming_contact_request(requester.id), - actor: Some(requester), - }) - } - Notification::ContactRequestAccepted { responder_id } => { - let responder = user_store.get_cached_user(responder_id)?; - Some(NotificationPresenter { - icon: "icons/plus.svg", - text: format!("{} accepted your contact invite", responder.github_login), - needs_response: false, - actor: Some(responder), - }) - } - Notification::ChannelInvitation { - ref channel_name, - channel_id, - inviter_id, - } => { - let inviter = user_store.get_cached_user(inviter_id)?; - Some(NotificationPresenter { - icon: "icons/hash.svg", - text: format!( - "{} invited you to join the #{channel_name} channel", - inviter.github_login - ), - needs_response: channel_store.has_channel_invitation(ChannelId(channel_id)), - actor: Some(inviter), - }) - } - } - } - - fn did_render_notification( - &mut self, - notification_id: u64, - notification: &Notification, - window: &mut Window, - cx: &mut Context, - ) { - let should_mark_as_read = match notification { - Notification::ContactRequestAccepted { .. } => true, - Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } => false, - }; - - if should_mark_as_read { - self.mark_as_read_tasks - .entry(notification_id) - .or_insert_with(|| { - let client = self.client.clone(); - cx.spawn_in(window, async move |this, cx| { - cx.background_executor().timer(MARK_AS_READ_DELAY).await; - client - .request(proto::MarkNotificationRead { notification_id }) - .await?; - this.update(cx, |this, _| { - this.mark_as_read_tasks.remove(¬ification_id); - })?; - Ok(()) - }) - }); - } - } - - fn on_notification_event( - &mut self, - _: &Entity, - event: &NotificationEvent, - window: &mut Window, - cx: &mut Context, - ) { - match event { - NotificationEvent::NewNotification { entry } => { - self.unseen_notifications.push(entry.clone()); - self.add_toast(entry, window, cx); - } - NotificationEvent::NotificationRemoved { entry } - | NotificationEvent::NotificationRead { entry } => { - self.unseen_notifications.retain(|n| n.id != entry.id); - self.remove_toast(entry.id, cx); - } - NotificationEvent::NotificationsUpdated { - old_range, - new_count, - } => { - self.notification_list.splice(old_range.clone(), *new_count); - cx.notify(); - } - } - } - - fn add_toast( - &mut self, - entry: &NotificationEntry, - window: &mut Window, - cx: &mut Context, - ) { - let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx) - else { - return; - }; - - let notification_id = entry.id; - self.current_notification_toast = Some(( - notification_id, - cx.spawn_in(window, async move |this, cx| { - cx.background_executor().timer(TOAST_DURATION).await; - this.update(cx, |this, cx| this.remove_toast(notification_id, cx)) - .ok(); - }), - )); - - self.workspace - .update(cx, |workspace, cx| { - let id = NotificationId::unique::(); - - workspace.dismiss_notification(&id, cx); - workspace.show_notification(id, cx, |cx| { - let workspace = cx.entity().downgrade(); - cx.new(|cx| NotificationToast { - actor, - text, - workspace, - focus_handle: cx.focus_handle(), - }) - }) - }) - .ok(); - } - - fn remove_toast(&mut self, notification_id: u64, cx: &mut Context) { - if let Some((current_id, _)) = &self.current_notification_toast - && *current_id == notification_id - { - self.current_notification_toast.take(); - self.workspace - .update(cx, |workspace, cx| { - let id = NotificationId::unique::(); - workspace.dismiss_notification(&id, cx) - }) - .ok(); - } - } - - fn respond_to_notification( - &mut self, - notification: Notification, - response: bool, - - cx: &mut Context, - ) { - self.notification_store.update(cx, |store, cx| { - store.respond_to_notification(notification, response, cx); - }); - } -} - -impl Render for NotificationPanel { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - v_flex() - .size_full() - .child( - h_flex() - .justify_between() - .px_2() - .py_1() - // Match the height of the tab bar so they line up. - .h(Tab::container_height(cx)) - .border_b_1() - .border_color(cx.theme().colors().border) - .child(Label::new("Notifications")) - .child(Icon::new(IconName::Envelope)), - ) - .map(|this| { - if !self.client.status().borrow().is_connected() { - this.child( - v_flex() - .gap_2() - .p_4() - .child( - Button::new("connect_prompt_button", "Connect") - .start_icon(Icon::new(IconName::Github).color(Color::Muted)) - .style(ButtonStyle::Filled) - .full_width() - .on_click({ - let client = self.client.clone(); - move |_, window, cx| { - let client = client.clone(); - window - .spawn(cx, async move |cx| { - match client.connect(true, cx).await { - util::ConnectionResult::Timeout => { - log::error!("Connection timeout"); - } - util::ConnectionResult::ConnectionReset => { - log::error!("Connection reset"); - } - util::ConnectionResult::Result(r) => { - r.log_err(); - } - } - }) - .detach() - } - }), - ) - .child( - div().flex().w_full().items_center().child( - Label::new("Connect to view notifications.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) - } else if self.notification_list.item_count() == 0 { - this.child( - v_flex().p_4().child( - div().flex().w_full().items_center().child( - Label::new("You have no notifications.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) - } else { - this.child( - list( - self.notification_list.clone(), - cx.processor(|this, ix, window, cx| { - this.render_notification(ix, window, cx) - .unwrap_or_else(|| div().into_any()) - }), - ) - .size_full(), - ) - } - }) - } -} - -impl Focusable for NotificationPanel { - fn focus_handle(&self, _: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl EventEmitter for NotificationPanel {} -impl EventEmitter for NotificationPanel {} - -impl Panel for NotificationPanel { - fn persistent_name() -> &'static str { - "NotificationPanel" - } - - fn panel_key() -> &'static str { - NOTIFICATION_PANEL_KEY - } - - fn position(&self, _: &Window, cx: &App) -> DockPosition { - NotificationPanelSettings::get_global(cx).dock - } - - fn position_is_valid(&self, position: DockPosition) -> bool { - matches!(position, DockPosition::Left | DockPosition::Right) - } - - fn set_position(&mut self, position: DockPosition, _: &mut Window, cx: &mut Context) { - settings::update_settings_file(self.fs.clone(), cx, move |settings, _| { - settings.notification_panel.get_or_insert_default().dock = Some(position.into()) - }); - } - - fn default_size(&self, _: &Window, cx: &App) -> Pixels { - NotificationPanelSettings::get_global(cx).default_width - } - - fn set_active(&mut self, active: bool, _: &mut Window, cx: &mut Context) { - self.active = active; - - if self.active { - self.unseen_notifications = Vec::new(); - cx.notify(); - } - - if self.notification_store.read(cx).notification_count() == 0 { - cx.emit(Event::Dismissed); - } - } - - fn icon(&self, _: &Window, cx: &App) -> Option { - let show_button = NotificationPanelSettings::get_global(cx).button; - if !show_button { - return None; - } - - if self.unseen_notifications.is_empty() { - return Some(IconName::Bell); - } - - Some(IconName::BellDot) - } - - fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { - Some("Notification Panel") - } - - fn icon_label(&self, _window: &Window, cx: &App) -> Option { - if !NotificationPanelSettings::get_global(cx).show_count_badge { - return None; - } - let count = self.notification_store.read(cx).unread_notification_count(); - if count == 0 { - None - } else { - Some(count.to_string()) - } - } - - fn toggle_action(&self) -> Box { - Box::new(ToggleFocus) - } - - fn activation_priority(&self) -> u32 { - 4 - } -} - -pub struct NotificationToast { - actor: Option>, - text: String, - workspace: WeakEntity, - focus_handle: FocusHandle, -} - -impl Focusable for NotificationToast { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl WorkspaceNotification for NotificationToast {} - -impl NotificationToast { - fn focus_notification_panel(&self, window: &mut Window, cx: &mut Context) { - let workspace = self.workspace.clone(); - window.defer(cx, move |window, cx| { - workspace - .update(cx, |workspace, cx| { - workspace.focus_panel::(window, cx) - }) - .ok(); - }) - } -} - -impl Render for NotificationToast { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let user = self.actor.clone(); - - let suppress = window.modifiers().shift; - let (close_id, close_icon) = if suppress { - ("suppress", IconName::Minimize) - } else { - ("close", IconName::Close) - }; - - h_flex() - .id("notification_panel_toast") - .elevation_3(cx) - .p_2() - .justify_between() - .children(user.map(|user| Avatar::new(user.avatar_uri.clone()))) - .child(Label::new(self.text.clone())) - .on_modifiers_changed(cx.listener(|_, _, _, cx| cx.notify())) - .child( - IconButton::new(close_id, close_icon) - .tooltip(move |_window, cx| { - if suppress { - Tooltip::for_action( - "Suppress.\nClose with click.", - &workspace::SuppressNotification, - cx, - ) - } else { - Tooltip::for_action( - "Close.\nSuppress with shift-click", - &menu::Cancel, - cx, - ) - } - }) - .on_click(cx.listener(move |_, _: &ClickEvent, _, cx| { - if suppress { - cx.emit(SuppressEvent); - } else { - cx.emit(DismissEvent); - } - })), - ) - .on_click(cx.listener(|this, _, window, cx| { - this.focus_notification_panel(window, cx); - cx.emit(DismissEvent); - })) - } -} - -impl EventEmitter for NotificationToast {} -impl EventEmitter for NotificationToast {} diff --git a/crates/collab_ui/src/panel_settings.rs b/crates/collab_ui/src/panel_settings.rs index 938d33159e9adb7a9e63ceb73219b70724efee17..3d6de1015a3751751c13c8ccb6d4c5639755be20 100644 --- a/crates/collab_ui/src/panel_settings.rs +++ b/crates/collab_ui/src/panel_settings.rs @@ -10,14 +10,6 @@ pub struct CollaborationPanelSettings { pub default_width: Pixels, } -#[derive(Debug, RegisterSetting)] -pub struct NotificationPanelSettings { - pub button: bool, - pub dock: DockPosition, - pub default_width: Pixels, - pub show_count_badge: bool, -} - impl Settings for CollaborationPanelSettings { fn from_settings(content: &settings::SettingsContent) -> Self { let panel = content.collaboration_panel.as_ref().unwrap(); @@ -29,15 +21,3 @@ impl Settings for CollaborationPanelSettings { } } } - -impl Settings for NotificationPanelSettings { - fn from_settings(content: &settings::SettingsContent) -> Self { - let panel = content.notification_panel.as_ref().unwrap(); - return Self { - button: panel.button.unwrap(), - dock: panel.dock.unwrap().into(), - default_width: panel.default_width.map(px).unwrap(), - show_count_badge: panel.show_count_badge.unwrap(), - }; - } -} diff --git a/crates/settings/src/vscode_import.rs b/crates/settings/src/vscode_import.rs index 1211cbd8a4519ea295773eb0d979b48258908311..4c7ce085aed5ad0cf7c48308b4211815cf5aad75 100644 --- a/crates/settings/src/vscode_import.rs +++ b/crates/settings/src/vscode_import.rs @@ -198,7 +198,7 @@ impl VsCodeSettings { log: None, message_editor: None, node: self.node_binary_settings(), - notification_panel: None, + outline_panel: self.outline_panel_settings_content(), preview_tabs: self.preview_tabs_settings_content(), project: self.project_settings_content(), diff --git a/crates/settings_content/src/settings_content.rs b/crates/settings_content/src/settings_content.rs index 6c60a7010f7cfc5b4fadf9a8cc386fe6e3267abc..3c3c0f600769b8437dc56016426eee4f84d2fc7a 100644 --- a/crates/settings_content/src/settings_content.rs +++ b/crates/settings_content/src/settings_content.rs @@ -174,9 +174,6 @@ pub struct SettingsContent { /// Configuration for Node-related features pub node: Option, - /// Configuration for the Notification Panel - pub notification_panel: Option, - pub proxy: Option, /// The URL of the Zed server to connect to. @@ -631,28 +628,6 @@ pub struct ScrollbarSettings { pub show: Option, } -#[with_fallible_options] -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug, PartialEq)] -pub struct NotificationPanelSettingsContent { - /// Whether to show the panel button in the status bar. - /// - /// Default: true - pub button: Option, - /// Where to dock the panel. - /// - /// Default: right - pub dock: Option, - /// Default width of the panel in pixels. - /// - /// Default: 300 - #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] - pub default_width: Option, - /// Whether to show a badge on the notification panel icon with the count of unread notifications. - /// - /// Default: false - pub show_count_badge: Option, -} - #[with_fallible_options] #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug, PartialEq)] pub struct PanelSettingsContent { diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 259ee2cf261f9e435a5431ddf3c470640daf41f9..c77bf5a326c6b48dea2c85f0744de0066d8c0236 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -5579,96 +5579,6 @@ fn panels_page() -> SettingsPage { ] } - fn notification_panel_section() -> [SettingsPageItem; 5] { - [ - SettingsPageItem::SectionHeader("Notification Panel"), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Button", - description: "Show the notification panel button in the status bar.", - field: Box::new(SettingField { - json_path: Some("notification_panel.button"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .button - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .button = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Dock", - description: "Where to dock the notification panel.", - field: Box::new(SettingField { - json_path: Some("notification_panel.dock"), - pick: |settings_content| { - settings_content.notification_panel.as_ref()?.dock.as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .dock = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Default Width", - description: "Default width of the notification panel in pixels.", - field: Box::new(SettingField { - json_path: Some("notification_panel.default_width"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .default_width - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .default_width = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Show Count Badge", - description: "Show a badge on the notification panel icon with the count of unread notifications.", - field: Box::new(SettingField { - json_path: Some("notification_panel.show_count_badge"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .show_count_badge - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .show_count_badge = value; - }, - }), - metadata: None, - files: USER, - }), - ] - } - fn collaboration_panel_section() -> [SettingsPageItem; 4] { [ SettingsPageItem::SectionHeader("Collaboration Panel"), @@ -5841,7 +5751,6 @@ fn panels_page() -> SettingsPage { outline_panel_section(), git_panel_section(), debugger_panel_section(), - notification_panel_section(), collaboration_panel_section(), agent_panel_section(), ], diff --git a/crates/ui/src/components/collab/collab_notification.rs b/crates/ui/src/components/collab/collab_notification.rs index 0c3fca84e9b9fb3246de20b9b1f077202fa3ebdb..28d28b0a292076a575a5443b80eae9b788e2b62e 100644 --- a/crates/ui/src/components/collab/collab_notification.rs +++ b/crates/ui/src/components/collab/collab_notification.rs @@ -67,7 +67,7 @@ impl Component for CollabNotification { let avatar = "https://avatars.githubusercontent.com/u/67129314?v=4"; let container = || div().h(px(72.)).w(px(400.)); // Size of the actual notification window - let examples = vec![ + let call_examples = vec![ single_example( "Incoming Call", container() @@ -129,6 +129,58 @@ impl Component for CollabNotification { ), ]; - Some(example_group(examples).vertical().into_any_element()) + let toast_examples = vec![ + single_example( + "Contact Request", + container() + .child( + CollabNotification::new( + avatar, + Button::new("accept", "Accept"), + Button::new("decline", "Decline"), + ) + .child(Label::new("maxbrunsfeld wants to add you as a contact")), + ) + .into_any_element(), + ), + single_example( + "Contact Request Accepted", + container() + .child( + CollabNotification::new( + avatar, + Button::new("dismiss", "Dismiss"), + Button::new("close", "Close"), + ) + .child(Label::new("maxbrunsfeld accepted your contact request")), + ) + .into_any_element(), + ), + single_example( + "Channel Invitation", + container() + .child( + CollabNotification::new( + avatar, + Button::new("accept", "Accept"), + Button::new("decline", "Decline"), + ) + .child(Label::new( + "maxbrunsfeld invited you to join the #zed channel", + )), + ) + .into_any_element(), + ), + ]; + + Some( + v_flex() + .gap_6() + .child(example_group_with_title("Calls & Projects", call_examples).vertical()) + .child( + example_group_with_title("Contact & Channel Toasts", toast_examples).vertical(), + ) + .into_any_element(), + ) } } diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index fd19a5dc400a24b9f27617c44bd71fe38073c757..06fa6ead775809c3df775d959fb080a93ee84aad 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -1782,7 +1782,6 @@ fn generate_commands(_: &App) -> Vec { VimCommand::str(("te", "rm"), "terminal_panel::Toggle"), VimCommand::str(("T", "erm"), "terminal_panel::Toggle"), VimCommand::str(("C", "ollab"), "collab_panel::ToggleFocus"), - VimCommand::str(("No", "tifications"), "notification_panel::ToggleFocus"), VimCommand::str(("A", "I"), "agent::ToggleFocus"), VimCommand::str(("G", "it"), "git_panel::ToggleFocus"), VimCommand::str(("D", "ebug"), "debug_panel::ToggleFocus"), diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 03e128415e1aa8390d1b95816755d3644064dada..293125c0089e0a4315eb9c28f30be5f840bd6052 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -652,10 +652,6 @@ fn initialize_panels(window: &mut Window, cx: &mut Context) -> Task) -> Task(window, cx); }, ) - .register_action( - |workspace: &mut Workspace, - _: &collab_ui::notification_panel::ToggleFocus, - window: &mut Window, - cx: &mut Context| { - workspace.toggle_panel_focus::( - window, cx, - ); - }, - ) .register_action( |workspace: &mut Workspace, _: &terminal_panel::ToggleFocus, @@ -4962,7 +4947,6 @@ mod tests { "multi_workspace", "new_process_modal", "notebook", - "notification_panel", "onboarding", "outline", "outline_panel", diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 3c285bc3d10fc3bcb5fba6f735304ede438104a3..7597cdac293dd842b6a6a9f5747551a6f172bbf3 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -105,7 +105,7 @@ To disable this behavior use: // "outline_panel": {"button": false }, // "collaboration_panel": {"button": false }, // "git_panel": {"button": false }, - // "notification_panel": {"button": false }, + // "agent": {"button": false }, // "debugger": {"button": false }, // "diagnostics": {"button": false }, @@ -588,16 +588,6 @@ See [Terminal settings](./reference/all-settings.md#terminal) for additional non "dock": "left", // Where to dock: left, right "default_width": 240 // Default width of the collaboration panel. }, - "show_call_status_icon": true, // Shown call status in the OS status bar. - - // Notification Panel - "notification_panel": { - // Whether to show the notification panel button in the status bar. - "button": true, - // Where to dock the notification panel. Can be 'left' or 'right'. - "dock": "right", - // Default width of the notification panel. - "default_width": 380 - } + "show_call_status_icon": true // Shown call status in the OS status bar. } ``` From 0de9b553b5126a1a0e375c435d27e7f2124d9242 Mon Sep 17 00:00:00 2001 From: Chris Biscardi Date: Tue, 7 Apr 2026 09:14:47 -0700 Subject: [PATCH 20/22] Save Settings text inputs on blur (#53036) Co-authored-by: Anthony Eid --- .../settings_ui/src/components/input_field.rs | 59 +++++++++++++++++-- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/crates/settings_ui/src/components/input_field.rs b/crates/settings_ui/src/components/input_field.rs index 35e63078c154dd324c8dd622b8d98c2de36beb68..e93944cf32cddc02e10a5e4f3251e80563c992b4 100644 --- a/crates/settings_ui/src/components/input_field.rs +++ b/crates/settings_ui/src/components/input_field.rs @@ -109,16 +109,37 @@ impl RenderOnce for SettingsInputField { ..Default::default() }; + let first_render_initial_text = window.use_state(cx, |_, _| self.initial_text.clone()); + let editor = if let Some(id) = self.id { window.use_keyed_state(id, cx, { let initial_text = self.initial_text.clone(); let placeholder = self.placeholder; + let mut confirm = self.confirm.clone(); + move |window, cx| { let mut editor = Editor::single_line(window, cx); + let editor_focus_handle = editor.focus_handle(cx); if let Some(text) = initial_text { editor.set_text(text, window, cx); } + if let Some(confirm) = confirm.take() + && !self.display_confirm_button + && !self.display_clear_button + && !self.clear_on_confirm + { + cx.on_focus_out( + &editor_focus_handle, + window, + move |editor, _, window, cx| { + let text = Some(editor.text(cx)); + confirm(text, window, cx); + }, + ) + .detach(); + } + if let Some(placeholder) = placeholder { editor.set_placeholder_text(placeholder, window, cx); } @@ -130,12 +151,31 @@ impl RenderOnce for SettingsInputField { window.use_state(cx, { let initial_text = self.initial_text.clone(); let placeholder = self.placeholder; + let mut confirm = self.confirm.clone(); + move |window, cx| { let mut editor = Editor::single_line(window, cx); + let editor_focus_handle = editor.focus_handle(cx); if let Some(text) = initial_text { editor.set_text(text, window, cx); } + if let Some(confirm) = confirm.take() + && !self.display_confirm_button + && !self.display_clear_button + && !self.clear_on_confirm + { + cx.on_focus_out( + &editor_focus_handle, + window, + move |editor, _, window, cx| { + let text = Some(editor.text(cx)); + confirm(text, window, cx); + }, + ) + .detach(); + } + if let Some(placeholder) = placeholder { editor.set_placeholder_text(placeholder, window, cx); } @@ -149,11 +189,20 @@ impl RenderOnce for SettingsInputField { // re-renders but use_keyed_state returns the cached editor with stale text. // Reconcile with the expected initial_text when the editor is not focused, // so we don't clobber what the user is actively typing. - if let Some(initial_text) = &self.initial_text { - let current_text = editor.read(cx).text(cx); - if current_text != *initial_text && !editor.read(cx).is_focused(window) { - editor.update(cx, |editor, cx| { - editor.set_text(initial_text.clone(), window, cx); + if let Some(initial_text) = &self.initial_text + && let Some(first_initial) = first_render_initial_text.read(cx) + { + if initial_text != first_initial && !editor.read(cx).is_focused(window) { + *first_render_initial_text.as_mut(cx) = self.initial_text.clone(); + let weak_editor = editor.downgrade(); + let initial_text = initial_text.clone(); + + window.defer(cx, move |window, cx| { + weak_editor + .update(cx, |editor, cx| { + editor.set_text(initial_text, window, cx); + }) + .ok(); }); } } From 4f3e4d2f46d0e01fff5b80f6e4583a0d46847e72 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Tue, 7 Apr 2026 19:17:49 +0300 Subject: [PATCH 21/22] Stop highlighting selection matches in the search inputs (#53307) Follow-up of https://github.com/zed-industries/zed/pull/52553 Restores previous search inputs' behavior where no extra highlights were applied. Before: https://github.com/user-attachments/assets/38b6e70c-d5d5-4e06-abec-97d20af44f39 After: https://github.com/user-attachments/assets/6e4b3931-adf0-4c2a-afc3-f3c839fc9add Release Notes: - N/A --- crates/editor/src/editor.rs | 8 +++++++- crates/search/src/buffer_search.rs | 1 + crates/search/src/project_search.rs | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ae852b1055b33f151b402ee999ce50ba064788a4..e6f597de7ff9138b226cd2474353ef8c2ce16ebb 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1265,6 +1265,7 @@ pub struct Editor { >, use_autoclose: bool, use_auto_surround: bool, + use_selection_highlight: bool, auto_replace_emoji_shortcode: bool, jsx_tag_auto_close_enabled_in_any_buffer: bool, show_git_blame_gutter: bool, @@ -2468,6 +2469,7 @@ impl Editor { read_only: is_minimap, use_autoclose: true, use_auto_surround: true, + use_selection_highlight: true, auto_replace_emoji_shortcode: false, jsx_tag_auto_close_enabled_in_any_buffer: false, leader_id: None, @@ -3547,6 +3549,10 @@ impl Editor { self.use_autoclose = autoclose; } + pub fn set_use_selection_highlight(&mut self, highlight: bool) { + self.use_selection_highlight = highlight; + } + pub fn set_use_auto_surround(&mut self, auto_surround: bool) { self.use_auto_surround = auto_surround; } @@ -7699,7 +7705,7 @@ impl Editor { if matches!(self.mode, EditorMode::SingleLine) { return None; } - if !EditorSettings::get_global(cx).selection_highlight { + if !self.use_selection_highlight || !EditorSettings::get_global(cx).selection_highlight { return None; } if self.selections.count() != 1 || self.selections.line_mode() { diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index 46177c5642a8d05daaf22e9fb24b205cd10ca42b..3a5fbe3fcae6241495deb43930b83bb78ba81968 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -849,6 +849,7 @@ impl BufferSearchBar { let query_editor = cx.new(|cx| { let mut editor = Editor::auto_height(1, 4, window, cx); editor.set_use_autoclose(false); + editor.set_use_selection_highlight(false); editor }); cx.subscribe_in(&query_editor, window, Self::on_query_editor_event) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 7c9d3f176ed3f17ec5e21faa7c1b483252657614..7e7903674e3d883bfb98ac8d57b5f407237f66d1 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -939,6 +939,7 @@ impl ProjectSearchView { let mut editor = Editor::auto_height(1, 4, window, cx); editor.set_placeholder_text("Search all files…", window, cx); editor.set_use_autoclose(false); + editor.set_use_selection_highlight(false); editor.set_text(query_text, window, cx); editor }); From 70d6c2bdc4e6238d55b147555a1e3a06c8c8fc71 Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:18:15 -0400 Subject: [PATCH 22/22] git_graph: Show propagated errors from git binary command (#53320) Based on commit fba49809b39b0f9e58d68e3956f5c24fd47121d7 that I worked with Dino on in PR: #50288 Co-authored-by Dino \ Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ... --- crates/fs/src/fake_git_repo.rs | 15 ++++++- crates/fs/src/fs.rs | 7 +++ crates/git/src/repository.rs | 18 +++++++- crates/git_graph/src/git_graph.rs | 71 +++++++++++++++++++++++++++++-- 4 files changed, 103 insertions(+), 8 deletions(-) diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index fbebeabf0ac15dde80016958eb358f792f46dd50..7b89a0751f17ef8c2bba837882f2a31c7d5451e5 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -61,6 +61,7 @@ pub struct FakeGitRepositoryState { pub remotes: HashMap, pub simulated_index_write_error_message: Option, pub simulated_create_worktree_error: Option, + pub simulated_graph_error: Option, pub refs: HashMap, pub graph_commits: Vec>, pub stash_entries: GitStash, @@ -78,6 +79,7 @@ impl FakeGitRepositoryState { branches: Default::default(), simulated_index_write_error_message: Default::default(), simulated_create_worktree_error: Default::default(), + simulated_graph_error: None, refs: HashMap::from_iter([("HEAD".into(), "abc".into())]), merge_base_contents: Default::default(), oids: Default::default(), @@ -1327,8 +1329,17 @@ impl GitRepository for FakeGitRepository { let fs = self.fs.clone(); let dot_git_path = self.dot_git_path.clone(); async move { - let graph_commits = - fs.with_git_state(&dot_git_path, false, |state| state.graph_commits.clone())?; + let (graph_commits, simulated_error) = + fs.with_git_state(&dot_git_path, false, |state| { + ( + state.graph_commits.clone(), + state.simulated_graph_error.clone(), + ) + })?; + + if let Some(error) = simulated_error { + anyhow::bail!("{}", error); + } for chunk in graph_commits.chunks(GRAPH_CHUNK_SIZE) { request_tx.send(chunk.to_vec()).await.ok(); diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index a26abb81255003e4059f9bcc8a68aa3c6212a73a..52cae537b6f00837b50123af0cae7c093699dedf 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -2168,6 +2168,13 @@ impl FakeFs { .unwrap(); } + pub fn set_graph_error(&self, dot_git: &Path, error: Option) { + self.with_git_state(dot_git, true, |state| { + state.simulated_graph_error = error; + }) + .unwrap(); + } + /// Put the given git repository into a state with the given status, /// by mutating the head, index, and unmerged state. pub fn set_status_for_repo(&self, dot_git: &Path, statuses: &[(&str, FileStatus)]) { diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index ba717d00c5e40374f5315d3ee8bc12e671f09552..d7049c0a50cb94c049556e395e818dbbddfb89bf 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -2784,10 +2784,11 @@ impl GitRepository for RealGitRepository { log_source.get_arg()?, ]); command.stdout(Stdio::piped()); - command.stderr(Stdio::null()); + command.stderr(Stdio::piped()); let mut child = command.spawn()?; let stdout = child.stdout.take().context("failed to get stdout")?; + let stderr = child.stderr.take().context("failed to get stderr")?; let mut reader = BufReader::new(stdout); let mut line_buffer = String::new(); @@ -2822,7 +2823,20 @@ impl GitRepository for RealGitRepository { } } - child.status().await?; + let status = child.status().await?; + if !status.success() { + let mut stderr_output = String::new(); + BufReader::new(stderr) + .read_to_string(&mut stderr_output) + .await + .log_err(); + + if stderr_output.is_empty() { + anyhow::bail!("git log command failed with {}", status); + } else { + anyhow::bail!("git log command failed with {}: {}", status, stderr_output); + } + } Ok(()) } .boxed() diff --git a/crates/git_graph/src/git_graph.rs b/crates/git_graph/src/git_graph.rs index aa5f6bc6e1293cfd057baa0c5e9f77819da71086..7594a206f14705bf47a673dee9abefad5a3446de 100644 --- a/crates/git_graph/src/git_graph.rs +++ b/crates/git_graph/src/git_graph.rs @@ -2536,11 +2536,19 @@ impl Render for GitGraph { } }; + let error = self.get_repository(cx).and_then(|repo| { + repo.read(cx) + .get_graph_data(self.log_source.clone(), self.log_order) + .and_then(|data| data.error.clone()) + }); + let content = if commit_count == 0 { - let message = if is_loading { - "Loading" + let message = if let Some(error) = &error { + format!("Error loading: {}", error) + } else if is_loading { + "Loading".to_string() } else { - "No commits found" + "No commits found".to_string() }; let label = Label::new(message) .color(Color::Muted) @@ -2552,7 +2560,7 @@ impl Render for GitGraph { .items_center() .justify_center() .child(label) - .when(is_loading, |this| { + .when(is_loading && error.is_none(), |this| { this.child(self.render_loading_spinner(cx)) }) } else { @@ -3757,6 +3765,61 @@ mod tests { ); } + #[gpui::test] + async fn test_initial_graph_data_propagates_error(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + Path::new("/project"), + json!({ + ".git": {}, + "file.txt": "content", + }), + ) + .await; + + fs.set_graph_error( + Path::new("/project/.git"), + Some("fatal: bad default revision 'HEAD'".to_string()), + ); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let repository = project.read_with(cx, |project, cx| { + project + .active_repository(cx) + .expect("should have a repository") + }); + + repository.update(cx, |repo, cx| { + repo.graph_data( + crate::LogSource::default(), + crate::LogOrder::default(), + 0..usize::MAX, + cx, + ); + }); + + cx.run_until_parked(); + + let error = repository.read_with(cx, |repo, _| { + repo.get_graph_data(crate::LogSource::default(), crate::LogOrder::default()) + .and_then(|data| data.error.clone()) + }); + + assert!( + error.is_some(), + "graph data should contain an error after initial_graph_data fails" + ); + let error_message = error.unwrap(); + assert!( + error_message.contains("bad default revision"), + "error should contain the git error message, got: {}", + error_message + ); + } + #[gpui::test] async fn test_graph_data_repopulated_from_cache_after_repo_switch(cx: &mut TestAppContext) { init_test(cx);