@@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, Entity, Task};
-use language::OffsetRangeExt;
+use language::{OffsetRangeExt, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
Project,
@@ -13,6 +13,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, fmt::Write, sync::Arc};
use ui::IconName;
+use util::RangeExt;
use util::markdown::MarkdownInlineCode;
use util::paths::PathMatcher;
@@ -102,6 +103,7 @@ impl Tool for GrepTool {
cx: &mut App,
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;
+ const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
@@ -140,7 +142,7 @@ impl Tool for GrepTool {
let results = project.update(cx, |project, cx| project.search(query, cx));
- cx.spawn(async move|cx| {
+ cx.spawn(async move |cx| {
futures::pin_mut!(results);
let mut output = String::new();
@@ -148,68 +150,113 @@ impl Tool for GrepTool {
let mut matches_found = 0;
let mut has_more_matches = false;
- while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
+ 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if ranges.is_empty() {
continue;
}
- buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
- if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
- let mut file_header_written = false;
- let mut ranges = ranges
- .into_iter()
- .map(|range| {
- let mut point_range = range.to_point(buffer);
- point_range.start.row =
- point_range.start.row.saturating_sub(CONTEXT_LINES);
- point_range.start.column = 0;
- point_range.end.row = cmp::min(
- buffer.max_point().row,
- point_range.end.row + CONTEXT_LINES,
- );
- point_range.end.column = buffer.line_len(point_range.end.row);
- point_range
- })
- .peekable();
-
- while let Some(mut range) = ranges.next() {
- if skips_remaining > 0 {
- skips_remaining -= 1;
- continue;
- }
+ let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
+ (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
+ })? else {
+ continue;
+ };
- // We'd already found a full page of matches, and we just found one more.
- if matches_found >= RESULTS_PER_PAGE {
- has_more_matches = true;
- return Ok(());
- }
- while let Some(next_range) = ranges.peek() {
- if range.end.row >= next_range.start.row {
- range.end = next_range.end;
- ranges.next();
- } else {
- break;
- }
- }
+ while *parse_status.borrow() != ParseStatus::Idle {
+ parse_status.changed().await?;
+ }
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let mut ranges = ranges
+ .into_iter()
+ .map(|range| {
+ let matched = range.to_point(&snapshot);
+ let matched_end_line_len = snapshot.line_len(matched.end.row);
+ let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
+ let symbols = snapshot.symbols_containing(matched.start, None);
- if !file_header_written {
- writeln!(output, "\n## Matches in {}", path.display())?;
- file_header_written = true;
+ if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
+ let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
+ let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
+ let end_col = snapshot.line_len(end_row);
+ let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
+
+ if capped_ancestor_range.contains_inclusive(&full_lines) {
+ return (capped_ancestor_range, Some(full_ancestor_range), symbols)
}
+ }
+
+ let mut matched = matched;
+ matched.start.column = 0;
+ matched.start.row =
+ matched.start.row.saturating_sub(CONTEXT_LINES);
+ matched.end.row = cmp::min(
+ snapshot.max_point().row,
+ matched.end.row + CONTEXT_LINES,
+ );
+ matched.end.column = snapshot.line_len(matched.end.row);
+
+ (matched, None, symbols)
+ })
+ .peekable();
+
+ let mut file_header_written = false;
+
+ while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
+ if skips_remaining > 0 {
+ skips_remaining -= 1;
+ continue;
+ }
+
+ // We'd already found a full page of matches, and we just found one more.
+ if matches_found >= RESULTS_PER_PAGE {
+ has_more_matches = true;
+ break 'outer;
+ }
+
+ while let Some((next_range, _, _)) = ranges.peek() {
+ if range.end.row >= next_range.start.row {
+ range.end = next_range.end;
+ ranges.next();
+ } else {
+ break;
+ }
+ }
+
+ if !file_header_written {
+ writeln!(output, "\n## Matches in {}", path.display())?;
+ file_header_written = true;
+ }
- let start_line = range.start.row + 1;
- let end_line = range.end.row + 1;
- writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
- output.extend(buffer.text_for_range(range));
- output.push_str("\n```\n");
+ let end_row = range.end.row;
+ output.push_str("\n### ");
- matches_found += 1;
+ if let Some(parent_symbols) = &parent_symbols {
+ for symbol in parent_symbols {
+ write!(output, "{} › ", symbol.text)?;
}
}
- Ok(())
- })??;
+ if range.start.row == end_row {
+ writeln!(output, "L{}", range.start.row + 1)?;
+ } else {
+ writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
+ }
+
+ output.push_str("```\n");
+ output.extend(snapshot.text_for_range(range));
+ output.push_str("\n```\n");
+
+ if let Some(ancestor_range) = ancestor_range {
+ if end_row < ancestor_range.end.row {
+ let remaining_lines = ancestor_range.end.row - end_row;
+ writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
+ }
+ }
+
+ matches_found += 1;
+ }
}
if matches_found == 0 {
@@ -233,13 +280,16 @@ mod tests {
use super::*;
use assistant_tool::Tool;
use gpui::{AppContext, TestAppContext};
+ use language::{Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use settings::SettingsStore;
+ use unindent::Unindent;
use util::path;
#[gpui::test]
async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
init_test(cx);
+ cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
@@ -327,6 +377,7 @@ mod tests {
#[gpui::test]
async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
init_test(cx);
+ cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
@@ -401,6 +452,290 @@ mod tests {
);
}
+ /// Helper function to set up a syntax test environment
+ async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
+ use unindent::Unindent;
+ init_test(cx);
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor().clone());
+
+ // Create test file with syntax structures
+ fs.insert_tree(
+ "/root",
+ serde_json::json!({
+ "test_syntax.rs": r#"
+ fn top_level_function() {
+ println!("This is at the top level");
+ }
+
+ mod feature_module {
+ pub mod nested_module {
+ pub fn nested_function(
+ first_arg: String,
+ second_arg: i32,
+ ) {
+ println!("Function in nested module");
+ println!("{first_arg}");
+ println!("{second_arg}");
+ }
+ }
+ }
+
+ struct MyStruct {
+ field1: String,
+ field2: i32,
+ }
+
+ impl MyStruct {
+ fn method_with_block() {
+ let condition = true;
+ if condition {
+ println!("Inside if block");
+ }
+ }
+
+ fn long_function() {
+ println!("Line 1");
+ println!("Line 2");
+ println!("Line 3");
+ println!("Line 4");
+ println!("Line 5");
+ println!("Line 6");
+ println!("Line 7");
+ println!("Line 8");
+ println!("Line 9");
+ println!("Line 10");
+ println!("Line 11");
+ println!("Line 12");
+ }
+ }
+
+ trait Processor {
+ fn process(&self, input: &str) -> String;
+ }
+
+ impl Processor for MyStruct {
+ fn process(&self, input: &str) -> String {
+ format!("Processed: {}", input)
+ }
+ }
+ "#.unindent().trim(),
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ project.update(cx, |project, _cx| {
+ project.languages().add(rust_lang().into())
+ });
+
+ project
+ }
+
+ #[gpui::test]
+ async fn test_grep_top_level_function(cx: &mut TestAppContext) {
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line at the top level of the file
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "This is at the top level".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### fn top_level_function › L1-3
+ ```
+ fn top_level_function() {
+ println!("This is at the top level");
+ }
+ ```
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_grep_function_body(cx: &mut TestAppContext) {
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line inside a function body
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "Function in nested module".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
+ ```
+ ) {
+ println!("Function in nested module");
+ println!("{first_arg}");
+ println!("{second_arg}");
+ }
+ ```
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line with a function argument
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "second_arg".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
+ ```
+ pub fn nested_function(
+ first_arg: String,
+ second_arg: i32,
+ ) {
+ println!("Function in nested module");
+ println!("{first_arg}");
+ println!("{second_arg}");
+ }
+ ```
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_grep_if_block(cx: &mut TestAppContext) {
+ use unindent::Unindent;
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line inside an if block
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "Inside if block".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### impl MyStruct › fn method_with_block › L26-28
+ ```
+ if condition {
+ println!("Inside if block");
+ }
+ ```
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_grep_long_function_top(cx: &mut TestAppContext) {
+ use unindent::Unindent;
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line in the middle of a long function - should show message about remaining lines
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "Line 5".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### impl MyStruct › fn long_function › L31-41
+ ```
+ fn long_function() {
+ println!("Line 1");
+ println!("Line 2");
+ println!("Line 3");
+ println!("Line 4");
+ println!("Line 5");
+ println!("Line 6");
+ println!("Line 7");
+ println!("Line 8");
+ println!("Line 9");
+ println!("Line 10");
+ ```
+
+ 3 lines remaining in ancestor node. Read the file to see all.
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
+ use unindent::Unindent;
+ let project = setup_syntax_test(cx).await;
+
+ // Test: Line in the long function
+ let input = serde_json::to_value(GrepToolInput {
+ regex: "Line 12".to_string(),
+ include_pattern: Some("**/*.rs".to_string()),
+ offset: 0,
+ case_sensitive: false,
+ })
+ .unwrap();
+
+ let result = run_grep_tool(input, project.clone(), cx).await;
+ let expected = r#"
+ Found 1 matches:
+
+ ## Matches in root/test_syntax.rs
+
+ ### impl MyStruct › fn long_function › L41-45
+ ```
+ println!("Line 10");
+ println!("Line 11");
+ println!("Line 12");
+ }
+ }
+ ```
+ "#
+ .unindent();
+ assert_eq!(result, expected);
+ }
+
async fn run_grep_tool(
input: serde_json::Value,
project: Entity<Project>,
@@ -411,7 +746,13 @@ mod tests {
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
match task.output.await {
- Ok(result) => result,
+ Ok(result) => {
+ if cfg!(windows) {
+ result.replace("root\\", "root/")
+ } else {
+ result
+ }
+ }
Err(e) => panic!("Failed to run grep tool: {}", e),
}
}
@@ -424,4 +765,20 @@ mod tests {
Project::init_settings(cx);
});
}
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+ .unwrap()
+ }
}