From fd17f2d8ae0b32a8fd475f9d1b25951963266c5e Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 29 Apr 2025 14:03:02 -0300 Subject: [PATCH] agent: Enrich `grep` tool output with syntax information (#29601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `grep` tool used to include 4 lines of context around the match, but the lines included would often be unhelpful. This PR improves this behavior by using the range of the parent syntax node that contains the full line(s) matched. The match headers will also now include symbol breadcrumbs so that the model can already gather code structure before/without reading files. ````md ### impl GitRepository for RealGitRepository › fn compare_checkpoints › L1278-1284 ```rust let result = git .run(&[ "diff-tree", "--quiet", &left.commit_sha.to_string(), &right.commit_sha.to_string(), ]) ``` ```` This positively impacts the `add_arg_to_trait_method` eval example with better diff output, fewer tool failures, and reduced total turns. Note: We have some plans to use a an "elision" approach where we would combine all matches for a given file, skipping lines between them while keeping symbol declaration lines. The theory is that this would be map more closely to the expected input for edits. For now, this PR is a significant improvement. Release Notes: - Agent: Enrich `grep` tool output with syntax information --- crates/assistant_tools/src/grep_tool.rs | 461 ++++++++++++++++-- crates/eval/src/example.rs | 1 + .../src/examples/add_arg_to_trait_method.rs | 37 +- 3 files changed, 412 insertions(+), 87 deletions(-) diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 439dca17d0aa7030d549292d84e8dcd65f859d0e..e296a472b22e68b43f19bce02c4d9c602cb7144e 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -3,7 +3,7 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use futures::StreamExt; use gpui::{AnyWindowHandle, App, Entity, Task}; -use language::OffsetRangeExt; +use language::{OffsetRangeExt, ParseStatus, Point}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::{ Project, @@ -13,6 +13,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{cmp, fmt::Write, sync::Arc}; use ui::IconName; +use util::RangeExt; use util::markdown::MarkdownInlineCode; use util::paths::PathMatcher; @@ -102,6 +103,7 @@ impl Tool for GrepTool { cx: &mut App, ) -> ToolResult { const CONTEXT_LINES: u32 = 2; + const MAX_ANCESTOR_LINES: u32 = 10; let input = match serde_json::from_value::(input) { Ok(input) => input, @@ -140,7 +142,7 @@ impl Tool for GrepTool { let results = project.update(cx, |project, cx| project.search(query, cx)); - cx.spawn(async move|cx| { + cx.spawn(async move |cx| { futures::pin_mut!(results); let mut output = String::new(); @@ -148,68 +150,113 @@ impl Tool for GrepTool { let mut matches_found = 0; let mut has_more_matches = false; - while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { + 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { if ranges.is_empty() { continue; } - buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> { - if let Some(path) = buffer.file().map(|file| file.full_path(cx)) { - let mut file_header_written = false; - let mut ranges = ranges - .into_iter() - .map(|range| { - let mut point_range = range.to_point(buffer); - point_range.start.row = - point_range.start.row.saturating_sub(CONTEXT_LINES); - point_range.start.column = 0; - point_range.end.row = cmp::min( - buffer.max_point().row, - point_range.end.row + CONTEXT_LINES, - ); - point_range.end.column = buffer.line_len(point_range.end.row); - point_range - }) - .peekable(); - - while let Some(mut range) = ranges.next() { - if skips_remaining > 0 { - skips_remaining -= 1; - continue; - } + let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| { + (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status()) + })? else { + continue; + }; - // We'd already found a full page of matches, and we just found one more. - if matches_found >= RESULTS_PER_PAGE { - has_more_matches = true; - return Ok(()); - } - while let Some(next_range) = ranges.peek() { - if range.end.row >= next_range.start.row { - range.end = next_range.end; - ranges.next(); - } else { - break; - } - } + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let mut ranges = ranges + .into_iter() + .map(|range| { + let matched = range.to_point(&snapshot); + let matched_end_line_len = snapshot.line_len(matched.end.row); + let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len); + let symbols = snapshot.symbols_containing(matched.start, None); - if !file_header_written { - writeln!(output, "\n## Matches in {}", path.display())?; - file_header_written = true; + if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) { + let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot); + let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES); + let end_col = snapshot.line_len(end_row); + let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col); + + if capped_ancestor_range.contains_inclusive(&full_lines) { + return (capped_ancestor_range, Some(full_ancestor_range), symbols) } + } + + let mut matched = matched; + matched.start.column = 0; + matched.start.row = + matched.start.row.saturating_sub(CONTEXT_LINES); + matched.end.row = cmp::min( + snapshot.max_point().row, + matched.end.row + CONTEXT_LINES, + ); + matched.end.column = snapshot.line_len(matched.end.row); + + (matched, None, symbols) + }) + .peekable(); + + let mut file_header_written = false; + + while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){ + if skips_remaining > 0 { + skips_remaining -= 1; + continue; + } + + // We'd already found a full page of matches, and we just found one more. + if matches_found >= RESULTS_PER_PAGE { + has_more_matches = true; + break 'outer; + } + + while let Some((next_range, _, _)) = ranges.peek() { + if range.end.row >= next_range.start.row { + range.end = next_range.end; + ranges.next(); + } else { + break; + } + } + + if !file_header_written { + writeln!(output, "\n## Matches in {}", path.display())?; + file_header_written = true; + } - let start_line = range.start.row + 1; - let end_line = range.end.row + 1; - writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?; - output.extend(buffer.text_for_range(range)); - output.push_str("\n```\n"); + let end_row = range.end.row; + output.push_str("\n### "); - matches_found += 1; + if let Some(parent_symbols) = &parent_symbols { + for symbol in parent_symbols { + write!(output, "{} › ", symbol.text)?; } } - Ok(()) - })??; + if range.start.row == end_row { + writeln!(output, "L{}", range.start.row + 1)?; + } else { + writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?; + } + + output.push_str("```\n"); + output.extend(snapshot.text_for_range(range)); + output.push_str("\n```\n"); + + if let Some(ancestor_range) = ancestor_range { + if end_row < ancestor_range.end.row { + let remaining_lines = ancestor_range.end.row - end_row; + writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; + } + } + + matches_found += 1; + } } if matches_found == 0 { @@ -233,13 +280,16 @@ mod tests { use super::*; use assistant_tool::Tool; use gpui::{AppContext, TestAppContext}; + use language::{Language, LanguageConfig, LanguageMatcher}; use project::{FakeFs, Project}; use settings::SettingsStore; + use unindent::Unindent; use util::path; #[gpui::test] async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) { init_test(cx); + cx.executor().allow_parking(); let fs = FakeFs::new(cx.executor().clone()); fs.insert_tree( @@ -327,6 +377,7 @@ mod tests { #[gpui::test] async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) { init_test(cx); + cx.executor().allow_parking(); let fs = FakeFs::new(cx.executor().clone()); fs.insert_tree( @@ -401,6 +452,290 @@ mod tests { ); } + /// Helper function to set up a syntax test environment + async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity { + use unindent::Unindent; + init_test(cx); + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor().clone()); + + // Create test file with syntax structures + fs.insert_tree( + "/root", + serde_json::json!({ + "test_syntax.rs": r#" + fn top_level_function() { + println!("This is at the top level"); + } + + mod feature_module { + pub mod nested_module { + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + } + } + + struct MyStruct { + field1: String, + field2: i32, + } + + impl MyStruct { + fn method_with_block() { + let condition = true; + if condition { + println!("Inside if block"); + } + } + + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + + trait Processor { + fn process(&self, input: &str) -> String; + } + + impl Processor for MyStruct { + fn process(&self, input: &str) -> String { + format!("Processed: {}", input) + } + } + "#.unindent().trim(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + project.update(cx, |project, _cx| { + project.languages().add(rust_lang().into()) + }); + + project + } + + #[gpui::test] + async fn test_grep_top_level_function(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line at the top level of the file + let input = serde_json::to_value(GrepToolInput { + regex: "This is at the top level".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### fn top_level_function › L1-3 + ``` + fn top_level_function() { + println!("This is at the top level"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line inside a function body + let input = serde_json::to_value(GrepToolInput { + regex: "Function in nested module".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14 + ``` + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_args_and_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line with a function argument + let input = serde_json::to_value(GrepToolInput { + regex: "second_arg".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14 + ``` + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_if_block(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line inside an if block + let input = serde_json::to_value(GrepToolInput { + regex: "Inside if block".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn method_with_block › L26-28 + ``` + if condition { + println!("Inside if block"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_top(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the middle of a long function - should show message about remaining lines + let input = serde_json::to_value(GrepToolInput { + regex: "Line 5".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L31-41 + ``` + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + ``` + + 3 lines remaining in ancestor node. Read the file to see all. + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_bottom(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the long function + let input = serde_json::to_value(GrepToolInput { + regex: "Line 12".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L41-45 + ``` + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + async fn run_grep_tool( input: serde_json::Value, project: Entity, @@ -411,7 +746,13 @@ mod tests { let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx)); match task.output.await { - Ok(result) => result, + Ok(result) => { + if cfg!(windows) { + result.replace("root\\", "root/") + } else { + result + } + } Err(e) => panic!("Failed to run grep tool: {}", e), } } @@ -424,4 +765,20 @@ mod tests { Project::init_settings(cx); }); } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 9a9b7fd577ffba0b6fb635539741405615477b97..328fb25df817c78600c12b7e904bc4d5e6f4264a 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -387,6 +387,7 @@ impl Response { cx.assert_some(result, format!("called `{}`", tool_name)) } + #[allow(dead_code)] pub fn tool_uses(&self) -> impl Iterator { self.messages.iter().flat_map(|msg| &msg.tool_use) } diff --git a/crates/eval/src/examples/add_arg_to_trait_method.rs b/crates/eval/src/examples/add_arg_to_trait_method.rs index 5c3fb788f0bdbb062cb561b1ac061b0d23147f14..d797d08ce2ac9f2d9838f8a4a389b6d8e94a17e2 100644 --- a/crates/eval/src/examples/add_arg_to_trait_method.rs +++ b/crates/eval/src/examples/add_arg_to_trait_method.rs @@ -1,7 +1,6 @@ -use std::{collections::HashSet, path::Path}; +use std::path::Path; use anyhow::Result; -use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput}; use async_trait::async_trait; use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer}; @@ -32,39 +31,7 @@ impl Example for AddArgToTraitMethod { "# )); - let response = cx.run_to_end().await?; - - // Reads files before it edits them - - let mut read_files = HashSet::new(); - - for tool_use in response.tool_uses() { - match tool_use.name.as_str() { - "read_file" => { - if let Ok(input) = tool_use.parse_input::() { - read_files.insert(input.path); - } - } - "create_file" => { - if let Ok(input) = tool_use.parse_input::() { - read_files.insert(input.path); - } - } - "edit_file" => { - if let Ok(input) = tool_use.parse_input::() { - cx.assert( - read_files.contains(input.path.to_str().unwrap()), - format!( - "Read before edit: {}", - &input.path.file_stem().unwrap().to_str().unwrap() - ), - ) - .ok(); - } - } - _ => {} - } - } + let _ = cx.run_to_end().await?; // Adds ignored argument to all but `batch_tool`