diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 439dca17d0aa7030d549292d84e8dcd65f859d0e..e296a472b22e68b43f19bce02c4d9c602cb7144e 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -3,7 +3,7 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use futures::StreamExt; use gpui::{AnyWindowHandle, App, Entity, Task}; -use language::OffsetRangeExt; +use language::{OffsetRangeExt, ParseStatus, Point}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::{ Project, @@ -13,6 +13,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{cmp, fmt::Write, sync::Arc}; use ui::IconName; +use util::RangeExt; use util::markdown::MarkdownInlineCode; use util::paths::PathMatcher; @@ -102,6 +103,7 @@ impl Tool for GrepTool { cx: &mut App, ) -> ToolResult { const CONTEXT_LINES: u32 = 2; + const MAX_ANCESTOR_LINES: u32 = 10; let input = match serde_json::from_value::(input) { Ok(input) => input, @@ -140,7 +142,7 @@ impl Tool for GrepTool { let results = project.update(cx, |project, cx| project.search(query, cx)); - cx.spawn(async move|cx| { + cx.spawn(async move |cx| { futures::pin_mut!(results); let mut output = String::new(); @@ -148,68 +150,113 @@ impl Tool for GrepTool { let mut matches_found = 0; let mut has_more_matches = false; - while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { + 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { if ranges.is_empty() { continue; } - buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> { - if let Some(path) = buffer.file().map(|file| file.full_path(cx)) { - let mut file_header_written = false; - let mut ranges = ranges - .into_iter() - .map(|range| { - let mut point_range = range.to_point(buffer); - point_range.start.row = - point_range.start.row.saturating_sub(CONTEXT_LINES); - point_range.start.column = 0; - point_range.end.row = cmp::min( - buffer.max_point().row, - point_range.end.row + CONTEXT_LINES, - ); - point_range.end.column = buffer.line_len(point_range.end.row); - point_range - }) - .peekable(); - - while let Some(mut range) = ranges.next() { - if skips_remaining > 0 { - skips_remaining -= 1; - continue; - } + let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| { + (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status()) + })? else { + continue; + }; - // We'd already found a full page of matches, and we just found one more. - if matches_found >= RESULTS_PER_PAGE { - has_more_matches = true; - return Ok(()); - } - while let Some(next_range) = ranges.peek() { - if range.end.row >= next_range.start.row { - range.end = next_range.end; - ranges.next(); - } else { - break; - } - } + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let mut ranges = ranges + .into_iter() + .map(|range| { + let matched = range.to_point(&snapshot); + let matched_end_line_len = snapshot.line_len(matched.end.row); + let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len); + let symbols = snapshot.symbols_containing(matched.start, None); - if !file_header_written { - writeln!(output, "\n## Matches in {}", path.display())?; - file_header_written = true; + if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) { + let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot); + let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES); + let end_col = snapshot.line_len(end_row); + let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col); + + if capped_ancestor_range.contains_inclusive(&full_lines) { + return (capped_ancestor_range, Some(full_ancestor_range), symbols) } + } + + let mut matched = matched; + matched.start.column = 0; + matched.start.row = + matched.start.row.saturating_sub(CONTEXT_LINES); + matched.end.row = cmp::min( + snapshot.max_point().row, + matched.end.row + CONTEXT_LINES, + ); + matched.end.column = snapshot.line_len(matched.end.row); + + (matched, None, symbols) + }) + .peekable(); + + let mut file_header_written = false; + + while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){ + if skips_remaining > 0 { + skips_remaining -= 1; + continue; + } + + // We'd already found a full page of matches, and we just found one more. + if matches_found >= RESULTS_PER_PAGE { + has_more_matches = true; + break 'outer; + } + + while let Some((next_range, _, _)) = ranges.peek() { + if range.end.row >= next_range.start.row { + range.end = next_range.end; + ranges.next(); + } else { + break; + } + } + + if !file_header_written { + writeln!(output, "\n## Matches in {}", path.display())?; + file_header_written = true; + } - let start_line = range.start.row + 1; - let end_line = range.end.row + 1; - writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?; - output.extend(buffer.text_for_range(range)); - output.push_str("\n```\n"); + let end_row = range.end.row; + output.push_str("\n### "); - matches_found += 1; + if let Some(parent_symbols) = &parent_symbols { + for symbol in parent_symbols { + write!(output, "{} › ", symbol.text)?; } } - Ok(()) - })??; + if range.start.row == end_row { + writeln!(output, "L{}", range.start.row + 1)?; + } else { + writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?; + } + + output.push_str("```\n"); + output.extend(snapshot.text_for_range(range)); + output.push_str("\n```\n"); + + if let Some(ancestor_range) = ancestor_range { + if end_row < ancestor_range.end.row { + let remaining_lines = ancestor_range.end.row - end_row; + writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; + } + } + + matches_found += 1; + } } if matches_found == 0 { @@ -233,13 +280,16 @@ mod tests { use super::*; use assistant_tool::Tool; use gpui::{AppContext, TestAppContext}; + use language::{Language, LanguageConfig, LanguageMatcher}; use project::{FakeFs, Project}; use settings::SettingsStore; + use unindent::Unindent; use util::path; #[gpui::test] async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) { init_test(cx); + cx.executor().allow_parking(); let fs = FakeFs::new(cx.executor().clone()); fs.insert_tree( @@ -327,6 +377,7 @@ mod tests { #[gpui::test] async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) { init_test(cx); + cx.executor().allow_parking(); let fs = FakeFs::new(cx.executor().clone()); fs.insert_tree( @@ -401,6 +452,290 @@ mod tests { ); } + /// Helper function to set up a syntax test environment + async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity { + use unindent::Unindent; + init_test(cx); + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor().clone()); + + // Create test file with syntax structures + fs.insert_tree( + "/root", + serde_json::json!({ + "test_syntax.rs": r#" + fn top_level_function() { + println!("This is at the top level"); + } + + mod feature_module { + pub mod nested_module { + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + } + } + + struct MyStruct { + field1: String, + field2: i32, + } + + impl MyStruct { + fn method_with_block() { + let condition = true; + if condition { + println!("Inside if block"); + } + } + + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + + trait Processor { + fn process(&self, input: &str) -> String; + } + + impl Processor for MyStruct { + fn process(&self, input: &str) -> String { + format!("Processed: {}", input) + } + } + "#.unindent().trim(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + project.update(cx, |project, _cx| { + project.languages().add(rust_lang().into()) + }); + + project + } + + #[gpui::test] + async fn test_grep_top_level_function(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line at the top level of the file + let input = serde_json::to_value(GrepToolInput { + regex: "This is at the top level".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### fn top_level_function › L1-3 + ``` + fn top_level_function() { + println!("This is at the top level"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line inside a function body + let input = serde_json::to_value(GrepToolInput { + regex: "Function in nested module".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14 + ``` + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_args_and_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line with a function argument + let input = serde_json::to_value(GrepToolInput { + regex: "second_arg".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14 + ``` + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_if_block(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line inside an if block + let input = serde_json::to_value(GrepToolInput { + regex: "Inside if block".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn method_with_block › L26-28 + ``` + if condition { + println!("Inside if block"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_top(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the middle of a long function - should show message about remaining lines + let input = serde_json::to_value(GrepToolInput { + regex: "Line 5".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L31-41 + ``` + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + ``` + + 3 lines remaining in ancestor node. Read the file to see all. + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_bottom(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the long function + let input = serde_json::to_value(GrepToolInput { + regex: "Line 12".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }) + .unwrap(); + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L41-45 + ``` + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + async fn run_grep_tool( input: serde_json::Value, project: Entity, @@ -411,7 +746,13 @@ mod tests { let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx)); match task.output.await { - Ok(result) => result, + Ok(result) => { + if cfg!(windows) { + result.replace("root\\", "root/") + } else { + result + } + } Err(e) => panic!("Failed to run grep tool: {}", e), } } @@ -424,4 +765,20 @@ mod tests { Project::init_settings(cx); }); } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 9a9b7fd577ffba0b6fb635539741405615477b97..328fb25df817c78600c12b7e904bc4d5e6f4264a 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -387,6 +387,7 @@ impl Response { cx.assert_some(result, format!("called `{}`", tool_name)) } + #[allow(dead_code)] pub fn tool_uses(&self) -> impl Iterator { self.messages.iter().flat_map(|msg| &msg.tool_use) } diff --git a/crates/eval/src/examples/add_arg_to_trait_method.rs b/crates/eval/src/examples/add_arg_to_trait_method.rs index 5c3fb788f0bdbb062cb561b1ac061b0d23147f14..d797d08ce2ac9f2d9838f8a4a389b6d8e94a17e2 100644 --- a/crates/eval/src/examples/add_arg_to_trait_method.rs +++ b/crates/eval/src/examples/add_arg_to_trait_method.rs @@ -1,7 +1,6 @@ -use std::{collections::HashSet, path::Path}; +use std::path::Path; use anyhow::Result; -use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput}; use async_trait::async_trait; use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer}; @@ -32,39 +31,7 @@ impl Example for AddArgToTraitMethod { "# )); - let response = cx.run_to_end().await?; - - // Reads files before it edits them - - let mut read_files = HashSet::new(); - - for tool_use in response.tool_uses() { - match tool_use.name.as_str() { - "read_file" => { - if let Ok(input) = tool_use.parse_input::() { - read_files.insert(input.path); - } - } - "create_file" => { - if let Ok(input) = tool_use.parse_input::() { - read_files.insert(input.path); - } - } - "edit_file" => { - if let Ok(input) = tool_use.parse_input::() { - cx.assert( - read_files.contains(input.path.to_str().unwrap()), - format!( - "Read before edit: {}", - &input.path.file_stem().unwrap().to_str().unwrap() - ), - ) - .ok(); - } - } - _ => {} - } - } + let _ = cx.run_to_end().await?; // Adds ignored argument to all but `batch_tool`